Source code for genesis_forge.managers.reward_manager

import torch
import genesis as gs
from typing import Iterator, TypedDict, Callable, Any

from genesis_forge.genesis_env import GenesisEnv
from genesis_forge.managers.base import BaseManager
from genesis_forge.managers.config import RewardConfigItem


class RewardConfig(TypedDict):
    """Defines a reward item."""

    fn: Callable[[GenesisEnv, ...], torch.Tensor]
    """Function that will be called to calculate a reward for the environments."""

    params: dict[str, Any]
    """Additional parameters to pass to the function."""

    weight: float
    """The weight of the reward item."""


[docs] class RewardManager(BaseManager): """ Handles calculating and logging the rewards for the environment. This works with a dictionary configuration of reward handlers. For each dictionary item, a function will be called to calculate a reward value for the environment. Args: env: The environment to manage the rewards for. reward_cfg: A dictionary of reward conditions. logging_enabled: Whether to log the rewards to tensorboard. logging_tag: The section name used to log the rewards to tensorboard. Example with ManagedEnvironment:: class MyEnv(ManagedEnvironment): def config(self): self.reward_manager = RewardManager( self, cfg={ "Default pose": { "fn": mdp.rewards.dof_similar_to_default, "weight": -0.1, }, "Base height": { "fn": mdp.rewards.base_height, "params": { "target_height": 0.135 }, "weight": -100.0, }, }, ) Example using the reward manager directly:: class MyEnv(GenesisEnv): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.reward_manager = RewardManager( self, cfg={ "Base height": { "fn": mdp.rewards.base_height, "params": { "target_height": 0.135 }, "weight": -100.0, }, }, ) def build(self): super().build() self.reward_manager.build() def step(self, actions: torch.Tensor): super().step(actions) rewards = self.reward_manager.step() # ... other step logic ... return obs, rewards, terminations, timeouts, info def reset(self, envs_idx: list[int] | None = None): super().reset(envs_idx) # ... other reset logic ... return obs, info """ def __init__( self, env: GenesisEnv, cfg: dict[str, RewardConfig], logging_enabled: bool = True, logging_tag: str = "Rewards", ): super().__init__(env, type="reward") self.logging_enabled = logging_enabled self.logging_tag = logging_tag # Wrap config items self.cfg: dict[str, RewardConfigItem] = {} for name, cfg in cfg.items(): self.cfg[name] = RewardConfigItem(cfg, env) # Initialize buffers self._reward_buf = torch.zeros( (env.num_envs,), device=gs.device, dtype=gs.tc_float ) self._episode_seconds = torch.zeros( (self.env.num_envs,), device=gs.device, dtype=gs.tc_float ) self._episode_mean: dict[str, torch.Tensor] = dict() self._episode_data: dict[str, torch.Tensor] = dict() for name in self.cfg.keys(): self._episode_data[name] = torch.zeros( (env.num_envs,), device=gs.device, dtype=gs.tc_float ) @property def rewards(self) -> torch.Tensor: """ The rewards calculated for the most recent step. Shape is (num_envs,). """ return self._reward_buf @property def episode_data(self) -> dict[str, torch.Tensor]: """ Get the accumulated reward data for the current episode of all environments. """ return self._episode_data """ Helpers """
[docs] def last_episode_mean_reward(self, name: str, before_weight: bool = True) -> float: """ Get the last mean reward for an episode for a given reward name. The mean reward is only calculated when episodes end/reset. Args: name: The name of the reward to get the mean for. before_weight: If True, this will be the base reward value before the weight was applied. Returns: The last mean reward for an episode for a given reward name. """ rew = self._episode_mean.get(name, 0.0) if before_weight: rew /= self.cfg[name].weight return rew
""" Lifecycle Operations """
[docs] def build(self): """ Build any config item function classes. """ for cfg in self.cfg.values(): cfg.build()
[docs] def step(self) -> torch.Tensor: """ Calculate the rewards for this step Returns: The rewards for the environments. Shape is (num_envs,). """ if not self.enabled: return self._reward_buf dt = self.env.dt self._reward_buf[:] = 0.0 self._episode_seconds += dt for name, cfg in self.cfg.items(): # Don't calculate reward if the weight is zero if cfg.weight == 0: continue # Get reward value from function weight = cfg.weight * dt value = cfg.fn(self.env, **cfg.params) * weight # Add to reward buffer self._reward_buf += value # Add to episode data for logging (if enabled) if self.logging_enabled: self._episode_data[name] += value return self._reward_buf
[docs] def reset(self, envs_idx: list[int] | None = None): """Log the reward mean values at the end of the episode""" if envs_idx is None: envs_idx = torch.arange(self.env.num_envs, device=gs.device) if self.enabled and self.logging_enabled: logging_dict = self.env.extras[self.env.extras_logging_key] episode_seconds = self._episode_seconds[envs_idx] for name, value in self._episode_data.items(): # Don't log items that have zero weight cfg = self.cfg[name] if cfg.weight != 0: # Calculate average for each environment value[envs_idx] /= episode_seconds # Take the mean across all episodes episode_mean = torch.mean(value[envs_idx]) self._episode_mean[name] = episode_mean.item() logging_dict[f"{self.logging_tag} / {name}"] = episode_mean # Reset episodic data self._episode_data[name][envs_idx] = 0.0 # Reset episode seconds to nearly zero, to prevent divide by zero errors self._episode_seconds[envs_idx] = 1e-10
""" Dict-like operations """
[docs] def __getitem__(self, name: str) -> RewardConfigItem: """Get a reward config item by name.""" return self.cfg[name]
[docs] def __setitem__(self, name: str, value: RewardConfigItem): """Set a reward config item by name.""" self.cfg[name] = value
[docs] def __delitem__(self, name: str): """Delete a reward config item by name.""" del self.cfg[name]
[docs] def __contains__(self, name: str) -> bool: """Check if a reward config item exists by name.""" return name in self.cfg
[docs] def __iter__(self) -> Iterator[str]: """Iterate over the reward config item names.""" return iter(self.cfg.keys())
[docs] def __len__(self) -> int: """Get the number of reward config items.""" return len(self.cfg)