diff --git a/examples/cartpole/callbacks.yaml b/examples/cartpole/callbacks.yaml index 383736a7..810ff445 100644 --- a/examples/cartpole/callbacks.yaml +++ b/examples/cartpole/callbacks.yaml @@ -2,7 +2,18 @@ callbacks: - id: "CheckpointCallback" args: { "save_freq": 500, - "save_path": "./checkpoints/", + "save_path": "./results/checkpoints/", "name_prefix": "ppo", "save_replay_buffer": True + } + - id: "EvalCallback" + args: { + "eval_env": {"id": "CartPole-v1","env_num":4}, + "n_eval_episodes": 4, + "eval_freq":500, + "log_path": "./results/eval_log_path", + "best_model_save_path": "./results/best_model/", + "deterministic": True, + "render": True, + "asynchronous": True, } \ No newline at end of file diff --git a/examples/cartpole/train_ppo.py b/examples/cartpole/train_ppo.py index 8b61a923..739f804d 100644 --- a/examples/cartpole/train_ppo.py +++ b/examples/cartpole/train_ppo.py @@ -49,4 +49,4 @@ def evaluation(agent): if __name__ == "__main__": agent = train() - # evaluation(agent) + evaluation(agent) diff --git a/openrl/envs/common/build_envs.py b/openrl/envs/common/build_envs.py index 00a7b0ac..fb06964a 100644 --- a/openrl/envs/common/build_envs.py +++ b/openrl/envs/common/build_envs.py @@ -40,6 +40,7 @@ def _make_env() -> Env: env = wrapper(env) else: raise NotImplementedError + return env return _make_env diff --git a/openrl/envs/common/registration.py b/openrl/envs/common/registration.py index 23cb8399..fc430aba 100644 --- a/openrl/envs/common/registration.py +++ b/openrl/envs/common/registration.py @@ -18,11 +18,11 @@ from typing import Callable, Optional import gymnasium as gym -from gymnasium import Env import openrl from openrl.envs.vec_env import ( AsyncVectorEnv, + BaseVecEnv, RewardWrapper, SyncVectorEnv, VecMonitorWrapper, @@ -40,7 +40,7 @@ def make( render_mode: Optional[str] = None, make_custom_envs: Optional[Callable] = None, **kwargs, -) -> Env: +) -> BaseVecEnv: if render_mode in [None, "human", "rgb_array"]: convert_render_mode = render_mode elif render_mode in ["group_human", "group_rgb_array"]: diff --git a/openrl/envs/connect3/__init__.py b/openrl/envs/connect3/__init__.py index 771dc9bb..00beaed0 100644 --- a/openrl/envs/connect3/__init__.py +++ b/openrl/envs/connect3/__init__.py @@ -30,10 +30,7 @@ def make_connect3_envs( render_mode: Optional[Union[str, List[str]]] = None, **kwargs, ) -> List[Callable[[], Env]]: - from openrl.envs.wrappers import ( - RemoveTruncated, - Single2MultiAgentWrapper, - ) + from openrl.envs.wrappers import RemoveTruncated, Single2MultiAgentWrapper env_wrappers = [ Single2MultiAgentWrapper, @@ -48,5 +45,3 @@ def make_connect3_envs( **kwargs, ) return env_fns - - diff --git a/openrl/envs/connect3/connect3_env.py b/openrl/envs/connect3/connect3_env.py index c73fe90d..c401f556 100644 --- a/openrl/envs/connect3/connect3_env.py +++ b/openrl/envs/connect3/connect3_env.py @@ -1,7 +1,7 @@ from typing import Any, Optional -import numpy as np import gymnasium as gym +import numpy as np from gymnasium import Env, spaces @@ -12,8 +12,7 @@ def make( ) -> Env: # create Connect3 environment from id if id == "connect3": - env = Connect3Env(env_name=id, - args=kwargs) + env = Connect3Env(env_name=id, args=kwargs) return env @@ -21,7 +20,12 @@ def make( def check_if_win(state, check_row_pos, check_col_pos, all_args): def check_if_win_direction(now_state, direction, row_pos, col_pos, args): def check_if_valid(x_pos, y_pos): - return x_pos >= 0 and x_pos <= (args["row"] - 1) and y_pos >= 0 and y_pos <= (args["col"] - 1) + return ( + x_pos >= 0 + and x_pos <= (args["row"] - 1) + and y_pos >= 0 + and y_pos <= (args["col"] - 1) + ) check_who = now_state[row_pos][col_pos] counting = 1 @@ -29,7 +33,10 @@ def check_if_valid(x_pos, y_pos): while True: new_row_pos = row_pos + bias_num * direction[0] new_col_pos = col_pos + bias_num * direction[1] - if not check_if_valid(new_row_pos, new_col_pos) or now_state[new_row_pos][new_col_pos] != check_who: + if ( + not check_if_valid(new_row_pos, new_col_pos) + or now_state[new_row_pos][new_col_pos] != check_who + ): break else: counting += 1 @@ -38,7 +45,10 @@ def check_if_valid(x_pos, y_pos): while True: new_row_pos = row_pos + bias_num * direction[0] new_col_pos = col_pos + bias_num * direction[1] - if not check_if_valid(new_row_pos, new_col_pos) or now_state[new_row_pos][new_col_pos] != check_who: + if ( + not check_if_valid(new_row_pos, new_col_pos) + or now_state[new_row_pos][new_col_pos] != check_who + ): break else: counting += 1 @@ -50,7 +60,9 @@ def check_if_valid(x_pos, y_pos): directions = [(1, 0), (0, 1), (1, 1), (1, -1)] # 横 竖 右下 右上 for direction in directions: - if check_if_win_direction(state, direction, check_row_pos, check_col_pos, all_args): + if check_if_win_direction( + state, direction, check_row_pos, check_col_pos, all_args + ): return True return False @@ -69,36 +81,40 @@ def __init__(self, env_name, args): obs_space_low = np.zeros(obs_space_dim) - 1e6 obs_space_high = np.zeros(obs_space_dim) + 1e6 - obs_space_type = 'float64' + obs_space_type = "float64" sobs_space_dim = obs_space_dim * args["num_agents"] sobs_space_low = np.zeros(sobs_space_dim) - 1e6 sobs_space_high = np.zeros(sobs_space_dim) + 1e6 if args["num_agents"] > 1: - self.action_space = [spaces.Discrete(self.action_size) for _ in range(args["num_agents"])] - self.observation_space = [spaces.Box( - low=obs_space_low, - high=obs_space_high, - dtype=obs_space_type) for _ in range(args["num_agents"])] - self.share_observation_space = [spaces.Box( - low=sobs_space_low, - high=sobs_space_high, - dtype=obs_space_type) for _ in range(args["num_agents"])] + self.action_space = [ + spaces.Discrete(self.action_size) for _ in range(args["num_agents"]) + ] + self.observation_space = [ + spaces.Box(low=obs_space_low, high=obs_space_high, dtype=obs_space_type) + for _ in range(args["num_agents"]) + ] + self.share_observation_space = [ + spaces.Box( + low=sobs_space_low, high=sobs_space_high, dtype=obs_space_type + ) + for _ in range(args["num_agents"]) + ] else: self.action_space = spaces.Discrete(self.action_size) self.observation_space = spaces.Box( - low=obs_space_low, - high=obs_space_high, - dtype=obs_space_type) + low=obs_space_low, high=obs_space_high, dtype=obs_space_type + ) self.share_observation_space = spaces.Box( - low=sobs_space_low, - high=sobs_space_high, - dtype=obs_space_type) + low=sobs_space_low, high=sobs_space_high, dtype=obs_space_type + ) def step(self, action, is_enemy=True): # 传入action为0~8的数字 row_pos, col_pos = action // self.col, action % self.col - assert self.state[row_pos][col_pos] == 0, "({}, {}) pos has already be taken".format(row_pos, col_pos) + assert ( + self.state[row_pos][col_pos] == 0 + ), "({}, {}) pos has already be taken".format(row_pos, col_pos) self.state[row_pos][col_pos] = 2 if is_enemy else 1 done, have_winner = False, False @@ -110,28 +126,28 @@ def step(self, action, is_enemy=True): if done: if have_winner: reward = (-1) * self.reward if is_enemy else self.reward - winner = 'enemy' if is_enemy else 'self' + winner = "enemy" if is_enemy else "self" else: - winner = 'tie' + winner = "tie" reward = 0 else: reward = 0 - winner = 'no' - info = {'who_win': winner} + winner = "no" + info = {"who_win": winner} return self.state.flatten().copy(), reward, done, False, info def check_if_finish(self): return (self.state == 0).sum() == 0 - def reset(self, seed=None, options=None, set_who_first=None): + def reset(self, seed=None, options=None, set_who_first=None): self.state = np.zeros([self.row, self.col]) # 0无棋子,1我方棋子,2敌方棋子 if set_who_first is not None: who_first = set_who_first else: if np.random.random() > 0.5: - who_first = 'enemy' + who_first = "enemy" else: - who_first = 'self' + who_first = "self" obs = self.state.flatten().copy() # return obs, {"who_first": who_first} return obs, {} @@ -146,13 +162,8 @@ def close(self): pass -if __name__ == '__main__': - args = { - "row": 3, - "col": 3, - "num_to_win": 3, - "num_agents": 1 - } +if __name__ == "__main__": + args = {"row": 3, "col": 3, "num_to_win": 3, "num_agents": 1} env = Connect3Env(env_name="connect3", args=args) obs, info = env.reset() obs, reward, done, _, info = env.step(1, is_enemy=True) diff --git a/openrl/envs/vec_env/async_venv.py b/openrl/envs/vec_env/async_venv.py index bc07e50b..3fb27f83 100644 --- a/openrl/envs/vec_env/async_venv.py +++ b/openrl/envs/vec_env/async_venv.py @@ -6,7 +6,7 @@ from enum import Enum from multiprocessing import Queue from multiprocessing.connection import Connection -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import gymnasium as gym import numpy as np @@ -21,7 +21,7 @@ from gymnasium.vector.utils import CloudpickleWrapper, clear_mpi_env_vars from numpy.typing import NDArray -from openrl.envs.vec_env.base_venv import BaseVecEnv +from openrl.envs.vec_env.base_venv import BaseVecEnv, VecEnvIndices from openrl.envs.vec_env.utils.numpy_utils import ( concatenate, create_empty_array, @@ -32,6 +32,8 @@ read_from_shared_memory, write_to_shared_memory, ) +from openrl.envs.wrappers.base_wrapper import BaseWrapper +from openrl.envs.wrappers.util import is_wrapped class AsyncState(Enum): @@ -608,19 +610,65 @@ def call_fetch(self, timeout: Union[int, float, None] = None) -> list: return results - def call(self, name: str, *args, **kwargs) -> List[Any]: - """Call a method, or get a property, from each parallel environment. + def exec_func_send(self, func: Callable, indices, *args, **kwargs): + """Calls the method with name asynchronously and apply args and kwargs to the method. Args: - name (str): Name of the method or property to call. + func: a function. + indices: Indices of the environments to call the method on. *args: Arguments to apply to the method call. **kwargs: Keyword arguments to apply to the method call. + Raises: + ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). + AlreadyPendingCallError: Calling `call_send` while waiting for a pending call to complete + """ + self._assert_is_running() + if self._state != AsyncState.DEFAULT: + raise AlreadyPendingCallError( + ( + "Calling `exec_func_send` while waiting " + f"for a pending call to `{self._state.value}` to complete." + ), + str(self._state.value), + ) + + for pipe in self.parent_pipes: + pipe.send(("_func_exec", (func, indices, args, kwargs))) + self._state = AsyncState.WAITING_CALL + + def exec_func_fetch(self, timeout: Union[int, float, None] = None) -> list: + """Calls all parent pipes and waits for the results. + + Args: + timeout: Number of seconds before the call to `step_fetch` times out. + If `None` (default), the call to `step_fetch` never times out. + Returns: List of the results of the individual calls to the method or property for each environment. + + Raises: + NoAsyncCallError: Calling `call_fetch` without any prior call to `call_send`. + TimeoutError: The call to `call_fetch` has timed out after timeout second(s). """ - self.call_send(name, *args, **kwargs) - return self.call_fetch() + self._assert_is_running() + if self._state != AsyncState.WAITING_CALL: + raise NoAsyncCallError( + "Calling `exec_func_fetch` without any prior call to `exec_func_send`.", + AsyncState.WAITING_CALL.value, + ) + + if not self._poll(timeout): + self._state = AsyncState.DEFAULT + raise mp.TimeoutError( + f"The call to `call_fetch` has timed out after {timeout} second(s)." + ) + + results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) + self._raise_if_errors(successes) + self._state = AsyncState.DEFAULT + + return results def get_attr(self, name: str): """Get a property from each parallel environment. @@ -670,6 +718,16 @@ def set_attr(self, name: str, values: Union[List[Any], Tuple[Any], object]): _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) self._raise_if_errors(successes) + # def env_is_wrapped( + # self, wrapper_class: Type[BaseWrapper], indices: VecEnvIndices = None + # ) -> List[bool]: + # """Check if worker environments are wrapped with a given wrapper""" + # indices = self._get_indices(indices) + # results = self.exec_func( + # is_wrapped, indices=indices, wrapper_class=wrapper_class + # ) + # return [results[i] for i in indices] + def _worker( index: int, @@ -771,6 +829,15 @@ def prepare_obs(observation): True, ) ) + elif command == "_func_exec": + function, indices, args, kwargs = data + if index in indices: + if callable(function): + pipe.send((function(env, *args, **kwargs), True)) + else: + pipe.send((function, True)) + else: + pipe.send((None, True)) elif command == "_call": name, args, kwargs = data if name in ["reset", "step", "seed", "close"]: @@ -783,6 +850,7 @@ def prepare_obs(observation): pipe.send((function(*args, **kwargs), True)) else: pipe.send((function, True)) + elif command == "_setattr": name, value = data setattr(env, name, value) diff --git a/openrl/envs/vec_env/base_venv.py b/openrl/envs/vec_env/base_venv.py index 1ec66ad5..c6f618ae 100644 --- a/openrl/envs/vec_env/base_venv.py +++ b/openrl/envs/vec_env/base_venv.py @@ -18,16 +18,22 @@ import sys import warnings from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Type, Union import gymnasium as gym import numpy as np from openrl.envs.vec_env.utils.numpy_utils import single_random_action from openrl.envs.vec_env.utils.util import prepare_available_actions, tile_images +from openrl.envs.wrappers.base_wrapper import BaseWrapper +from openrl.envs.wrappers.util import is_wrapped IN_COLAB = "google.colab" in sys.modules +# Define type aliases here to avoid circular import +# Used when we want to access one or more VecEnv +VecEnvIndices = Union[None, int, Iterable[int]] + class BaseVecEnv( ABC, @@ -233,6 +239,52 @@ def call(self, name: str, *args, **kwargs) -> List[Any]: self.call_send(name, *args, **kwargs) return self.call_fetch() + def exec_func_send(self, func: Callable, indices, *args, **kwargs): + """Calls the method with name asynchronously and apply args and kwargs to the method. + + Args: + func: a function. + indices: Indices of the environments to call the method on. + *args: Arguments to apply to the method call. + **kwargs: Keyword arguments to apply to the method call. + + Raises: + ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). + AlreadyPendingCallError: Calling `call_send` while waiting for a pending call to complete + """ + + def exec_func_fetch(self, timeout: Union[int, float, None] = None) -> list: + """Calls all parent pipes and waits for the results. + + Args: + timeout: Number of seconds before the call to `step_fetch` times out. + If `None` (default), the call to `step_fetch` never times out. + + Returns: + List of the results of the individual calls to the method or property for each environment. + + Raises: + NoAsyncCallError: Calling `call_fetch` without any prior call to `call_send`. + TimeoutError: The call to `call_fetch` has timed out after timeout second(s). + """ + + def exec_func( + self, func: Callable, indices: List[int], *args, **kwargs + ) -> List[Any]: + """Call a method, or get a property, from each parallel environment. + + Args: + func : Name of the method to call. + indices: Indices of the environments to call the method on. + *args: Arguments to apply to the method call. + **kwargs: Keyword arguments to apply to the method call. + + Returns: + List of the results of the individual calls to the method or property for each environment. + """ + self.exec_func_send(func, indices, *args, **kwargs) + return self.exec_func_fetch() + def get_attr(self, name: str): """Get a property from each parallel environment. @@ -278,3 +330,26 @@ def random_action(self, infos: Optional[List[Dict[str, Any]]] = None): for env_index in range(self.parallel_env_num) ] ) + + def env_is_wrapped( + self, wrapper_class: Type[BaseWrapper], indices: VecEnvIndices = None + ) -> List[bool]: + """Check if worker environments are wrapped with a given wrapper""" + indices = self._get_indices(indices) + results = self.exec_func( + is_wrapped, indices=indices, wrapper_class=wrapper_class + ) + return [results[i] for i in indices] + + def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]: + """ + Convert a flexibly-typed reference to environment indices to an implied list of indices. + + :param indices: refers to indices of envs. + :return: the implied list of indices. + """ + if indices is None: + indices = range(self.parallel_env_num) + elif isinstance(indices, int): + indices = [indices] + return indices diff --git a/openrl/envs/vec_env/sync_venv.py b/openrl/envs/vec_env/sync_venv.py index af8b18eb..85066973 100644 --- a/openrl/envs/vec_env/sync_venv.py +++ b/openrl/envs/vec_env/sync_venv.py @@ -16,19 +16,20 @@ """""" from copy import deepcopy -from typing import Any, Callable, Iterable, List, Optional, Sequence, Union +from typing import Any, Callable, Iterable, List, Optional, Sequence, Type, Union import numpy as np from gymnasium import Env from gymnasium.core import ActType from gymnasium.spaces import Space -from openrl.envs.vec_env.base_venv import BaseVecEnv +from openrl.envs.vec_env.base_venv import BaseVecEnv, VecEnvIndices from openrl.envs.vec_env.utils.numpy_utils import ( concatenate, create_empty_array, iterate_action, ) +from openrl.envs.wrappers.base_wrapper import BaseWrapper class SyncVectorEnv(BaseVecEnv): @@ -273,6 +274,29 @@ def env_name(self): else: return self.envs[0].unwrapped.spec.id + def exec_func(self, func: Callable, indices: List[int], *args, **kwargs) -> tuple: + """Calls the method with name and applies args and kwargs. + + Args: + func: The method name + *args: The method args + **kwargs: The method kwargs + + Returns: + Tuple of results + """ + results = [] + for i, env in enumerate(self.envs): + if i in indices: + if callable(func): + results.append(func(env, *args, **kwargs)) + else: + results.append(func) + else: + results.append(None) + + return tuple(results) + def call(self, name, *args, **kwargs) -> tuple: """Calls the method with name and applies args and kwargs. diff --git a/openrl/envs/vec_env/wrappers/base_wrapper.py b/openrl/envs/vec_env/wrappers/base_wrapper.py index cbb343a0..bd1e16d4 100644 --- a/openrl/envs/vec_env/wrappers/base_wrapper.py +++ b/openrl/envs/vec_env/wrappers/base_wrapper.py @@ -16,20 +16,32 @@ """""" - -from typing import Any, Dict, Optional, Sequence, SupportsFloat, Tuple, TypeVar, Union +from abc import ABC +from typing import ( + Any, + Dict, + List, + Optional, + Sequence, + SupportsFloat, + Tuple, + Type, + TypeVar, + Union, +) import numpy as np from gymnasium import spaces from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType from gymnasium.utils import seeding -from openrl.envs.vec_env.base_venv import BaseVecEnv +from openrl.envs.vec_env.base_venv import BaseVecEnv, VecEnvIndices +from openrl.envs.wrappers import BaseWrapper ArrayType = TypeVar("ArrayType") -class VecEnvWrapper(BaseVecEnv): +class VecEnvWrapper(BaseVecEnv, ABC): """Wraps the vectorized environment to allow a modular transformation. This class is the base class for all wrappers for vectorized environments. The subclass @@ -179,6 +191,11 @@ def np_random(self) -> np.random.Generator: def np_random(self, value: np.random.Generator): self._np_random = value + def env_is_wrapped( + self, wrapper_class: Type[BaseWrapper], indices: VecEnvIndices = None + ) -> List[bool]: + return self.env.env_is_wrapped(wrapper_class, indices=indices) + class VectorObservationWrapper(VecEnvWrapper): """Wraps the vectorized environment to allow a modular transformation of the observation. Equivalent to :class:`gym.ObservationWrapper` for vectorized environments.""" diff --git a/openrl/envs/wrappers/util.py b/openrl/envs/wrappers/util.py index 119de315..a2e185f0 100644 --- a/openrl/envs/wrappers/util.py +++ b/openrl/envs/wrappers/util.py @@ -16,11 +16,14 @@ """""" -from typing import Any +from typing import Any, Optional, Type +import gymnasium as gym import numpy as np from gymnasium.spaces.box import Box +from openrl.envs.wrappers.base_wrapper import BaseWrapper + def nest_expand_dim(input: Any) -> Any: if isinstance(input, (np.ndarray, float, int)): @@ -35,3 +38,32 @@ def nest_expand_dim(input: Any) -> Any: return [input] else: raise NotImplementedError("Not support type: {}".format(type(input))) + + +def unwrap_wrapper( + env: gym.Env, wrapper_class: Type[BaseWrapper] +) -> Optional[BaseWrapper]: + """ + Retrieve a ``BaseWrapper`` object by recursively searching. + + :param env: Environment to unwrap + :param wrapper_class: Wrapper to look for + :return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it + """ + env_tmp = env + while isinstance(env_tmp, BaseWrapper): + if isinstance(env_tmp, wrapper_class): + return env_tmp + env_tmp = env_tmp.env + return None + + +def is_wrapped(env: gym.Env, wrapper_class: Type[BaseWrapper]) -> bool: + """ + Check if a given environment has been wrapped with a given wrapper. + + :param env: Environment to check + :param wrapper_class: Wrapper class to look for + :return: True if environment has been wrapped with ``wrapper_class``. + """ + return unwrap_wrapper(env, wrapper_class) is not None diff --git a/openrl/runners/common/rl_agent.py b/openrl/runners/common/rl_agent.py index 2b7587c2..45f0521b 100644 --- a/openrl/runners/common/rl_agent.py +++ b/openrl/runners/common/rl_agent.py @@ -26,13 +26,13 @@ from openrl.modules.common import BaseNet from openrl.runners.common.base_agent import BaseAgent, SelfAgent +from openrl.utils.callbacks import CallbackFactory from openrl.utils.callbacks.callbacks import ( BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback, ) -from openrl.utils.callbacks.callbacks_factory import CallbackFactory from openrl.utils.type_aliases import MaybeCallback diff --git a/openrl/utils/callbacks/__init__.py b/openrl/utils/callbacks/__init__.py index 663cfed7..1d13979e 100644 --- a/openrl/utils/callbacks/__init__.py +++ b/openrl/utils/callbacks/__init__.py @@ -15,3 +15,8 @@ # limitations under the License. """""" + + +from openrl.utils.callbacks.callbacks_factory import CallbackFactory + +__all__ = ["CallbackFactory"] diff --git a/openrl/utils/callbacks/callbacks.py b/openrl/utils/callbacks/callbacks.py index e496919a..bce5f9b4 100644 --- a/openrl/utils/callbacks/callbacks.py +++ b/openrl/utils/callbacks/callbacks.py @@ -1,6 +1,5 @@ # Modified from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/callbacks.py -import os import warnings from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Union @@ -21,9 +20,8 @@ # if the progress bar is used tqdm = None -from openrl.envs.vec_env import BaseVecEnv, SyncVectorEnv +from openrl.envs.vec_env import BaseVecEnv from openrl.runners.common.base_agent import BaseAgent -from openrl.utils.evaluation import evaluate_policy class BaseCallback(ABC): @@ -258,209 +256,6 @@ def _on_step(self) -> bool: return True -class EvalCallback(EventCallback): - """ - Callback for evaluating an agent. - - .. warning:: - - When using multiple environments, each call to ``env.step()`` - will effectively correspond to ``n_envs`` steps. - To account for that, you can use ``eval_freq = max(eval_freq // n_envs, 1)`` - - :param eval_env: The environment used for initialization - :param callback_on_new_best: Callback to trigger - when there is a new best model according to the ``mean_reward`` - :param callback_after_eval: Callback to trigger after every evaluation - :param n_eval_episodes: The number of episodes to test the agent - :param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback. - :param log_path: Path to a folder where the evaluations (``evaluations.npz``) - will be saved. It will be updated at each evaluation. - :param best_model_save_path: Path to a folder where the best model - according to performance on the eval env will be saved. - :param deterministic: Whether the evaluation should - use a stochastic or deterministic actions. - :param render: Whether to render or not the environment during evaluation - :param verbose: Verbosity level: 0 for no output, 1 for indicating information about evaluation results - :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been - wrapped with a Monitor wrapper) - """ - - def __init__( - self, - eval_env: Union[gym.Env, BaseVecEnv], - callback_on_new_best: Optional[BaseCallback] = None, - callback_after_eval: Optional[BaseCallback] = None, - n_eval_episodes: int = 5, - eval_freq: int = 10000, - log_path: Optional[str] = None, - best_model_save_path: Optional[str] = None, - deterministic: bool = True, - render: bool = False, - verbose: int = 1, - warn: bool = True, - ): - super().__init__(callback_after_eval, verbose=verbose) - - self.callback_on_new_best = callback_on_new_best - if self.callback_on_new_best is not None: - # Give access to the parent - self.callback_on_new_best.parent = self - - self.n_eval_episodes = n_eval_episodes - self.eval_freq = eval_freq - self.best_mean_reward = -np.inf - self.last_mean_reward = -np.inf - self.deterministic = deterministic - self.render = render - self.warn = warn - - # Convert to BaseVecEnv for consistency - if not isinstance(eval_env, BaseVecEnv): - eval_env = SyncVectorEnv([lambda: eval_env]) - - self.eval_env = eval_env - self.best_model_save_path = best_model_save_path - # Logs will be written in ``evaluations.npz`` - if log_path is not None: - log_path = os.path.join(log_path, "evaluations") - self.log_path = log_path - self.evaluations_results = [] - self.evaluations_timesteps = [] - self.evaluations_length = [] - # For computing success rate - self._is_success_buffer = [] - self.evaluations_successes = [] - - def _init_callback(self) -> None: - # Does not work in some corner cases, where the wrapper is not the same - if not isinstance(self.training_env, type(self.eval_env)): - warnings.warn( - "Training and eval env are not of the same type" - f"{self.training_env} != {self.eval_env}" - ) - - # Create folders if needed - if self.best_model_save_path is not None: - os.makedirs(self.best_model_save_path, exist_ok=True) - if self.log_path is not None: - os.makedirs(os.path.dirname(self.log_path), exist_ok=True) - - # Init callback called on new best model - if self.callback_on_new_best is not None: - self.callback_on_new_best.init_callback(self.agent) - - def _log_success_callback( - self, locals_: Dict[str, Any], globals_: Dict[str, Any] - ) -> None: - """ - Callback passed to the ``evaluate_policy`` function - in order to log the success rate (when applicable), - for instance when using HER. - - :param locals_: - :param globals_: - """ - info = locals_["info"] - - if locals_["done"]: - maybe_is_success = info.get("is_success") - if maybe_is_success is not None: - self._is_success_buffer.append(maybe_is_success) - - def _on_step(self) -> bool: - continue_training = True - - if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: - # Reset success rate buffer - self._is_success_buffer = [] - - episode_rewards, episode_lengths = evaluate_policy( - self.agent, - self.eval_env, - n_eval_episodes=self.n_eval_episodes, - render=self.render, - deterministic=self.deterministic, - return_episode_rewards=True, - warn=self.warn, - callback=self._log_success_callback, - ) - - if self.log_path is not None: - self.evaluations_timesteps.append(self.num_time_steps) - self.evaluations_results.append(episode_rewards) - self.evaluations_length.append(episode_lengths) - - kwargs = {} - # Save success log if present - if len(self._is_success_buffer) > 0: - self.evaluations_successes.append(self._is_success_buffer) - kwargs = dict(successes=self.evaluations_successes) - - np.savez( - self.log_path, - timesteps=self.evaluations_timesteps, - results=self.evaluations_results, - ep_lengths=self.evaluations_length, - **kwargs, - ) - - mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards) - mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std( - episode_lengths - ) - self.last_mean_reward = mean_reward - - if self.verbose >= 1: - print( - f"Eval num_timesteps={self.num_time_steps}, " - f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}" - ) - print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}") - # Add to current Logger - self.logger.record("eval/mean_reward", float(mean_reward)) - self.logger.record("eval/mean_ep_length", mean_ep_length) - - if len(self._is_success_buffer) > 0: - success_rate = np.mean(self._is_success_buffer) - if self.verbose >= 1: - print(f"Success rate: {100 * success_rate:.2f}%") - self.logger.record("eval/success_rate", success_rate) - - # Dump log so the evaluation results are printed with the correct timestep - self.logger.record( - "time/total_timesteps", self.num_time_steps, exclude="tensorboard" - ) - self.logger.dump(self.num_time_steps) - - if mean_reward > self.best_mean_reward: - if self.verbose >= 1: - print("New best mean reward!") - if self.best_model_save_path is not None: - self.agent.save( - os.path.join(self.best_model_save_path, "best_model") - ) - self.best_mean_reward = mean_reward - # Trigger callback on new best model, if needed - if self.callback_on_new_best is not None: - continue_training = self.callback_on_new_best.on_step() - - # Trigger callback after every evaluation, if needed - if self.callback is not None: - continue_training = continue_training and self._on_event() - - return continue_training - - def update_child_locals(self, locals_: Dict[str, Any]) -> None: - """ - Update the references to the local variables. - - :param locals_: the local variables during rollout collection - """ - if self.callback: - self.callback.update_locals(locals_) - - class StopTrainingOnRewardThreshold(BaseCallback): """ Stop the training once a threshold in episodic reward diff --git a/openrl/utils/callbacks/callbacks_factory.py b/openrl/utils/callbacks/callbacks_factory.py index 5ce512cb..a49417e4 100644 --- a/openrl/utils/callbacks/callbacks_factory.py +++ b/openrl/utils/callbacks/callbacks_factory.py @@ -2,9 +2,11 @@ from openrl.utils.callbacks.callbacks import BaseCallback, CallbackList from openrl.utils.callbacks.checkpoint_callback import CheckpointCallback +from openrl.utils.callbacks.eval_callback import EvalCallback callbacks_dict = { "CheckpointCallback": CheckpointCallback, + "EvalCallback": EvalCallback, } @@ -17,6 +19,8 @@ def get_callbacks( callbacks = [callbacks] callbacks_list = [] for callback in callbacks: + if callback["id"] not in callbacks_dict: + raise ValueError(f"Callback {callback['id']} not found") callbacks_list.append(callbacks_dict[callback["id"]](**callback["args"])) return CallbackList(callbacks_list) diff --git a/openrl/utils/callbacks/checkpoint_callback.py b/openrl/utils/callbacks/checkpoint_callback.py index 34038e14..38fd4d44 100644 --- a/openrl/utils/callbacks/checkpoint_callback.py +++ b/openrl/utils/callbacks/checkpoint_callback.py @@ -81,6 +81,7 @@ def _on_step(self) -> bool: print(f"Saving model checkpoint to {model_path}") if ( + # TODO: add buffer save support self.save_replay_buffer and hasattr(self.agent, "replay_buffer") and self.agent.replay_buffer is not None diff --git a/openrl/utils/callbacks/eval_callback.py b/openrl/utils/callbacks/eval_callback.py new file mode 100644 index 00000000..a3ad3cdd --- /dev/null +++ b/openrl/utils/callbacks/eval_callback.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import warnings +from typing import Any, Dict, Optional, Union + +import gymnasium as gym +import numpy as np + +from openrl.envs.common import make +from openrl.envs.vec_env import BaseVecEnv, SyncVectorEnv +from openrl.envs.wrappers.monitor import Monitor +from openrl.utils.callbacks.callbacks import BaseCallback, EventCallback +from openrl.utils.evaluation import evaluate_policy + +env_wrappers = [ + Monitor, +] + + +def _make_env( + env: Union[str, Dict[str, Any]], render: bool, asynchronous: bool +) -> BaseVecEnv: + if isinstance(env, str): + env = {"id": env, "env_num": 1} + envs = make( + env["id"], + env_num=env["env_num"], + render_mode="group_human" if render else None, + env_wrappers=env_wrappers, + asynchronous=asynchronous, + ) + return envs + + +class EvalCallback(EventCallback): + """ + Callback for evaluating an agent. + + .. warning:: + + When using multiple environments, each call to ``env.step()`` + will effectively correspond to ``n_envs`` steps. + To account for that, you can use ``eval_freq = max(eval_freq // n_envs, 1)`` + + :param eval_env: The environment used for initialization + :param callback_on_new_best: Callback to trigger + when there is a new best model according to the ``mean_reward`` + :param callback_after_eval: Callback to trigger after every evaluation + :param n_eval_episodes: The number of episodes to test the agent + :param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback. + :param log_path: Path to a folder where the evaluations (``evaluations.npz``) + will be saved. It will be updated at each evaluation. + :param best_model_save_path: Path to a folder where the best model + according to performance on the eval env will be saved. + :param deterministic: Whether the evaluation should + use a stochastic or deterministic actions. + :param render: Whether to render or not the environment during evaluation + :param verbose: Verbosity level: 0 for no output, 1 for indicating information about evaluation results + :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been + wrapped with a Monitor wrapper) + """ + + def __init__( + self, + eval_env: Union[str, Dict[str, Any], gym.Env, BaseVecEnv], + callback_on_new_best: Optional[BaseCallback] = None, + callback_after_eval: Optional[BaseCallback] = None, + n_eval_episodes: int = 5, + eval_freq: int = 10000, + log_path: Optional[str] = None, + best_model_save_path: Optional[str] = None, + deterministic: bool = True, + render: bool = False, + asynchronous: bool = True, + verbose: int = 1, + warn: bool = True, + ): + super().__init__(callback_after_eval, verbose=verbose) + + self.callback_on_new_best = callback_on_new_best + if self.callback_on_new_best is not None: + # Give access to the parent + self.callback_on_new_best.parent = self + + self.n_eval_episodes = n_eval_episodes + self.eval_freq = eval_freq + self.best_mean_reward = -np.inf + self.last_mean_reward = -np.inf + self.deterministic = deterministic + self.render = render + self.warn = warn + + if isinstance(eval_env, str) or isinstance(eval_env, dict): + eval_env = _make_env(eval_env, render, asynchronous) + # Convert to BaseVecEnv for consistency + if not isinstance(eval_env, BaseVecEnv): + eval_env = SyncVectorEnv([lambda: eval_env]) + + self.eval_env = eval_env + self.best_model_save_path = best_model_save_path + # Logs will be written in ``evaluations.npz`` + if log_path is not None: + log_path = os.path.join(log_path, "evaluations") + self.log_path = log_path + self.evaluations_results = [] + self.evaluations_time_steps = [] + self.evaluations_length = [] + # For computing success rate + self._is_success_buffer = [] + self.evaluations_successes = [] + + def _init_callback(self) -> None: + # Does not work in some corner cases, where the wrapper is not the same + if not isinstance(self.training_env, type(self.eval_env)): + warnings.warn( + "Training and eval env are not of the same type" + f"{self.training_env} != {self.eval_env}" + ) + + # Create folders if needed + if self.best_model_save_path is not None: + os.makedirs(self.best_model_save_path, exist_ok=True) + if self.log_path is not None: + os.makedirs(os.path.dirname(self.log_path), exist_ok=True) + + # Init callback called on new best model + if self.callback_on_new_best is not None: + self.callback_on_new_best.init_callback(self.agent) + + def _log_success_callback( + self, locals_: Dict[str, Any], globals_: Dict[str, Any] + ) -> None: + """ + Callback passed to the ``evaluate_policy`` function + in order to log the success rate (when applicable), + for instance when using HER. + + :param locals_: + :param globals_: + """ + info = locals_["info"] + + if locals_["done"]: + maybe_is_success = info.get("is_success") + if maybe_is_success is not None: + self._is_success_buffer.append(maybe_is_success) + + def _on_step(self) -> bool: + continue_training = True + + if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: + # Reset success rate buffer + self._is_success_buffer = [] + + episode_rewards, episode_lengths = evaluate_policy( + self.agent, + self.eval_env, + n_eval_episodes=self.n_eval_episodes, + render=self.render, + deterministic=self.deterministic, + return_episode_rewards=True, + warn=self.warn, + callback=self._log_success_callback, + ) + + if self.log_path is not None: + self.evaluations_time_steps.append(self.num_time_steps) + self.evaluations_results.append(episode_rewards) + self.evaluations_length.append(episode_lengths) + + kwargs = {} + # Save success log if present + if len(self._is_success_buffer) > 0: + self.evaluations_successes.append(self._is_success_buffer) + kwargs = dict(successes=self.evaluations_successes) + + np.savez( + self.log_path, + timesteps=self.evaluations_time_steps, + results=self.evaluations_results, + ep_lengths=self.evaluations_length, + **kwargs, + ) + + mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards) + mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std( + episode_lengths + ) + self.last_mean_reward = mean_reward + + if self.verbose >= 1: + print( + f"Eval num_timesteps={self.num_time_steps}, " + f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}" + ) + print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}") + + if len(self._is_success_buffer) > 0: + success_rate = np.mean(self._is_success_buffer) + if self.verbose >= 1: + print(f"Success rate: {100 * success_rate:.2f}%") + + if mean_reward > self.best_mean_reward: + if self.verbose >= 1: + print("New best mean reward!") + if self.best_model_save_path is not None: + self.agent.save( + os.path.join(self.best_model_save_path, "best_model") + ) + with open( + os.path.join(self.best_model_save_path, f"best_model_info.txt"), + "w", + ) as f: + f.write(f"best model at step: {self.num_time_steps}\n") + f.write(f"best model reward: {mean_reward}\n") + self.best_mean_reward = mean_reward + # Trigger callback on new best model, if needed + if self.callback_on_new_best is not None: + continue_training = self.callback_on_new_best.on_step() + + # Trigger callback after every evaluation, if needed + if self.callback is not None: + continue_training = continue_training and self._on_event() + + return continue_training + + def update_child_locals(self, locals_: Dict[str, Any]) -> None: + """ + Update the references to the local variables. + + :param locals_: the local variables during rollout collection + """ + if self.callback: + self.callback.update_locals(locals_) + + def _on_training_end(self): + self.eval_env.close() diff --git a/openrl/utils/evaluation.py b/openrl/utils/evaluation.py index 7f9066ac..c079d37a 100644 --- a/openrl/utils/evaluation.py +++ b/openrl/utils/evaluation.py @@ -7,12 +7,11 @@ import numpy as np from openrl.envs.vec_env import BaseVecEnv, SyncVectorEnv, is_vecenv_wrapped -from openrl.envs.vec_env.wrappers.vec_monitor_wrapper import VecMonitorWrapper from openrl.utils import type_aliases def evaluate_policy( - model: "type_aliases.AgentActor", + agent: "type_aliases.AgentActor", env: Union[gym.Env, BaseVecEnv], n_eval_episodes: int = 10, deterministic: bool = True, @@ -21,7 +20,9 @@ def evaluate_policy( reward_threshold: Optional[float] = None, return_episode_rewards: bool = False, warn: bool = True, -) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: +) -> Union[ + Tuple[np.ndarray, np.ndarray], Tuple[float, float], Tuple[List[float], List[int]] +]: """ Runs policy for ``n_eval_episodes`` episodes and returns average reward. If a vector env is passed in, this divides the episodes to evaluate onto the @@ -36,7 +37,7 @@ def evaluate_policy( results as well. You can avoid this by wrapping environment with ``Monitor`` wrapper before anything else. - :param model: The RL agent you want to evaluate. This can be any object + :param agent: The RL agent you want to evaluate. This can be any object that implements a `predict` method, such as an RL algorithm (``BaseAlgorithm``) or policy (``BasePolicy``). :param env: The gym environment or ``BaseVecEnv`` environment. @@ -58,14 +59,12 @@ def evaluate_policy( """ is_monitor_wrapped = False # Avoid circular import - from stable_baselines3.common.monitor import Monitor + from openrl.envs.wrappers.monitor import Monitor if not isinstance(env, BaseVecEnv): env = SyncVectorEnv([lambda: env]) - is_monitor_wrapped = ( - is_vecenv_wrapped(env, VecMonitorWrapper) or env.env_is_wrapped(Monitor)[0] - ) + is_monitor_wrapped = env.env_is_wrapped(Monitor, indices=0)[0] if not is_monitor_wrapped and warn: warnings.warn( @@ -78,7 +77,7 @@ def evaluate_policy( UserWarning, ) - n_envs = env.num_envs + n_envs = env.parallel_env_num episode_rewards = [] episode_lengths = [] @@ -88,19 +87,22 @@ def evaluate_policy( [(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int" ) - current_rewards = np.zeros(n_envs) + current_rewards = np.zeros([n_envs, env.agent_num]) current_lengths = np.zeros(n_envs, dtype="int") - observations = env.reset() + # get the train_env, and will set it back after evaluation + train_env = agent.get_env() + agent.set_env(env) + observations, info = env.reset() states = None - episode_starts = np.ones((env.num_envs,), dtype=bool) + episode_starts = np.ones((env.parallel_env_num,), dtype=bool) + while (episode_counts < episode_count_targets).any(): - actions, states = model.predict( + actions, states = agent.act( observations, - state=states, - episode_start=episode_starts, deterministic=deterministic, ) observations, rewards, dones, infos = env.step(actions) + rewards = np.squeeze(rewards, axis=-1) current_rewards += rewards current_lengths += 1 for i in range(n_envs): @@ -119,24 +121,34 @@ def evaluate_policy( # Atari wrapper can send a "done" signal when # the agent loses a life, but it does not correspond # to the true end of episode - if "episode" in info.keys(): + assert "final_info" in info.keys(), ( + "final_info should be in info keys", + info.keys(), + ) + assert "episode" in info["final_info"].keys(), ( + "episode should be in final_info keys", + info["final_info"].keys(), + ) + if "episode" in info["final_info"].keys(): # Do not trust "done" with episode endings. # Monitor wrapper includes "episode" key in info if environment # has been wrapped with it. Use those rewards instead. - episode_rewards.append(info["episode"]["r"]) - episode_lengths.append(info["episode"]["l"]) + episode_rewards.append(info["final_info"]["episode"]["r"]) + episode_lengths.append(info["final_info"]["episode"]["l"]) # Only increment at the real end of an episode episode_counts[i] += 1 else: episode_rewards.append(current_rewards[i]) episode_lengths.append(current_lengths[i]) episode_counts[i] += 1 + current_rewards[i] = 0 current_lengths[i] = 0 - if render: - env.render() - + # if render: + # env.render() + # set env to train_env + agent.set_env(train_env) mean_reward = np.mean(episode_rewards) std_reward = np.std(episode_rewards) if reward_threshold is not None: