Skip to content

Commit

Permalink
polish(pu): polish comments in env_wrappers.py and ding_env_wrapper.py
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Oct 27, 2023
1 parent e6eea3d commit ef288f0
Show file tree
Hide file tree
Showing 2 changed files with 768 additions and 395 deletions.
134 changes: 130 additions & 4 deletions ding/envs/env/ding_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,27 @@


class DingEnvWrapper(BaseEnv):
"""
Overview:
This is a wrapper for the BaseEnv class, used to provide a consistent environment interface.
Interfaces:
__init__, reset, step, close, seed, random_action, _wrap_env, __repr__, create_collector_env_cfg,
create_evaluator_env_cfg, enable_save_replay, observation_space, action_space, reward_space, clone
"""

def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, caller: str = 'collector') -> None:
"""
You can pass in either an env instance, or a config to create an env instance:
- An env instance: Parameter `env` must not be `None`, but should be the instance.
Do not support subprocess env manager; Thus usually used in simple env.
- A config to create an env instance: Parameter `cfg` dict must contain `env_id`.
Overview:
Initialize the DingEnvWrapper. Either an environment instance or a config to create the environment
instance should be passed in:
- An environment instance: The `env` parameter must not be `None`, but should be the instance.
It does not support subprocess environment manager. Thus, it is usually used in simple environments.
- A config to create an environment instance: The `cfg` parameter must contain `env_id`.
Arguments:
- env (:obj:`gym.Env`): An environment instance to be wrapped.
- cfg (:obj:`dict`): The configuration dictionary to create an environment instance.
- seed_api (:obj:`bool`): Whether to use seed API, default is True.
- caller (:obj:`str`): The caller of this method, default is 'collector'.
"""
self._env = None
self._raw_env = env
Expand Down Expand Up @@ -59,7 +73,14 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True,
self._replay_path = None

# override

def reset(self) -> None:
"""
Overview:
Resets the state of the environment. If the environment is not initialized, it will be created first.
Returns:
- obs (:obj:`Dict`): The new observation after reset.
"""
if not self._init_flag:
self._env = gym.make(self._cfg.env_id)
self._wrap_env(self._caller)
Expand Down Expand Up @@ -108,6 +129,12 @@ def reset(self) -> None:

# override
def close(self) -> None:
"""
Overview:
Clean up the environment by closing and deleting it.
This method should be called when the environment is no longer needed.
Failing to call this method can lead to memory leaks.
"""
try:
self._env.close()
del self._env
Expand All @@ -116,12 +143,27 @@ def close(self) -> None:

# override
def seed(self, seed: int, dynamic_seed: bool = True) -> None:
"""
Overview:
Set the seed for the environment.
Arguments:
- seed (:obj:`int`): The seed to set.
- dynamic_seed (:obj:`bool`): Whether to use dynamic seed, default is True.
"""
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)

# override
def step(self, action: Union[np.int64, np.ndarray]) -> BaseEnvTimestep:
"""
Overview:
Execute the given action in the environment, and return the timestep (observation, reward, done, info).
Arguments:
- action (:obj:`Union[np.int64, np.ndarray]`): The action to execute in the environment.
Returns:
- timestep (:obj:`BaseEnvTimestep`): The timestep after the action execution.
"""
action = self._judge_action_type(action)
if self._cfg.act_scale:
action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high)
Expand All @@ -137,6 +179,15 @@ def step(self, action: Union[np.int64, np.ndarray]) -> BaseEnvTimestep:
return BaseEnvTimestep(obs, rew, done, info)

def _judge_action_type(self, action: Union[np.ndarray, dict]) -> Union[np.ndarray, dict]:
"""
Overview:
Ensure the action taken by the agent is of the correct type.
This method is used to standardize different action types to a common format.
Arguments:
- action (Union[np.ndarray, dict]): The action taken by the agent.
Returns:
- action (Union[np.ndarray, dict]): The formatted action.
"""
if isinstance(action, int):
return action
elif isinstance(action, np.int64):
Expand All @@ -161,6 +212,12 @@ def _judge_action_type(self, action: Union[np.ndarray, dict]) -> Union[np.ndarra
)

