Source code for genesis_forge.wrappers.skrl

import torch
from gymnasium import spaces
from typing import Any, Tuple

from skrl.envs.wrappers.torch.base import Wrapper as SkrlWrapper
from genesis_forge.wrappers.wrapper import Wrapper as GenesisWrapper


[docs] class SkrlEnvWapper(SkrlWrapper, GenesisWrapper): """ A wrapper that makes your genesis forge environment compatible with the skrl training framework. """ can_be_wrapped = False @property def action_space(self) -> spaces: """The action space of the environment.""" return self._env.action_space @property def observation_space(self) -> spaces: """The observation space of the environment.""" return self._env.observation_space
[docs] def reset(self) -> Tuple[torch.Tensor, Any]: """Reset the environment :raises NotImplementedError: Not implemented :return: Observation, info :rtype: torch.Tensor and any other info """ return self._env.reset()
[docs] def step( self, actions: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: """Perform a step in the environment :param actions: The actions to perform :type actions: torch.Tensor :return: Observation, reward, terminated, truncated, info :rtype: tuple of torch.Tensor and any other info """ obs, rewards, terminations, timeouts, extras = self._env.step(actions) # Expand rewards, terminations and timeouts to the shape (num_envs, 1) rewards = rewards.unsqueeze(1) terminations = terminations.unsqueeze(1) timeouts = timeouts.unsqueeze(1) return obs, rewards, terminations, timeouts, extras
[docs] def state(self) -> torch.Tensor: """Get the environment state :return: State :rtype: torch.Tensor """ return self.env.state()
[docs] def render(self, *args, **kwargs) -> Any: """ Not implemented for Genesis Forge environments. """ pass
[docs] def close(self) -> None: """Close the environment""" return self._env.close()
[docs] def build(self) -> None: """Build the environment""" self._env.build()