from __future__ import annotations
"""
Termination functions for the Genesis environment.
Each of these should return a boolean tensor indicating which environments should terminate, in the tensor shape (num_envs,).
"""
import math
from typing import Literal
import torch
from genesis_forge.genesis_env import GenesisEnv
from genesis_forge.utils import entity_projected_gravity
from genesis_forge.managers import (
ActuatorManager,
ContactManager,
EntityManager,
TerrainManager,
)
[docs]
def timeout(env: GenesisEnv) -> torch.Tensor:
"""
Terminate the environment if the episode length exceeds the maximum episode length.
"""
if env.max_episode_length is None:
return torch.zeros(env.num_envs, dtype=torch.bool)
return env.episode_length > env.max_episode_length
[docs]
def bad_orientation(
env: GenesisEnv,
limit_angle: float = 40.0,
entity_attr: str = "robot",
entity_manager: EntityManager = None,
grace_steps: int = 0,
) -> torch.Tensor:
"""
Terminate the environment if the robot is tipping over too much.
This function uses projected gravity to detect when the robot has tilted
beyond a safe threshold. When the robot is perfectly upright, projected
gravity should be [0, 0, -1] in the body frame. As the robot tilts,
the x,y components increase, indicating roll and pitch angles.
Args:
env: The Genesis environment containing the robot
limit_angle: Maximum allowed tilt angle in degrees (default: 40 degrees)
entity_manager: The entity manager for the entity.
entity_attr: The attribute name of the entity in the environment.
This isn't necessary if `entity_manager` is provided.
grace_steps: Number of steps at episode start to ignore tilt detection (default: 0)
This gives the robot a chance to stabilize before tilt detection is active.
Returns:
torch.Tensor: Boolean tensor indicating which environments should terminate
"""
in_grace_period = env.episode_length <= grace_steps
# Get the projected gravity vector in body frame
projected_gravity = None
if entity_manager is not None:
projected_gravity = entity_manager.get_projected_gravity()
else:
entity = getattr(env, entity_attr)
projected_gravity = entity_projected_gravity(entity)
# Calculate the magnitude of tilt (distance from perfectly upright)
projected_gravity_xy = projected_gravity[:, :2]
tilt_magnitude = torch.norm(projected_gravity_xy, dim=1)
# Convert tilt magnitude to angle
tilt_angle = torch.asin(torch.clamp(tilt_magnitude, max=0.99))
# Terminate if tilt angle exceeds the limit
return (~in_grace_period) & (tilt_angle > math.radians(limit_angle))
[docs]
def is_upsidedown(
env: GenesisEnv,
threshold: float = 0.5,
entity_attr: str = "robot",
entity_manager: EntityManager = None,
grace_steps: int = 0,
) -> torch.Tensor:
"""
Terminate when the robot is belly-up (inverted).
Uses projected gravity in the body frame: upright is approximately [0, 0, -1],
belly-up is approximately [0, 0, +1]. Side-lying poses keep z below threshold.
Args:
env: The Genesis environment
threshold: Terminate when projected_gravity[:, 2] exceeds this value
entity_manager: The entity manager for the robot
entity_attr: Entity attribute if entity_manager is not provided
grace_steps: Steps at episode start to ignore this check
"""
in_grace_period = env.episode_length <= grace_steps
if entity_manager is not None:
projected_gravity = entity_manager.get_projected_gravity()
else:
entity = getattr(env, entity_attr)
projected_gravity = entity_projected_gravity(entity)
return (~in_grace_period) & (projected_gravity[:, 2] > threshold)
[docs]
def base_height_below_minimum(
env: GenesisEnv,
minimum_height: float = 0.05,
entity_attr: str = "robot",
entity_manager: EntityManager = None,
) -> torch.Tensor:
"""
Terminate the environment if the robot's base height falls below a minimum threshold.
Args:
env: The Genesis environment containing the robot
minimum_height: Minimum allowed base height in meters
entity_manager: The entity manager for the entity.
entity_attr: The attribute name of the entity in the environment.
This isn't necessary if `entity_manager` is provided.
Returns:
torch.Tensor: Boolean tensor indicating which environments should terminate
"""
base_pos = None
if entity_manager is not None:
base_pos = entity_manager.base_pos
else:
entity = getattr(env, entity_attr)
base_pos = entity.get_pos()
return base_pos[:, 2] < minimum_height
[docs]
def out_of_bounds(
env: GenesisEnv,
terrain_manager: TerrainManager,
subterrain: str | None = None,
border_margin: float = 0.5,
entity_attr: str = "robot",
) -> torch.Tensor:
"""
Terminate if the entity's base position is outside of the terrain.
Args:
env: The Genesis environment containing the robot
terrain_manager: The terrain manager to check for out of bounds
subterrain: The subterrain to keep the robot inside of
border_margin: The margin (in meters) to add to the terrain bounds
This terminates the episode before the robot falls off the terrain.
entity_attr: The attribute name of the entity in the environment.
This isn't necessary if `entity_manager` is provided.
"""
# Get the entity's base position
entity = getattr(env, entity_attr)
position = entity.get_pos()
# Get terrain bounds
(x_min, x_max, y_min, y_max) = terrain_manager.get_bounds(subterrain)
x_min_bound, x_max_bound = x_min + border_margin, x_max - border_margin
y_min_bound, y_max_bound = y_min + border_margin, y_max - border_margin
# Check bounds
x_pos, y_pos = position[:, 0], position[:, 1]
return (
(x_pos < x_min_bound)
| (x_pos > x_max_bound)
| (y_pos < y_min_bound)
| (y_pos > y_max_bound)
)
[docs]
def dof_control_force_limit(
_env: GenesisEnv,
actuator_manager: ActuatorManager,
threshold: float | None = None,
) -> torch.Tensor:
"""
Terminate if any joint's commanded actuator force exceeds a limit (+/-).
Uses control/output force (what the actuator commands), not measured joint load.
Suitable for teaching policies to stay within rated motor torque.
Args:
env: The Genesis environment
actuator_manager: Actuator manager for the controlled joints
threshold: Force/torque limit (in simulator units).
If None, uses `max_force` value from actuator_manager.
Returns:
Boolean tensor indicating which environments should terminate
"""
force = actuator_manager.get_dofs_control_force()
if threshold is None:
threshold = actuator_manager.get_dofs_max_force()
return torch.any(torch.abs(force) > threshold, dim=-1)
[docs]
def dof_velocity_limit(
_env: GenesisEnv,
actuator_manager: ActuatorManager,
threshold: float,
unit: Literal["rpm", "rad"] = "rad",
) -> torch.Tensor:
"""
Terminate if any of the actuator_manager's joints moves faster than a speed limit.
Args:
env: The Genesis environment
actuator_manager: Actuator manager for the controlled joints
threshold: Speed limit in the units given by `unit`
unit: The speed units
- `"rad"` for radians per second (default)
- `"rpm"` for revolutions per minute
Returns:
Boolean tensor indicating which environments should terminate
"""
assert unit in ["rad", "rpm"], f"Unknown velocity unit '{unit}'. Use 'rad' or 'rpm'."
if unit == "rpm":
threshold = threshold * (2 * math.pi / 60)
vel = actuator_manager.get_dofs_velocity()
return torch.any(torch.abs(vel) > threshold, dim=-1)