def random_action(self) -> np.ndarray:
"""
Overview:
Return a random action from the action space of the environment.
Returns:
- action (:obj:`np.ndarray`): The random action.
"""
random_action = self.action_space.sample()
if isinstance(random_action, np.ndarray):
pass
Expand All @@ -177,6 +234,12 @@ def random_action(self) -> np.ndarray:
return random_action

def _wrap_env(self, caller: str = 'collector') -> None:
"""
Overview:
Wrap the environment according to the configuration.
Arguments:
- caller (:obj:`str`): The caller of this method, default is 'collector'.
"""
# wrapper_cfgs: Union[str, List]
wrapper_cfgs = self._cfg.env_wrapper
if isinstance(wrapper_cfgs, str):
Expand All @@ -191,42 +254,105 @@ def _wrap_env(self, caller: str = 'collector') -> None:
self._env = wrapper(self._env)

def __repr__(self) -> str:
"""
Overview:
Return the string representation of the instance.
Returns:
- str (:obj:`str`): The string representation of the instance.
"""
return "DI-engine Env({}), generated by DingEnvWrapper".format(self._cfg.env_id)

@staticmethod
def create_collector_env_cfg(cfg: dict) -> List[dict]:
"""
Overview:
Create a list of environment configuration for collectors based on the input configuration.
Arguments:
- cfg (:obj:`dict`): The input configuration dictionary.
Returns:
- env_cfgs (:obj:`List[dict]`): The list of environment configurations for collectors.
"""
actor_env_num = cfg.pop('collector_env_num')
cfg = copy.deepcopy(cfg)
cfg.is_train = True
return [cfg for _ in range(actor_env_num)]

@staticmethod
def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
"""
Overview:
Create a list of environment configuration for evaluators based on the input configuration.
Arguments:
- cfg (:obj:`dict`): The input configuration dictionary.
Returns:
- env_cfgs (:obj:`List[dict]`): The list of environment configurations for evaluators.
"""
evaluator_env_num = cfg.pop('evaluator_env_num')
cfg = copy.deepcopy(cfg)
cfg.is_train = False
return [cfg for _ in range(evaluator_env_num)]

def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
"""
Overview:
Enable the save replay functionality. The replay will be saved at the specified path.
Arguments:
- replay_path (:obj:`Optional[str]`): The path to save the replay, default is None.
"""
if replay_path is None:
replay_path = './video'
self._replay_path = replay_path

@property
def observation_space(self) -> gym.spaces.Space:
"""
Overview:
Return the observation space of the wrapped environment.
The observation space represents the range and shape of possible observations
that the environment can provide to the agent.
Note:
If the data type of the observation space is float64, it's converted to float32
for better compatibility with most machine learning libraries.
Returns:
- observation_space (gym.spaces.Space): The observation space of the environment.
"""
if self._observation_space.dtype == np.float64:
self._observation_space.dtype = np.float32
return self._observation_space

@property
def action_space(self) -> gym.spaces.Space:
"""
Overview:
Return the action space of the wrapped environment.
The action space represents the range and shape of possible actions
that the agent can take in the environment.
Returns:
- action_space (gym.spaces.Space): The action space of the environment.
"""
return self._action_space

@property
def reward_space(self) -> gym.spaces.Space:
"""
Overview:
Return the reward space of the wrapped environment.
The reward space represents the range and shape of possible rewards
that the agent can receive as a result of its actions.
Returns:
- reward_space (gym.spaces.Space): The reward space of the environment.
"""
return self._reward_space

def clone(self, caller: str = 'collector') -> BaseEnv:
"""
Overview:
Clone the current environment wrapper, creating a new environment with the same settings.
Arguments:
- caller (str): A string representing the caller of this method, default is 'collector'.
Returns:
- DingEnvWrapper: A new instance of the environment with the same settings.
"""
try:
spec = copy.deepcopy(self._raw_env.spec)
raw_env = CloudPickleWrapper(self._raw_env)
Expand Down
Loading

0 comments on commit ef288f0

Please sign in to comment.