from __future__ import annotations
import torch
import genesis as gs
from genesis_forge.genesis_env import GenesisEnv
from genesis_forge.managers.base import BaseManager
from genesis.utils.geom import (
transform_by_quat,
inv_quat,
)
from genesis_forge.managers.config import ConfigItem, ResetMdpFnClass
from typing import TypedDict, Callable, Any, TYPE_CHECKING
if TYPE_CHECKING:
from genesis.engine.entities import RigidEntity
ResetConfigFn = Callable[[GenesisEnv, RigidEntity, list[int], ...], None]
class EntityResetConfig(TypedDict):
"""Defines an entity reset item."""
fn: ResetConfigFn | ResetMdpFnClass
"""
Function, or class function, that will be called on reset.
Args:
env: The environment instance.
entity: The entity instance.
envs_idx: The environment ids for which the entity is to be reset.
**params: Additional parameters to pass to the function from the params dictionary.
"""
params: dict[str, Any]
"""Additional parameters to pass to the function."""
weight: float
"""The weight of the reward item."""
[docs]
class EntityManager(BaseManager):
"""
Provides options for resetting an entity and adding noise and randomization to its state.
Args:
env: The environment instance.
entity_attr: The attribute name of the environment that the entity is stored in.
on_reset: The reset configuration for the entity.
Example::
class MyEnv(ManagedEnvironment):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def config(self):
self.entity_manager = EntityManager(
self,
entity_attr="robot",
on_reset={
"position": {
"fn": reset.randomize_terrain_position,
"params": {
"terrain_manager": self.terrain_manager,
"subterrain": self._target_terrain,
"height_offset": 0.15,
},
},
},
)
"""
def __init__(
self,
env: GenesisEnv,
entity_attr: str,
on_reset: dict[str, EntityResetConfig],
):
super().__init__(env, type="entity")
if hasattr(env, "add_entity_manager"):
env.add_entity_manager(self)
self.entity: RigidEntity | None = None
self.on_reset = on_reset
self._entity_attr = entity_attr
# Wrap config items
self.on_reset: dict[str, ConfigItem] = {}
for name, cfg in on_reset.items():
self.on_reset[name] = ConfigItem(cfg, env)
# Buffers
self._global_gravity = torch.tensor(
[0.0, 0.0, -1.0], device=gs.device, dtype=gs.tc_float
).repeat(env.num_envs, 1)
self._base_pos = torch.zeros(
(env.num_envs, 3), device=gs.device, dtype=gs.tc_float
)
self._base_quat = torch.zeros(
(env.num_envs, 4), device=gs.device, dtype=gs.tc_float
)
self._inv_base_quat = torch.zeros_like(self._base_quat)
"""
Properties
"""
@property
def base_pos(self) -> torch.Tensor:
"""
The position of the entities base link.
"""
return self._base_pos
@property
def base_quat(self) -> torch.Tensor:
"""
The quaternion of the entity's base link.
"""
return self._base_quat
@property
def inv_base_quat(self) -> torch.Tensor:
"""
The inverse of the entity's base link quaternion.
"""
return self._inv_base_quat
"""
Helpers
"""
[docs]
def get_projected_gravity(self) -> torch.Tensor:
"""
The projected gravity of the entity's base link, in the entity's local frame.
"""
return transform_by_quat(self._global_gravity, self._inv_base_quat)
[docs]
def get_linear_velocity(self) -> torch.Tensor:
"""
The linear velocity of the entity's base link, in the entity's local frame.
"""
return transform_by_quat(self.entity.get_vel(), self._inv_base_quat)
[docs]
def get_angular_velocity(self) -> torch.Tensor:
"""
The angular velocity of the entity's base link, in the entity's local frame.
"""
return transform_by_quat(self.entity.get_ang(), self._inv_base_quat)
"""
Operations.
"""
[docs]
def build(self):
"""
Build the entity manager.
"""
self.entity = getattr(self.env, self._entity_attr)
self._cached_calcs()
# Build reset function classes
for cfg in self.on_reset.values():
cfg.build(entity=self.entity)
[docs]
def step(self):
"""
Run some common shared calculations at each step.
"""
self._cached_calcs()
[docs]
def reset(self, envs_idx: list[int] | None = None):
"""
Call all reset functions
"""
if not self.enabled:
return
if envs_idx is None:
envs_idx = torch.arange(self.env.num_envs, device=gs.device)
for name, cfg in self.on_reset.items():
try:
cfg.execute(envs_idx)
except Exception as e:
print(f"Error resetting entity with config: '{name}'")
raise e
"""
Implementation
"""
def _cached_calcs(self):
"""
Calculate and cache some common values
"""
self._base_pos[:] = self.entity.get_pos()
self._base_quat[:] = self.entity.get_quat()
self._inv_base_quat = inv_quat(self._base_quat)