Source code for genesis_forge.wrappers.rsl_rl
from __future__ import annotations
import torch
from tensordict import TensorDict
from typing import Any, Union, Optional
import genesis as gs
from importlib import metadata
from genesis_forge.genesis_env import GenesisEnv
from genesis_forge.wrappers.wrapper import Wrapper
[docs]
class RslRlWrapper(Wrapper):
"""
A wrapper that makes your genesis forge environment compatible with the rsl_rl training framework.
IMPORTANT: This should be the last wrapper, as the change in the step and get_observations methods might break other wrappers.
What it does:
- Combines the terminated and truncated tensors into a single tensor (i.e. `terminated | truncated`).
- Add the truncated tensor to the extras dictionary as "time_outs".
- Returns observations and extras from the `get_observations` method.
Args:
env: The environment to wrap.
cfg: The configuration for the wrapper that will be passed to the neptune or wandb logger
"""
can_be_wrapped = False
def __init__(self, env: GenesisEnv, cfg: dict | object = {}):
super().__init__(env)
self.rsl3 = False
self.cfg = cfg
try:
major_version = int(metadata.version("rsl-rl-lib").split(".")[0])
if major_version >= 3:
self.rsl3 = True
except:
pass
@property
def device(self) -> str:
return gs.device
[docs]
def step(
self, actions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
"""
Returns a single "dones" tensor, instead of the terminated and truncated tensors (via `terminated | truncated`).
Add the truncated tensor to the extras dictionary as "time_outs".
"""
(
obs,
rewards,
terminated,
truncated,
extras,
) = super().step(actions)
# Combine terminated and truncated
dones = terminated | truncated
# Add observations and timeouts to extras
if extras is None:
extras = {}
extras = self._add_observations_to_extras(obs, extras)
obs = self._format_obs_group(obs, extras)
return obs, rewards, dones, extras
[docs]
def reset(self):
"""
Converts observations into a TensorDict for rsl_rl 3.0+
"""
(obs, extras) = self.env.reset()
obs = self._format_obs_group(obs, extras)
return obs, extras
[docs]
def get_observations(self):
"""
Returns observations as well as an extras dictionary with the observations added to the `extras["observations"]["critic"]` key.
"""
obs = self.env.get_observations()
# rsl_rl 3.0+ just returns the observations
if self.rsl3:
obs = self._format_obs_group(obs, self.env.extras)
return obs
# Earlier versions of rsl_rl adds critic observations to the extras dictionary
extras = self._add_observations_to_extras(obs, self.env.extras)
return obs, extras
def _add_observations_to_extras(
self, obs: torch.Tensor, extras: Optional[dict[str, Any]]
):
"""
Add the observations to the extras dictionary.
"""
if extras is None:
extras = {}
if "observations" not in extras:
extras["observations"] = {}
if "critic" not in extras["observations"]:
extras["observations"]["critic"] = obs
return extras
def _format_obs_group(
self, obs: torch.Tensor, extras: Optional[dict[str, Any]]
) -> Union[torch.Tensor, TensorDict]:
"""
If we're using rsl_rl 3.0+, put the observations into a TensorDict
"""
if self.rsl3:
if extras is not None and "observations" in extras:
if isinstance(extras["observations"], TensorDict):
obs = extras["observations"]
else:
obs = TensorDict(extras["observations"], device=gs.device)
else:
obs = TensorDict(
{"policy": obs},
batch_size=[obs.shape[0]],
device=gs.device,
)
return obs