Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: c envs support #1152

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mava/configs/arch/sebulba.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
architecture_name: sebulba

# --- Training ---
num_envs: 32 # number of environments per thread.
num_envs: 512 # number of environments per thread.

# --- Evaluation ---
evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select
Expand All @@ -15,7 +15,7 @@ absolute_metric: True # Whether the absolute metric should be computed. For more
# on the absolute metric please see: https://arxiv.org/abs/2209.10485

# --- Sebulba devices config ---
n_threads_per_executor: 2 # num of different threads/env batches per actor
n_threads_per_executor: 1 # num of different threads/env batches per actor
actor_device_ids: [0] # ids of actor devices
learner_device_ids: [0] # ids of learner devices
rollout_queue_size : 5
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/ff_ippo_sebulba.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: sebulba
- system: ppo/ff_ippo
- network: mlp # [mlp, continuous_mlp, cnn]
- env: lbf_gym # [rware_gym, lbf_gym, smaclite_gym]
- env: rware_puffer # [rware_gym, lbf_gym, smaclite_gym, rware_puffer]
- _self_

hydra:
Expand Down
24 changes: 24 additions & 0 deletions mava/configs/env/rware_puffer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# ---Environment Configs---
defaults:
- _self_
- scenario: puffer-rware-tiny-2ag

env_name: PufferRobotWarehouse # Used for logging purposes.


# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used.
# This should not be changed.
implicit_agent_id: False
# Whether or not to log the winrate of this environment. This should not be changed as not all
# environments have a winrate metric.
log_win_rate: False

# Weather or not to sum the returned rewards over all of the agents.
use_shared_rewards: True

kwargs:
max_episode_steps: 500
6 changes: 6 additions & 0 deletions mava/configs/env/scenario/puffer-rware-tiny-2ag.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
task_name: puffer-tiny-2ag
task_config:
map_choice: 1
num_agents: 2
num_requested_shelves: 2
# 1 : tiny_shelf_locations (32 shelves), 2 : small_shelf_locations (80), 3: medium_shelf_locations (144)
8 changes: 4 additions & 4 deletions mava/systems/ppo/sebulba/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def learner_setup(
"""Initialise learner_fn, network and learner state."""

# create temporory envoirnments.
env = environments.make_gym_env(config, config.arch.num_envs)
env = environments.sebulba_make(config, config.arch.num_envs)
# Get number of agents and actions.
action_space = env.single_action_space
config.system.num_agents = len(action_space)
Expand Down Expand Up @@ -569,7 +569,7 @@ def run_experiment(_config: DictConfig) -> float:
# One key per device for evaluation.
eval_act_fn = make_ff_eval_act_fn(apply_fns[0], config)
evaluator, evaluator_envs = get_eval_fn(
environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=False
environments.sebulba_make, eval_act_fn, config, np_rng, absolute_metric=False
)

# Calculate total timesteps.
Expand Down Expand Up @@ -626,7 +626,7 @@ def run_experiment(_config: DictConfig) -> float:
args=(
act_key,
# We have to do this here, creating envs inside actor threads causes deadlocks
environments.make_gym_env(config, config.arch.num_envs),
environments.sebulba_make(config, config.arch.num_envs),
config,
pipe,
params_source,
Expand Down Expand Up @@ -700,7 +700,7 @@ def run_experiment(_config: DictConfig) -> float:
if config.arch.absolute_metric:
print(f"{Fore.BLUE}{Style.BRIGHT}Measuring absolute metric...{Style.RESET_ALL}")
abs_metric_evaluator, abs_metric_evaluator_envs = get_eval_fn(
environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=True
environments.sebulba_make, eval_act_fn, config, np_rng, absolute_metric=True
)
key, eval_key = jax.random.split(key, 2)
eval_metrics = abs_metric_evaluator(best_params_cpu, eval_key, {})
Expand Down
78 changes: 73 additions & 5 deletions mava/utils/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple
from typing import Tuple, Dict

import gymnasium
import gymnasium as gym
Expand All @@ -23,6 +23,9 @@
import matrax
from gigastep import ScenarioBuilder
from jaxmarl.environments.smax import map_name_to_scenario
from pufferlib.ocean import Rware
import pufferlib.vector
from psutil import cpu_count
from jumanji.environments.routing.cleaner.generator import (
RandomGenerator as CleanerRandomGenerator,
)
Expand Down Expand Up @@ -56,6 +59,8 @@
SmacWrapper,
SmaxWrapper,
UoeWrapper,
PufferAutoResetWrapper,
PufferToJumanji,
VectorConnectorWrapper,
async_multiagent_worker,
)
Expand All @@ -82,7 +87,9 @@
"LevelBasedForaging": UoeWrapper,
"SMACLite": SmacWrapper,
}

