LukeDarlow's picture
Welcome to the CTM. This is the first commit of the public repo. Enjoy!
68b32f4
raw
history blame
1.27 kB
import gymnasium as gym
import numpy as np
class MaskVelocityWrapper(gym.Wrapper):
"""
Simple wrapper that automatically resets the environment on done.
Modeled after EpisodicLifeEnv but simplified since we don't need
to handle lives or partial resets.
"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
return self._apply_velocity_mask(obs), info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._apply_velocity_mask(obs), reward, terminated, truncated, info
def _apply_velocity_mask(self, observation):
gym_id = self.env.spec.id
if gym_id == "CartPole-v1":
return self._apply_velocity_mask_cartpole(observation)
elif gym_id == "Acrobot-v1":
return self._apply_velocity_mask_acrobot(observation)
else:
raise NotImplementedError
def _apply_velocity_mask_cartpole(self, observation):
return observation * np.array([1, 0, 1, 0], dtype="float32")
def _apply_velocity_mask_acrobot(self, observation):
return observation * np.array([1, 1, 1, 1, 0, 0], dtype="float32")