Source code for genesis_forge.managers.observation_manager

from __future__ import annotations

import torch
import numpy as np
from gymnasium import spaces
import genesis as gs
from typing import TypedDict, Callable, Any
from genesis_forge.genesis_env import GenesisEnv
from genesis_forge.managers.base import BaseManager
from genesis_forge.managers.config import ObservationConfigItem


class ObservationConfig(TypedDict):
    """Defines an observation item."""

    fn: Callable[[GenesisEnv, ...], torch.Tensor]
    """Function that will be called to generate an observation, returning a value for each environment."""

    params: dict[str, Any]
    """Additional parameters to pass to the function."""

    scale: float | None
    """The scale to apply to the observation. If None, no scale will be applied."""

    noise: float | None
    """The noise scale to add to the observation. If None, no noise will be added.
    This will randomly choose a number between -1 and 1, multiply it by the noise scale, and add the result to the observation values."""


[docs] class ObservationManager(BaseManager): """ Defines the observations and observation space for the environment. Args: env: The environment. cfg: The configuration for the observation manager. name: The name to categorize the observations under, generally used for asymmetrical RL. It's required to have one observation manager named "policy". noise: The range of random noise to add to all observations. history_len: The number of previous observations to include in the observation. Example with ManagedEnvironment:: class MyEnv(ManagedEnvironment): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) config(self): ObservationManager( self, cfg={ "velocity_cmd": {"fn": self.velocity_command.observation}, "robot_ang_vel": { "fn": utils.entity_ang_vel, "params": {"entity": self.robot}, "noise": 0.1, }, "robot_lin_vel": { "fn": utils.entity_lin_vel, "params": {"entity": self.robot}, "noise": 0.1, }, "robot_projected_gravity": { "fn": utils.entity_projected_gravity, "params": {"entity": self.robot}, "noise": 0.1, }, "robot_dofs_position": { "fn": self.action_manager.get_dofs_position, "noise": 0.01, }, "actions": {"fn": lambda: env.actions}, }, ) Example using the observation manager directly:: class MyEnv(GenesisEnv): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.observation_manager = ObservationManager( self, cfg={ "velocity_cmd": {"fn": self.velocity_command.observation}, "robot_ang_vel": { "fn": utils.entity_ang_vel, "params": {"entity": self.robot}, "noise": 0.1, }, "robot_lin_vel": { "fn": utils.entity_lin_vel, "params": {"entity": self.robot}, "noise": 0.1, }, "robot_projected_gravity": { "fn": utils.entity_projected_gravity, "params": {"entity": self.robot}, "noise": 0.1, }, "robot_dofs_position": { "fn": self.action_manager.get_dofs_position, "noise": 0.01, }, "actions": {"fn": lambda: env.actions}, }, ) @property observation_space(self): return self.obs_manager.observation_space def build(self): super().build() self.obs_manager.build() def step(self, actions: torch.Tensor): super().step(actions) # ... step logic ... obs = self.observation_manager.observation() return obs, rewards, terminations, timeouts, info def reset(self, envs_idx: list[int] | None = None): super().reset(envs_idx) # ... reset logic ... obs = self.observation_manager.observation() return obs, info """ def __init__( self, env: GenesisEnv, cfg: dict[str, ObservationConfig], name: str = "policy", history_len: int | None = None, noise: tuple[float, float] | None = None, ): super().__init__(env, "observation") self._name = name self.cfg = cfg self.noise = noise self._observation_size = 1 self._observation_space = None if history_len is not None and history_len < 1: raise ValueError("history_len must be greater than 0") self._history_len = history_len if history_len is not None else 1 self._history = [] # Wrap config items self.cfg: dict[str, ObservationConfigItem] = {} for name, cfg in cfg.items(): self.cfg[name] = ObservationConfigItem(cfg, env) """ Properties """ @property def name(self) -> str: """ The name to categorize the observations under This is generally used for asymmetrical RL and it's required to have one observation manager named "policy". """ return self._name @property def observation_space(self) -> spaces.Space: """The observation space.""" return self._observation_space """ Public methods """
[docs] def build(self): """ Determine the observation space and setup the buffers. """ if not self.enabled: self._observation_size = 1 self._observation_space = spaces.Box( low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32, ) return # Setup observation functions and the observation space single_obs_size = self._setup_observation_functions() self._observation_size = single_obs_size * self._history_len self._observation_space = spaces.Box( low=-np.inf, high=np.inf, shape=(self._observation_size,), dtype=np.float32, ) # Fill history buffer shape = (self.env.num_envs, single_obs_size) self._history = [ torch.zeros(shape, device=gs.device) for _ in range(self._history_len) ] self._history_output = torch.zeros( (self.env.num_envs, self._observation_size), device=gs.device, )
[docs] def get_observations( self, values: dict[str, float | torch.Tensor] | None = None ) -> torch.Tensor: """ Generate current observations for all environments. Optionally, you can provide the observation values directly as a dictionary of values, and this method will return the formatted/scaled (without noise) tensor for the policy. This is useful for manual deployments or troubleshooting. Args: values: (optional) If provided, these values will be used instead of fetching observations from the config functions. It's expected that this dict contains a key for every observation configuration. These values will be scaled, based on the configuration, but not receive any noise. This is useful for providing observations for deployment. Returns: The observations for all environments. """ if not self.enabled: return torch.zeros((self.env.num_envs, self._observation_size)) buffer = self._history.pop() self._perform_observation(buffer, values) self._history.insert(0, buffer) # Concatenate the history buffers into the pre-allocated output buffer # This is more performant than torch.cat() offset = 0 for obs in self._history: size = obs.shape[1] self._history_output[:, offset : offset + size] = obs offset += size return self._history_output.clone()
""" Private methods. """ def _setup_observation_functions(self) -> int: """Build all the observation function classes, and determine the observation space.""" size = 0 for name, cfg in self.cfg.items(): try: cfg.build() assert callable(cfg.fn), f"Observation function {name} is not callable" value = cfg.fn(env=self.env, **cfg.params) value_size = value.shape[-1] if value_size > 0: size += value_size except Exception as e: print(f"Error generating observation for '{name}'") raise e return size def _perform_observation( self, output: torch.Tensor, override_values: dict[str, float | torch.Tensor] | None = None, ) -> torch.Tensor: """ Perform a round of observations. Args: output: The output tensor to fill with the observations. """ offset = 0 has_overrides = override_values is not None for name, cfg in self.cfg.items(): try: # Get values params = cfg.params if override_values is not None: if name not in override_values: raise ValueError(f"Value '{name}' not found in override values") value = override_values[name] if not isinstance(value, torch.Tensor): value = torch.tensor(value, device=gs.device) else: value = cfg.fn(env=self.env, **params) # Apply scale scale = cfg.scale if scale is not None and scale != 1.0: value *= scale # Add noise, if the value is not an override if not has_overrides: noise = cfg.noise or self.noise if noise is not None and noise != 0.0: noise_value = torch.empty_like(value).uniform_(-1, 1) * noise value += noise_value # Copy directly into output buffer value_size = value.shape[-1] if value_size > 0: output[:, offset : offset + value_size] = value offset += value_size except Exception as e: print(f"Error generating observation for '{name}'") raise e return output