From b8fa94d24b55f3c804b877dc2fdfad5aee4482bd Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 19 Dec 2024 09:20:04 +0100 Subject: [PATCH 1/2] feat: inital puffer --- mava/configs/arch/sebulba.yaml | 4 +- mava/configs/default/ff_ippo_sebulba.yaml | 2 +- mava/configs/env/rware_puffer.yaml | 24 ++ .../env/scenario/puffer-rware-tiny-2ag.yaml | 6 + mava/systems/ppo/sebulba/ff_ippo.py | 8 +- mava/utils/make_env.py | 66 +++- mava/wrappers/__init__.py | 2 + mava/wrappers/gym.py | 326 +++++++++++++----- requirements/requirements.txt | 1 + 9 files changed, 339 insertions(+), 100 deletions(-) create mode 100644 mava/configs/env/rware_puffer.yaml create mode 100644 mava/configs/env/scenario/puffer-rware-tiny-2ag.yaml diff --git a/mava/configs/arch/sebulba.yaml b/mava/configs/arch/sebulba.yaml index 52ee0ffbf..e927a447d 100644 --- a/mava/configs/arch/sebulba.yaml +++ b/mava/configs/arch/sebulba.yaml @@ -2,7 +2,7 @@ architecture_name: sebulba # --- Training --- -num_envs: 32 # number of environments per thread. +num_envs: 512 # number of environments per thread. # --- Evaluation --- evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select @@ -15,7 +15,7 @@ absolute_metric: True # Whether the absolute metric should be computed. For more # on the absolute metric please see: https://arxiv.org/abs/2209.10485 # --- Sebulba devices config --- -n_threads_per_executor: 2 # num of different threads/env batches per actor +n_threads_per_executor: 1 # num of different threads/env batches per actor actor_device_ids: [0] # ids of actor devices learner_device_ids: [0] # ids of learner devices rollout_queue_size : 5 diff --git a/mava/configs/default/ff_ippo_sebulba.yaml b/mava/configs/default/ff_ippo_sebulba.yaml index d0ecfae97..346bdf81b 100644 --- a/mava/configs/default/ff_ippo_sebulba.yaml +++ b/mava/configs/default/ff_ippo_sebulba.yaml @@ -3,7 +3,7 @@ defaults: - arch: sebulba - system: ppo/ff_ippo - network: mlp # [mlp, continuous_mlp, cnn] - - env: lbf_gym # [rware_gym, lbf_gym, smaclite_gym] + - env: rware_puffer # [rware_gym, lbf_gym, smaclite_gym, rware_puffer] - _self_ hydra: diff --git a/mava/configs/env/rware_puffer.yaml b/mava/configs/env/rware_puffer.yaml new file mode 100644 index 000000000..3ec510e6d --- /dev/null +++ b/mava/configs/env/rware_puffer.yaml @@ -0,0 +1,24 @@ +# ---Environment Configs--- +defaults: + - _self_ + - scenario: puffer-rware-tiny-2ag + +env_name: PufferRobotWarehouse # Used for logging purposes. + + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used. +# This should not be changed. +implicit_agent_id: False +# Whether or not to log the winrate of this environment. This should not be changed as not all +# environments have a winrate metric. +log_win_rate: False + +# Weather or not to sum the returned rewards over all of the agents. +use_shared_rewards: True + +kwargs: + max_episode_steps: 500 \ No newline at end of file diff --git a/mava/configs/env/scenario/puffer-rware-tiny-2ag.yaml b/mava/configs/env/scenario/puffer-rware-tiny-2ag.yaml new file mode 100644 index 000000000..d05ca16b5 --- /dev/null +++ b/mava/configs/env/scenario/puffer-rware-tiny-2ag.yaml @@ -0,0 +1,6 @@ +task_name: puffer-tiny-2ag +task_config: + map_choice: 1 + num_agents: 2 + num_requested_shelves: 2 +# 1 : tiny_shelf_locations (32 shelves), 2 : small_shelf_locations (80), 3: medium_shelf_locations (144) \ No newline at end of file diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 6f34c0b1a..369f58be8 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -450,7 +450,7 @@ def learner_setup( """Initialise learner_fn, network and learner state.""" # create temporory envoirnments. - env = environments.make_gym_env(config, config.arch.num_envs) + env = environments.sebulba_make(config, config.arch.num_envs) # Get number of agents and actions. action_space = env.single_action_space config.system.num_agents = len(action_space) @@ -569,7 +569,7 @@ def run_experiment(_config: DictConfig) -> float: # One key per device for evaluation. eval_act_fn = make_ff_eval_act_fn(apply_fns[0], config) evaluator, evaluator_envs = get_eval_fn( - environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=False + environments.sebulba_make, eval_act_fn, config, np_rng, absolute_metric=False ) # Calculate total timesteps. @@ -626,7 +626,7 @@ def run_experiment(_config: DictConfig) -> float: args=( act_key, # We have to do this here, creating envs inside actor threads causes deadlocks - environments.make_gym_env(config, config.arch.num_envs), + environments.sebulba_make(config, config.arch.num_envs), config, pipe, params_source, @@ -700,7 +700,7 @@ def run_experiment(_config: DictConfig) -> float: if config.arch.absolute_metric: print(f"{Fore.BLUE}{Style.BRIGHT}Measuring absolute metric...{Style.RESET_ALL}") abs_metric_evaluator, abs_metric_evaluator_envs = get_eval_fn( - environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=True + environments.sebulba_make, eval_act_fn, config, np_rng, absolute_metric=True ) key, eval_key = jax.random.split(key, 2) eval_metrics = abs_metric_evaluator(best_params_cpu, eval_key, {}) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 0a56367c8..2c8e32f62 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Tuple, Dict import gymnasium import gymnasium as gym @@ -23,6 +23,9 @@ import matrax from gigastep import ScenarioBuilder from jaxmarl.environments.smax import map_name_to_scenario +from pufferlib.ocean import Rware +import pufferlib.vector +from psutil import cpu_count from jumanji.environments.routing.cleaner.generator import ( RandomGenerator as CleanerRandomGenerator, ) @@ -56,6 +59,8 @@ SmacWrapper, SmaxWrapper, UoeWrapper, + PufferAutoResetWrapper, + PufferToJumanji, VectorConnectorWrapper, async_multiagent_worker, ) @@ -82,7 +87,9 @@ "LevelBasedForaging": UoeWrapper, "SMACLite": SmacWrapper, } - +_puffer_registry = { + "PufferRobotWarehouse" : Rware +} def add_extra_wrappers( train_env: MarlEnv, eval_env: MarlEnv, config: DictConfig @@ -249,9 +256,6 @@ def create_gym_env(config: DictConfig, add_global_state: bool = False) -> gymnas registered_name = f"{config.env.scenario.name}:{config.env.scenario.task_name}" env = gym.make(registered_name, disable_env_checker=True, **config.env.kwargs) wrapped_env = wrapper(env, config.env.use_shared_rewards, add_global_state) - if config.system.add_agent_id: - wrapped_env = GymAgentIDWrapper(wrapped_env) - wrapped_env = GymRecordEpisodeMetrics(wrapped_env) return wrapped_env envs = gymnasium.vector.AsyncVectorEnv( @@ -260,9 +264,61 @@ def create_gym_env(config: DictConfig, add_global_state: bool = False) -> gymnas ) envs = GymToJumanji(envs) + if config.system.add_agent_id: + envs = GymAgentIDWrapper(envs) + envs = GymRecordEpisodeMetrics(envs) + + return envs + +def make_puffer_env( + config: DictConfig, + num_env: int, + add_global_state: bool = False, +) -> GymToJumanji: + + env_creator = _puffer_registry[config.env.env_name] + + def create_puffer_env( *env_args, **env_kwargs) -> gymnasium.Env: + wrapped_env = PufferAutoResetWrapper(env_creator, *env_args, **env_kwargs) + return wrapped_env + # todo running with a signle cpu core is much faster for light envs + # the data transfer overhead makes using multiple cores not worth it for rware :thinking: + # todo: is using more than 1 actor bad? + n_cpu_cores = 1#cpu_count(logical = False) + if n_cpu_cores >= num_env: + num_workers, num_parrallel_envs = num_env, 1 + else: + assert num_env % n_cpu_cores == 0, f"the numlber of envs({num_env}) must be divisable by the number of cpu cores ({n_cpu_cores})" + num_workers, num_parrallel_envs = n_cpu_cores, num_env // n_cpu_cores + + + envs = pufferlib.vector.make(create_puffer_env, + backend=pufferlib.vector.Multiprocessing, + num_envs=num_workers, env_kwargs = dict(config['env']['kwargs']) | dict(config['env']['scenario']['task_config']) | {"num_envs" : num_parrallel_envs} + ) + envs = PufferToJumanji(envs, num_envs=num_env) + if config.system.add_agent_id: + envs = GymAgentIDWrapper(envs) + envs = GymRecordEpisodeMetrics(envs) return envs +def sebulba_make( + config: DictConfig, + num_env: int, + add_global_state: bool = False, + ) -> GymToJumanji: + + env_name = config.env.env_name + + if env_name in _puffer_registry: + return make_puffer_env(config, num_env, add_global_state) + elif env_name in _gym_registry: + return make_gym_env(config, num_env, add_global_state) + else: + raise ValueError(f"{env_name} is not a supported environment.") + + def make(config: DictConfig, add_global_state: bool = False) -> Tuple[MarlEnv, MarlEnv]: """ diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index fc9dadb31..331f89037 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -22,6 +22,8 @@ GymToJumanji, SmacWrapper, UoeWrapper, + PufferAutoResetWrapper, + PufferToJumanji, async_multiagent_worker, ) from mava.wrappers.jaxmarl import MabraxWrapper, MPEWrapper, SmaxWrapper diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 9258bde6a..af7c97233 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -23,11 +23,13 @@ import gymnasium import gymnasium.vector.async_vector_env +from pufferlib import PufferEnv import numpy as np from gymnasium import spaces from gymnasium.spaces.utils import is_space_dtype_shape_equiv from gymnasium.vector.utils import write_to_shared_memory from numpy.typing import NDArray +from jax import tree from mava.types import Observation, ObservationGlobalState @@ -163,87 +165,6 @@ def get_action_mask(self, info: Dict) -> NDArray: return np.array(self._env.unwrapped.get_avail_actions()) -class GymRecordEpisodeMetrics(gymnasium.Wrapper): - """Record the episode returns and lengths.""" - - def __init__(self, env: gymnasium.Env): - super().__init__(env) - self._env = env - self.running_count_episode_return = 0.0 - self.running_count_episode_length = 0.0 - - def reset( - self, seed: Optional[int] = None, options: Optional[dict] = None - ) -> Tuple[NDArray, Dict]: - agents_view, info = self._env.reset(seed, options) - - # Reset the metrics - self.running_count_episode_return = 0.0 - self.running_count_episode_length = 0.0 - - # Create the metrics dict - metrics = { - "episode_return": self.running_count_episode_return, - "episode_length": self.running_count_episode_length, - "is_terminal_step": False, - } - - info["metrics"] = metrics - - return agents_view, info - - def step(self, actions: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: - agents_view, reward, terminated, truncated, info = self._env.step(actions) - - self.running_count_episode_return += float(np.mean(reward)) - self.running_count_episode_length += 1 - - metrics = { - "episode_return": self.running_count_episode_return, - "episode_length": self.running_count_episode_length, - "is_terminal_step": np.logical_or(terminated, truncated).all().item(), - } - - info["metrics"] = metrics - - return agents_view, reward, terminated, truncated, info - - -class GymAgentIDWrapper(gymnasium.Wrapper): - """Add one hot agent IDs to observation.""" - - def __init__(self, env: gymnasium.Env): - super().__init__(env) - - self.agent_ids = np.eye(self.env.num_agents) - self.observation_space = self.modify_space(self.env.observation_space) - - def reset( - self, seed: Optional[int] = None, options: Optional[dict] = None - ) -> Tuple[NDArray, Dict]: - """Reset the environment.""" - obs, info = self.env.reset(seed, options) - obs = np.concatenate([self.agent_ids, obs], axis=1) - return obs, info - - def step(self, action: list) -> Tuple[NDArray, float, bool, bool, Dict]: - """Step the environment.""" - obs, reward, terminated, truncated, info = self.env.step(action) - obs = np.concatenate([self.agent_ids, obs], axis=1) - return obs, reward, terminated, truncated, info - - def modify_space(self, space: spaces.Space) -> spaces.Space: - if isinstance(space, spaces.Box): - new_shape = (space.shape[0], space.shape[1] + self.env.num_agents) - high = np.concatenate((space.high, np.ones_like(self.agent_ids)), axis=1) - low = np.concatenate((space.low, np.zeros_like(self.agent_ids)), axis=1) - return spaces.Box(low=low, high=high, shape=new_shape, dtype=space.dtype) - elif isinstance(space, spaces.Tuple): - return spaces.Tuple(self.modify_space(s) for s in space) - else: - raise ValueError(f"Space {type(space)} is not currently supported.") - - class GymToJumanji: """Converts from the Gym API to the Jumanji API.""" @@ -251,16 +172,18 @@ def __init__(self, env: gymnasium.vector.VectorEnv): self.env = env self.single_action_space = env.unwrapped.single_action_space self.single_observation_space = env.unwrapped.single_observation_space + + self.num_envs = self.env.num_envs + self.num_agents = len(env.unwrapped.single_action_space) def reset(self, seed: Optional[list[int]] = None, options: Optional[dict] = None) -> TimeStep: obs, info = self.env.reset(seed=seed, options=options) # type: ignore num_agents = len(self.env.single_action_space) # type: ignore - num_envs = self.env.num_envs - step_type = np.full(num_envs, StepType.FIRST) - rewards = np.zeros((num_envs, num_agents), dtype=float) - teminated = np.zeros(num_envs, dtype=float) + step_type = np.full(self.num_envs, StepType.FIRST) + rewards = np.zeros((self.num_envs, num_agents), dtype=float) + teminated = np.zeros(self.num_envs, dtype=float) timestep = self._create_timestep(obs, step_type, teminated, rewards, info) @@ -297,9 +220,6 @@ def _create_timestep( observation = self._format_observation(obs, info) # Filter out the masks and auxiliary data extras = {} - extras["episode_metrics"] = { - key: value for key, value in info["metrics"].items() if key[0] != "_" - } if "won_episode" in info: extras["won_episode"] = info["won_episode"] @@ -409,3 +329,233 @@ def async_multiagent_worker( # CCR001 pipe.send((None, False)) finally: env.close() + + + +class GymRecordEpisodeMetrics: + """Record the episode returns and lengths.""" + + def __init__(self, env: GymToJumanji): + self.env = env + + self.num_env = self.env.num_envs + self.num_agents = self.env.num_agents + + self.running_count_episode_return = np.zeros(env.num_envs) + self.running_count_episode_length = np.zeros(env.num_envs) + + self.episode_return = np.zeros(env.num_envs) + self.episode_length = np.zeros(env.num_envs) + + self.single_action_space = env.single_action_space + self.single_observation_space = env.single_observation_space + + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> TimeStep: + timestep = self.env.reset(seed, options) + + # Reset the metrics + self.running_count_episode_return = np.zeros(self.env.num_envs) + self.running_count_episode_length = np.zeros(self.env.num_envs) + + self.episode_return = np.zeros(self.env.num_envs) + self.episode_length = np.zeros(self.env.num_envs) + + # Create the metrics dict + metrics = { + "episode_return": self.running_count_episode_return, + "episode_length": self.running_count_episode_length, + "is_terminal_step": np.full(self.env.num_envs, False), + } + + timestep.extras["episode_metrics"] = metrics + + return timestep + + def step(self, actions: NDArray) -> TimeStep: + timestep = self.env.step(actions) + + done = timestep.last() + not_done = 1 - done + + # Counting episode return and length. + new_episode_return = self.running_count_episode_return + np.mean(timestep.reward, axis= 1) + new_episode_length = self.running_count_episode_length + 1 + + # Previous episode return/length until done and then the next episode return. + self.episode_return = self.episode_return * not_done + new_episode_return * done + self.episode_length = self.episode_length * not_done + new_episode_length * done + + self.running_count_episode_return = new_episode_return * not_done + self.running_count_episode_length = new_episode_length * not_done + + metrics = { + "episode_return": self.episode_return, + "episode_length": self.episode_length, + "is_terminal_step": done, + } + + timestep.extras["episode_metrics"] = metrics + + return timestep + + def close(self): + self.env.close() + + +class GymAgentIDWrapper: + """Add one hot agent IDs to observation.""" + + def __init__(self, env: GymToJumanji): + self.env = env + self.num_envs = self.env.num_envs + self.num_agents = self.env.num_agents + + self.agent_ids = np.repeat(np.eye(env.num_agents)[np.newaxis, ...], repeats=env.num_envs, axis = 0) + self.single_observation_space = self.modify_space(env.single_observation_space) + self.single_action_space = env.single_action_space + + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> TimeStep: + """Reset the environment.""" + timestep = self.env.reset(seed, options) + agents_view = np.concatenate([self.agent_ids, timestep.observation.agents_view], axis=2) + timestep.observation = timestep.observation._replace(agents_view = agents_view) + return timestep + + def step(self, action: list) -> TimeStep: + """Step the environment.""" + timestep = self.env.step(action) + agents_view = np.concatenate([self.agent_ids, timestep.observation.agents_view], axis=2) + timestep.observation = timestep.observation._replace(agents_view = agents_view) + return timestep + + def modify_space(self, space: spaces.Space) -> spaces.Space: + if isinstance(space, spaces.Box): + new_shape = (space.shape[0], space.shape[1] + self.env.num_agents) + high = np.concatenate((space.high, np.ones((self.env.num_agents,self.env.num_agents))), axis=1) + low = np.concatenate((space.low, np.zeros((self.env.num_agents,self.env.num_agents))), axis=1) + return spaces.Box(low=low, high=high, shape=new_shape, dtype=space.dtype) + elif isinstance(space, spaces.Tuple): + return spaces.Tuple(self.modify_space(s) for s in space) + else: + raise ValueError(f"Space {type(space)} is not currently supported.") + def close(self): + self.env.close() + + + +class PufferToJumanji: + def __init__(self, env, num_envs): + self.env = env + + self.num_envs = num_envs + self.num_agents = env.num_agents // num_envs + self.num_actions = env.single_action_space.n + + self.single_action_space = spaces.MultiDiscrete([self.num_actions] * self.num_agents) + + # Box(...) --> Box(N, ...) + single_obs = env.single_observation_space # type: ignore + shape = (self.num_agents, *single_obs.shape) + low = np.tile(single_obs.low, (self.num_agents, 1)) + high = np.tile(single_obs.high, (self.num_agents, 1)) + self.single_observation_space = spaces.Box(low=low, high=high, shape=shape, dtype=single_obs.dtype) + + self.fix_shape_copy = lambda x : x.reshape(self.num_envs, self.num_agents, *x.shape[1:]).copy() #copy to avoid pointer magic + + def reset(self, seed: Optional[list[int]] = None, options: Optional[dict] = None) -> TimeStep: + obs, info = self.env.reset() + obs = self.fix_shape_copy(obs) + + step_type = np.full(self.num_envs, StepType.FIRST) + rewards = np.zeros((self.num_envs, self.num_agents), dtype=float) + terminated = np.zeros(self.num_envs, dtype=float) + action_mask = np.ones((self.num_envs, self.num_agents, self.num_actions)) + + obs_data = {"agents_view": obs, "action_mask": action_mask} + Observation(**obs_data) + + return TimeStep( + step_type=step_type, + reward=rewards, + discount=1.0 - terminated, + observation=Observation(**obs_data), + extras={}, + ) + def step(self, action: list) -> TimeStep: + action = action.flatten() + obs, rewards, terminated, truncated, info = self.env.step(action) + obs, rewards, terminated, truncated = tree.map(self.fix_shape_copy, (obs, rewards, terminated, truncated)) + + terminated = np.any(terminated, axis = -1) # Agent termination flag to env termination flag + truncated = np.any(truncated, axis = -1) + + ep_done = np.logical_or(terminated, truncated) + step_type = np.where(ep_done, StepType.LAST, StepType.MID) + action_mask = np.ones((self.num_envs, self.num_agents, self.num_actions)) + + obs_data = {"agents_view": obs, "action_mask": action_mask} + + + return TimeStep( + step_type=step_type, + reward=rewards, + discount=1.0 - terminated, + observation=Observation(**obs_data), + extras={}, + ) + + def close(self) -> None: + self.env.close() + + +class PufferAutoResetWrapper(PufferEnv): + def __init__(self, env_class : PufferEnv, max_episode_steps: int = 0, *args, **kwargs): + """ + Generic wrapper for PufferLib environments to track the number of steps taken. + + Parameters: + - env_class: The class of the environment to wrap. + - max_steps: Maximum number of steps allowed in an episode (optional). + - *args, **kwargs: Arguments to initialize the environment. + """ + self.env = env_class(*args, **kwargs) + self.steps = 0 # Initialize step counter + self.max_steps = max_episode_steps # Set maximum steps if provided + + def reset(self, seed: Optional[int] = None) -> Tuple[NDArray, Dict]: + self.steps = 0 + return self.env.reset(seed = seed) + + def step(self, actions: List) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: + + if self.steps == 0: + self.env.terminals.fill(False) + + self.steps += 1 + observation, reward, terminated, truncated, _ = self.env.step(actions) + info = {"real_next_obs" : observation.copy()} #todo i dislike replacing the info + + if np.logical_or(terminated, truncated).all() or self.steps == self.max_steps: + # The returned values are ignored when using puffer's vector envs + # Intsted the updates have to be directly made in the env + self.env.observations[:], _ = self.reset() # change values without changing array refrence + self.env.terminals.fill(True) + + return self.env.observations, reward, self.env.terminals, truncated, info + + def render(self, *args, **kwargs): + return self.env.render(*args, **kwargs) + + def close(self): + self.env.close() + + def __getattr__(self, name): + """ + Forward any attributes or methods not explicitly defined + in this wrapper to the underlying environment. + """ + return getattr(self.env, name) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index e004a3c23..4329ed8f9 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -26,3 +26,4 @@ smaclite @ git+https://github.com/uoe-agents/smaclite.git tensorboard_logger tensorflow_probability type_enforced # needed because gigastep is missing this dependency +pufferlib \ No newline at end of file From 328d66c0e79ffb01de230e253afeb8810be91bfc Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sat, 4 Jan 2025 20:39:01 +0100 Subject: [PATCH 2/2] chore: minor changes to comments --- mava/utils/make_env.py | 44 +++++++++++++++++++++++++++--------------- mava/wrappers/gym.py | 9 +++++---- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 2c8e32f62..746decb2d 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -275,32 +275,44 @@ def make_puffer_env( num_env: int, add_global_state: bool = False, ) -> GymToJumanji: - + """Create and configure a Puffer environment wrapped for Jumanji.""" + env_creator = _puffer_registry[config.env.env_name] - def create_puffer_env( *env_args, **env_kwargs) -> gymnasium.Env: - wrapped_env = PufferAutoResetWrapper(env_creator, *env_args, **env_kwargs) - return wrapped_env - # todo running with a signle cpu core is much faster for light envs - # the data transfer overhead makes using multiple cores not worth it for rware :thinking: - # todo: is using more than 1 actor bad? - n_cpu_cores = 1#cpu_count(logical = False) + def create_puffer_env(*env_args, **env_kwargs) -> gymnasium.Env: + """Wraps the environment with a PufferAutoResetWrapper.""" + return PufferAutoResetWrapper(env_creator, *env_args, **env_kwargs) + + # Determine the number of CPU cores to use. + #todo: should we move this to config? testing showed that running on multiple cores is slower due to the transfer overhead + n_cpu_cores = 1 # Using a single CPU core for light environments. if n_cpu_cores >= num_env: - num_workers, num_parrallel_envs = num_env, 1 + num_workers, num_parallel_envs = num_env, 1 else: - assert num_env % n_cpu_cores == 0, f"the numlber of envs({num_env}) must be divisable by the number of cpu cores ({n_cpu_cores})" - num_workers, num_parrallel_envs = n_cpu_cores, num_env // n_cpu_cores - - - envs = pufferlib.vector.make(create_puffer_env, - backend=pufferlib.vector.Multiprocessing, - num_envs=num_workers, env_kwargs = dict(config['env']['kwargs']) | dict(config['env']['scenario']['task_config']) | {"num_envs" : num_parrallel_envs} + assert num_env % n_cpu_cores == 0, ( + f"The number of environments ({num_env}) must be divisible by the number of CPU cores ({n_cpu_cores})." + ) + num_workers, num_parallel_envs = n_cpu_cores, num_env // n_cpu_cores + + # Create the vectorized environments. + env_kwargs = { + **config['env']['kwargs'], + **config['env']['scenario']['task_config'], + "num_envs": num_parallel_envs, + } + envs = pufferlib.vector.make( + create_puffer_env, + backend=pufferlib.vector.Multiprocessing, + num_envs=num_workers, + env_kwargs=env_kwargs, ) + # Wrap environments for Jumanji compatibility. envs = PufferToJumanji(envs, num_envs=num_env) if config.system.add_agent_id: envs = GymAgentIDWrapper(envs) envs = GymRecordEpisodeMetrics(envs) + return envs def sebulba_make( diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index af7c97233..e77415c5b 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -531,17 +531,18 @@ def reset(self, seed: Optional[int] = None) -> Tuple[NDArray, Dict]: return self.env.reset(seed = seed) def step(self, actions: List) -> Tuple[NDArray, NDArray, NDArray, NDArray, Dict]: - + # The returned values are ignored when using puffer's vector envs + # Intsted the updates have to be directly made in the matricies stored inside env + # Since everything is passed by reference. + if self.steps == 0: self.env.terminals.fill(False) self.steps += 1 observation, reward, terminated, truncated, _ = self.env.step(actions) - info = {"real_next_obs" : observation.copy()} #todo i dislike replacing the info + info = {"real_next_obs" : observation.copy()} if np.logical_or(terminated, truncated).all() or self.steps == self.max_steps: - # The returned values are ignored when using puffer's vector envs - # Intsted the updates have to be directly made in the env self.env.observations[:], _ = self.reset() # change values without changing array refrence self.env.terminals.fill(True)