import torch
from typing import TypedDict, Callable, Any
import genesis as gs
from genesis_forge.genesis_env import GenesisEnv
from genesis_forge.managers.base import BaseManager
from genesis_forge.managers.config import TerminationConfigItem
class TerminationConfig(TypedDict):
"""Defines a termination condition."""
fn: Callable[[GenesisEnv, ...], torch.Tensor]
"""Function that will be called to calculate a termination signal for the environment."""
params: dict[str, Any]
"""Additional parameters to pass to the function."""
time_out: bool
"""Set to True if a positive result is a time out and not a termination."""
[docs]
class TerminationManager(BaseManager):
"""
Handles calculating and logging the "dones" (termination or truncation) for the environments.
This works with a dictionary configuration of termination conditions. For each dictionary item,
a function will be called to calculate a termination signal for the environment.
Args:
env: The environment to manage the termination for.
term_cfg: A dictionary of termination conditions.
logging_enabled: Whether to log the termination signals to tensorboard.
logging_tag: The section tag used to log the termination signals to tensorboard.
Example with ManagedEnvironment::
class MyEnv(ManagedEnvironment):
def config(self):
self.termination_manager = TerminationManager(
self,
term_cfg={
"Min Height": {
"fn": mdp.terminations.min_height,
"params": {"min_height": 0.05},
},
},
)
Example using the termination manager directly::
class MyEnv(GenesisEnv):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.termination_manager = TerminationManager(
self,
term_cfg={
"Min Height": {
"fn": mdp.terminations.min_height,
"params": {"min_height": 0.5},
},
"Rolled over": {
"fn": mdp.terminations.max_angle,
"params": { "quat_threshold": 0.35 },
},
},
)
def build(self):
super().build()
self.termination_manager.build()
def step(self, actions: torch.Tensor):
super().step(actions)
# ...handle actions...
# Calculate dones (terminated or truncated)
terminated, truncated = self.termination_manager.step()
dones = terminated | truncated
reset_env_idx = dones.nonzero(as_tuple=False).reshape((-1,))
# Reset environments
if reset_env_idx.numel() > 0:
self.reset(reset_env_idx)
return obs, rewards, terminated, truncated, info
def reset(self, envs_idx: Sequence[int] = None):
super().reset(envs_idx)
# ...do reset logic here...x
self.termination_manager.reset(envs_idx)
return obs, info
"""
def __init__(
self,
env: GenesisEnv,
term_cfg: dict[str, TerminationConfig],
logging_enabled: bool = True,
logging_tag: str = "Terminations",
):
super().__init__(env, type="termination")
self.term_cfg = term_cfg
self.logging_enabled = logging_enabled
self.logging_tag = logging_tag
# Wrap config items
self.term_cfg: dict[str, TerminationConfigItem] = {}
for name, cfg in term_cfg.items():
self.term_cfg[name] = TerminationConfigItem(cfg, env)
# Buffers
self._terminated_buf = torch.zeros(
env.num_envs, device=gs.device, dtype=torch.bool
)
self._truncated_buf = torch.zeros_like(self._terminated_buf)
"""
Properties
"""
@property
def dones(self) -> torch.Tensor:
"""The termination signals for the environments. Shape is (num_envs,)."""
return self._terminated_buf | self._truncated_buf
@property
def terminated(self) -> torch.Tensor:
"""The termination signals for the environments. Shape is (num_envs,)."""
return self._terminated_buf
@property
def truncated(self) -> torch.Tensor:
"""The truncation signals for the environments. Shape is (num_envs,)."""
return self._truncated_buf
"""
Operations
"""
[docs]
def build(self):
"""
Build any config item function classes.
"""
for cfg in self.term_cfg.values():
cfg.build()
[docs]
def step(self) -> tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the termination/truncation signals for this step
Returns:
terminated - The termination signals for the environments. Shape is (num_envs,).
truncated - The truncation signals for the environments. Shape is (num_envs,).
"""
if not self.enabled:
return self._terminated_buf, self._truncated_buf
self._terminated_buf[:] = False
self._truncated_buf[:] = False
logging_dict = self.env.extras[self.env.extras_logging_key]
for name, term_item in self.term_cfg.items():
try:
# Get termination value
params = term_item.params
value = term_item.fn(self.env, **params)
# Add to the correct buffer using in-place operations
if term_item.time_out:
self._truncated_buf |= value
else:
self._terminated_buf |= value
# Logging, if there are some terminations and timeouts
dones = value.nonzero(as_tuple=True)[0]
if self.logging_enabled and dones.numel() > 0:
logging_dict[f"{self.logging_tag} / {name}"] = (
value.float().mean().detach()
)
except Exception as e:
print(f"Error calculating termination for '{name}'")
raise e
self.env.extras["terminations"] = self._terminated_buf
self.env.extras["time_outs"] = self._truncated_buf
return self._terminated_buf, self._truncated_buf