_puffer_registry = {
"PufferRobotWarehouse" : Rware
}

def add_extra_wrappers(
train_env: MarlEnv, eval_env: MarlEnv, config: DictConfig
Expand Down Expand Up @@ -249,9 +256,6 @@ def create_gym_env(config: DictConfig, add_global_state: bool = False) -> gymnas
registered_name = f"{config.env.scenario.name}:{config.env.scenario.task_name}"
env = gym.make(registered_name, disable_env_checker=True, **config.env.kwargs)
wrapped_env = wrapper(env, config.env.use_shared_rewards, add_global_state)
if config.system.add_agent_id:
wrapped_env = GymAgentIDWrapper(wrapped_env)
wrapped_env = GymRecordEpisodeMetrics(wrapped_env)
return wrapped_env

envs = gymnasium.vector.AsyncVectorEnv(
Expand All @@ -260,9 +264,73 @@ def create_gym_env(config: DictConfig, add_global_state: bool = False) -> gymnas
)

envs = GymToJumanji(envs)
if config.system.add_agent_id:
envs = GymAgentIDWrapper(envs)
envs = GymRecordEpisodeMetrics(envs)

return envs

def make_puffer_env(
config: DictConfig,
num_env: int,
add_global_state: bool = False,
) -> GymToJumanji:
"""Create and configure a Puffer environment wrapped for Jumanji."""

env_creator = _puffer_registry[config.env.env_name]

def create_puffer_env(*env_args, **env_kwargs) -> gymnasium.Env:
"""Wraps the environment with a PufferAutoResetWrapper."""
return PufferAutoResetWrapper(env_creator, *env_args, **env_kwargs)

# Determine the number of CPU cores to use.
#todo: should we move this to config? testing showed that running on multiple cores is slower due to the transfer overhead
n_cpu_cores = 1 # Using a single CPU core for light environments.
if n_cpu_cores >= num_env:
num_workers, num_parallel_envs = num_env, 1
else:
assert num_env % n_cpu_cores == 0, (
f"The number of environments ({num_env}) must be divisible by the number of CPU cores ({n_cpu_cores})."
)
num_workers, num_parallel_envs = n_cpu_cores, num_env // n_cpu_cores

# Create the vectorized environments.
env_kwargs = {
**config['env']['kwargs'],
**config['env']['scenario']['task_config'],
"num_envs": num_parallel_envs,
}
envs = pufferlib.vector.make(
create_puffer_env,
backend=pufferlib.vector.Multiprocessing,
num_envs=num_workers,
env_kwargs=env_kwargs,
)

# Wrap environments for Jumanji compatibility.
envs = PufferToJumanji(envs, num_envs=num_env)
if config.system.add_agent_id:
envs = GymAgentIDWrapper(envs)
envs = GymRecordEpisodeMetrics(envs)

return envs

def sebulba_make(
config: DictConfig,
num_env: int,
add_global_state: bool = False,
) -> GymToJumanji:

env_name = config.env.env_name

if env_name in _puffer_registry:
return make_puffer_env(config, num_env, add_global_state)
elif env_name in _gym_registry:
return make_gym_env(config, num_env, add_global_state)
else:
raise ValueError(f"{env_name} is not a supported environment.")



def make(config: DictConfig, add_global_state: bool = False) -> Tuple[MarlEnv, MarlEnv]:
"""
Expand Down
2 changes: 2 additions & 0 deletions mava/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
GymToJumanji,
SmacWrapper,
UoeWrapper,
PufferAutoResetWrapper,
PufferToJumanji,
async_multiagent_worker,
)
from mava.wrappers.jaxmarl import MabraxWrapper, MPEWrapper, SmaxWrapper
Expand Down
Loading
Loading