Source code for genesis_forge.wrappers.skrl
import torch
from gymnasium import spaces
from typing import Any, Tuple
from skrl.envs.wrappers.torch.base import Wrapper as SkrlWrapper
from genesis_forge.wrappers.wrapper import Wrapper as GenesisWrapper
[docs]
class SkrlEnvWapper(SkrlWrapper, GenesisWrapper):
"""
A wrapper that makes your genesis forge environment compatible with the skrl training framework.
"""
can_be_wrapped = False
@property
def action_space(self) -> spaces:
"""The action space of the environment."""
return self._env.action_space
@property
def observation_space(self) -> spaces:
"""The observation space of the environment."""
return self._env.observation_space
[docs]
def reset(self) -> Tuple[torch.Tensor, Any]:
"""Reset the environment
:raises NotImplementedError: Not implemented
:return: Observation, info
:rtype: torch.Tensor and any other info
"""
return self._env.reset()
[docs]
def step(
self, actions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]:
"""Perform a step in the environment
:param actions: The actions to perform
:type actions: torch.Tensor
:return: Observation, reward, terminated, truncated, info
:rtype: tuple of torch.Tensor and any other info
"""
obs, rewards, terminations, timeouts, extras = self._env.step(actions)
# Expand rewards, terminations and timeouts to the shape (num_envs, 1)
rewards = rewards.unsqueeze(1)
terminations = terminations.unsqueeze(1)
timeouts = timeouts.unsqueeze(1)
return obs, rewards, terminations, timeouts, extras
[docs]
def state(self) -> torch.Tensor:
"""Get the environment state
:return: State
:rtype: torch.Tensor
"""
return self.env.state()
[docs]
def render(self, *args, **kwargs) -> Any:
"""
Not implemented for Genesis Forge environments.
"""
pass
[docs]
def close(self) -> None:
"""Close the environment"""
return self._env.close()
[docs]
def build(self) -> None:
"""Build the environment"""
self._env.build()