Termination#

class genesis_forge.managers.TerminationManager(env: GenesisEnv, term_cfg: dict[str, TerminationConfig], logging_enabled: bool = True, logging_tag: str = 'Terminations')[source]#

Bases: BaseManager

Handles calculating and logging the “dones” (termination or truncation) for the environments.

This works with a dictionary configuration of termination conditions. For each dictionary item, a function will be called to calculate a termination signal for the environment.

Parameters:
  • env – The environment to manage the termination for.

  • term_cfg – A dictionary of termination conditions.

  • logging_enabled – Whether to log the termination signals to tensorboard.

  • logging_tag – The section tag used to log the termination signals to tensorboard.

Example with ManagedEnvironment:

class MyEnv(ManagedEnvironment):
    def config(self):
        self.termination_manager = TerminationManager(
            self,
            term_cfg={
                "Min Height": {
                    "fn": mdp.terminations.min_height,
                    "params": {"min_height": 0.05},
                },
            },
        )

Example using the termination manager directly:

class MyEnv(GenesisEnv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.termination_manager = TerminationManager(
            self,
            term_cfg={
                "Min Height": {
                    "fn": mdp.terminations.min_height,
                    "params": {"min_height": 0.5},
                },
                "Rolled over": {
                    "fn": mdp.terminations.max_angle,
                    "params": { "quat_threshold": 0.35 },
                },
            },
        )

    def build(self):
        super().build()
        self.termination_manager.build()

    def step(self, actions: torch.Tensor):
        super().step(actions)
        # ...handle actions...

        # Calculate dones (terminated or truncated)
        terminated, truncated = self.termination_manager.step()
        dones = terminated | truncated
        reset_env_idx = dones.nonzero(as_tuple=False).reshape((-1,))

        # Reset environments
        if reset_env_idx.numel() > 0:
            self.reset(reset_env_idx)

        return obs, rewards, terminated, truncated, info

    def reset(self, envs_idx: Sequence[int] = None):
        super().reset(envs_idx)
        # ...do reset logic here...x

        self.termination_manager.reset(envs_idx)
        return obs, info
build()[source]#

Build any config item function classes.

reset(envs_idx: list[int] | None = None)#

One or more environments have been reset

step() tuple[torch.Tensor, torch.Tensor][source]#

Calculate the termination/truncation signals for this step

Returns:

terminated - The termination signals for the environments. Shape is (num_envs,). truncated - The truncation signals for the environments. Shape is (num_envs,).

property dones: torch.Tensor#

The termination signals for the environments. Shape is (num_envs,).

property terminated: torch.Tensor#

The termination signals for the environments. Shape is (num_envs,).

property truncated: torch.Tensor#

The truncation signals for the environments. Shape is (num_envs,).