Source code for genesis_forge.managers.command.command_manager

from typing import Tuple, Callable

import os
import torch
import genesis as gs

from genesis_forge.genesis_env import GenesisEnv
from genesis_forge.managers.base import BaseManager
from genesis_forge.gamepads import Gamepad

CommandRangeValue = Tuple[float, float]
CommandRange = CommandRangeValue | dict[str, CommandRangeValue]

THIS_DIR = os.path.dirname(os.path.abspath(__file__))


[docs] class CommandManager(BaseManager): """ Generates a command from uniform distribution of values. You can use this to regularly sample a commands for each environment, such as a height command, a target destination, or a specific pose. Args: env: The environment to control range: The number range, or dict of ranges, to generate target command(s) for resample_time_sec: The time interval between changing the command Example:: class MyEnv(GenesisEnv): def config(self): # Create a height command self.height_command = CommandManager(self, range=(0.1, 0.2)) # Rewards RewardManager( self, logging_enabled=True, cfg={ "base_height_target": { "weight": -50.0, "fn": rewards.base_height, "params": { "height_command": self.height_command, }, }, # ... other rewards ... }, ) # Observations ObservationManager( self, cfg={ "height_cmd": {"fn": self.height_command.observation}, # ... other observations ... }, ) """ def __init__( self, env: GenesisEnv, range: CommandRange, resample_time_sec: float = 5.0, ): super().__init__(env, type="command") self._range = range self.resample_time_sec = resample_time_sec self._external_controller = None self._gamepad_cfg = None self._gamepad_axis_command_buffer = None num_ranges = len(range) if isinstance(range, dict) else 1 self._command = torch.zeros(env.num_envs, num_ranges, device=gs.device) self._range_idx = {} if isinstance(range, dict): self._range_idx = {key: i for i, key in enumerate(range.keys())} """ Properties """ @property def command(self) -> torch.Tensor: """The desired command value. Shape is (num_envs, num_ranges).""" if self._external_controller is not None: return self._external_controller(self.env.step_count) return self._command @property def range(self) -> CommandRange: """The range of values to generate target command(s) for.""" return self._range @range.setter def range(self, range: CommandRange): """Set the range of values to generate target command(s) for.""" # Validate the shape of the range num = len(range) if isinstance(range, dict) else 1 if num != self._command.shape[1]: raise ValueError( f"Cannot change the shape of the CommandManager range. Expected size: {self._command.shape[1]}, got {num}" ) # Validate the range types match if type(range) != type(self._range): raise ValueError( f"Cannot change the base type of the CommandManager range. Expected type: {type(self._range)}, got {type(range)}" ) # Validate that the dict keys match the current range dict keys if isinstance(range, dict): if set(range.keys()) != set(self._range.keys()): raise ValueError( f"Cannot change the dict keys of the CommandManager range. Expected keys: {set(self._range.keys())}, got {set(range.keys())}" ) self._range = range @property def resample_time_sec(self) -> float: """The time interval (in seconds) between changing the command for each environment.""" return self._resample_time_sec @resample_time_sec.setter def resample_time_sec(self, resample_time_sec: float): """Set the time interval (in seconds) between changing the command for each environment.""" self._resample_time_sec = resample_time_sec self._resample_steps = int(resample_time_sec / self.env.dt) """ Operations """
[docs] def get_command(self, range_key: str) -> torch.Tensor: """ If the range is a dict, get the command values for the given key. """ if not isinstance(self._range, dict): raise ValueError("The range is not a dict") return self._command[:, self._range_idx[range_key]]
[docs] def set_command( self, range_key: str, value: torch.Tensor, envs_idx: list[int] | None = None, ): """ Update a command value for selected environments. """ if not isinstance(self._range, dict): raise ValueError("The range is not a dict") if envs_idx is None: self._command[:, self._range_idx[range_key]] = value else: self._command[envs_idx, self._range_idx[range_key]] = value
[docs] def increment_range( self, range_key: str, increment: float | tuple[float, float], limit: float | tuple[float, float] = None, ): """ Increment a command range target values by the given amount, with an optional limit. Both increment and limit can be passed as a single float or a tuple of two floats. When they are tuples, they represent the increment and limit for both min and max. When they are a single float, they will be applied to both min and max as (-value, +value). Example:: # Increment the range by 0.5 (+/-) for both min and max command_manager.increment_range("height", 0.5) # Increment the range min by -0.25 and max by 1.0 command_manager.increment_range("height", (-0.25, 1.0)) # Increment the range by (-0.25, 1.0) and keep min above -0.5 and max below 2.0 command_manager.increment_range("height", (-0.25, 1.0), (-0.5, 2.0)) Args: range_key: The key of the command range to increment. increment: The amount to increment the min & max range values by (+/-). If this is a tuple, it will be: `(min_increment, max_increment)`. If this is a single float, it will be applied to both min and max as (-increment, +increment). limit: Do not set the values beyond this limit range. This is a tuple of `(min_limit, max_limit)`. """ # Get the range to increment, and ensure it is a list, not a tuple (tuples are immutable) if not isinstance(self.range, dict): raise ValueError("Cannot increment a non-dict range item") range_item = self._range.get(range_key, None) if range_item is None: raise ValueError(f"Range item {range_key} not found") range_item = list[float](range_item) # Ensure it's a list, not a tuple # Expand single increment/limit values to tuples if not isinstance(increment, (list, tuple)): increment = (-increment, increment) if limit is not None and not isinstance(limit, (list, tuple)): limit = (-limit, limit) # Increment the range values for i, value in enumerate(range_item): value += increment[i] if limit is not None: if increment[i] > 0: value = min(value, limit[i]) else: value = max(value, limit[i]) range_item[i] = value self._range[range_key] = range_item
[docs] def get_command_idx(self, key: str) -> int: """ If the range is a dict, get the command index for the given key. """ if not isinstance(self._range, dict): raise ValueError("The range is not a dict") return self._range_idx[key]
[docs] def step(self): """Resample the command if necessary""" if not self.enabled or self._external_controller is not None: return resample_command_envs = ( (self.env.episode_length % self._resample_steps == 0) .nonzero(as_tuple=False) .reshape((-1,)) ) self.resample_command(resample_command_envs)
[docs] def reset(self, env_ids: list[int] | None = None): """One or more environments have been reset""" if not self.enabled: return if env_ids is None: env_ids = torch.arange(self.env.num_envs, device=gs.device) self.resample_command(env_ids)
[docs] def observation(self, env: GenesisEnv) -> torch.Tensor: """Function that returns the current command for each environment.""" return self.command
[docs] def use_external_controller(self, controller: Callable[[int], CommandRange]): """ Bypass the internal command controller, and generate the command values with an external control function. This can be used to connect a gamepad, joystick, or other external controller to the command manager. Args: controller: A function that takes the step index and returns a tensor of command values with the shape (num_envs, num_ranges). Example:: N_ENVS = 1 MIN_HEIGHT = 0.1 MAX_HEIGHT = 0.2 # Create environment class MyEnv(GenesisEnv): def config(self): self.height_command = CommandManager(self, range=(MIN_HEIGHT, MAX_HEIGHT)) # ... # Setup gamepad gamepad = Gamepad(GAMEPAD_PRODUCT) cmd_buffer = torch.zeros((N_ENVS, 1), device=gs.device) def gamepad_controller(_step): a_pressed = "A" in gamepad.state.buttons cmd_buffer[:, 0] = MAX_HEIGHT if a_pressed else MIN_HEIGHT return cmd_buffer # Create environment & connect gamepad env = MyEnv(num_envs=N_ENVS) env.build() env.command_manager.use_external_controller(gamepad_controller) """ self._external_controller = controller
[docs] def use_gamepad( self, gamepad: Gamepad, range_axis: int | dict[str, int], ): """ A wrapper around use_external_controller that converts a gamepad joystick axis to a command value. Args: gamepad: The gamepad to use. range_axis: The axis or dict of axes to use for the command value. This should match the range init param. Example:: # Create environment class MyEnv(GenesisEnv): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.height_command = CommandManager(self, range=(0.1, 0.2)) # ... # Connect gamepad gamepad = Gamepad(GAMEPAD_PRODUCT) # Create environment & connect gamepad env = MyEnv(num_envs=1) env.build() # Connect joystick axis 3 to the height command env.height_command.use_gamepad(gamepad_controller, range_axis=3) Example with multiple ranges:: # Create environment class MyEnv(GenesisEnv): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.height_command = CommandManager( self, range={ "cmd1": (-2.0, 2.0), "cmd2": (1.0, 5.0), } ) # ... # Connect gamepad gamepad = Gamepad(GAMEPAD_PRODUCT) # Create environment & connect gamepad env = MyEnv(num_envs=1) env.build() # Connect joystick axis 2 and 3 to to the indvidual ranges env.height_command.use_gamepad( gamepad_controller, range_axis={ "cmd1": 2, "cmd2": 3, }) """ self._external_controller = self._gamepad_axis_command # Map axis to range keys axis_map = [] if isinstance(range_axis, int): axis_map.append(range_axis) elif isinstance(range_axis, dict): for key in self._range.keys(): axis_map.append(range_axis[key]) self._gamepad_cfg = { "gamepad": gamepad, "axis_map": axis_map, } self._gamepad_axis_command_buffer = torch.zeros_like( self._command, device=gs.device )
[docs] def resample_command(self, env_ids: list[int]): """Create a new command for the given environment ids.""" # Get range values (this might have changed since init due to curriculum training) ranges = None if isinstance(self._range, dict): ranges = list(self._range.values()) else: ranges = [self._range] # Resample the command for i in range(self._command.shape[1]): buffer = torch.empty(len(env_ids), device=gs.device).uniform_(*ranges[i]) self._command[env_ids, i] = buffer
""" Implementation """ def _gamepad_axis_command(self, step_count: int) -> torch.Tensor: """ Get the command from the gamepad. """ if self._gamepad_cfg is None: return self._gamepad_axis_command_buffer gamepad = self._gamepad_cfg["gamepad"] axis_map = self._gamepad_cfg["axis_map"] # Convert the values to the commanded full range def convert_to_range(value: float, min: float, max: float) -> float: return (value - -1) * (max - min) / 2 + min # Set the gamepad commands to the buffer cmd = self._gamepad_axis_command_buffer ranges = [self._range] if isinstance(self._range, dict): ranges = list(self._range.values()) for i, axis in enumerate(axis_map): if i < len(ranges): cmd[:, i] = convert_to_range(gamepad.state.axis(axis), *ranges[i]) return cmd