From 7bfeb0b74dd9c229d2b97f256103547d99f95a3e Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Tue, 28 Feb 2023 12:36:47 +0800 Subject: [PATCH 01/24] add ppo --- malib/rl/common/trainer.py | 4 +- malib/rl/ppo/policy.py | 22 +++++++++++ malib/rl/ppo/trainer.py | 76 ++++++++++++++++++++++++++++++++++++++ malib/utils/data.py | 13 +++++++ 4 files changed, 113 insertions(+), 2 deletions(-) diff --git a/malib/rl/common/trainer.py b/malib/rl/common/trainer.py index 01bea92c..16897aa0 100644 --- a/malib/rl/common/trainer.py +++ b/malib/rl/common/trainer.py @@ -76,8 +76,8 @@ def counter(self): def step_counter(self): self._counter += 1 - def parameters(self): - return self.policy.parameters() + # def parameters(self): + # return self.policy.parameters() @abstractmethod def setup(self): diff --git a/malib/rl/ppo/policy.py b/malib/rl/ppo/policy.py index 47cb69ed..26e06d3e 100644 --- a/malib/rl/ppo/policy.py +++ b/malib/rl/ppo/policy.py @@ -21,3 +21,25 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. + +from typing import Dict, Any + +import torch + +from torch import nn +from gym import spaces +from malib.rl.a2c import A2CPolicy + + +class PPOPolicy(A2CPolicy): + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + model_config: Dict[str, Any], + custom_config: Dict[str, Any], + **kwargs + ): + super().__init__( + observation_space, action_space, model_config, custom_config, **kwargs + ) diff --git a/malib/rl/ppo/trainer.py b/malib/rl/ppo/trainer.py index 47cb69ed..74e0a270 100644 --- a/malib/rl/ppo/trainer.py +++ b/malib/rl/ppo/trainer.py @@ -21,3 +21,79 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. + +from typing import Sequence, List, Any, Dict + +import torch + +from torch import nn +from malib.utils.data import Postprocessor +from malib.utils.typing import AgentID +from malib.utils.tianshou_batch import Batch +from malib.rl.a2c import A2CTrainer + + +class PPOTrainer(A2CTrainer): + def train(self, batch: Batch) -> Dict[str, List[float]]: + repeats = self.training_config["repeats"] + ratio_clip = self.training_config["ratio_clip"] + dual_clip = self.training_config["dual_clip"] + vf_ratio = self.training_config["vf_ratio"] + ent_ratio = self.training_config["ent_ratio"] + use_adv_norm = self.training_config["use_adv_norm"] + adv_norm_eps = self.training_config["adv_norm_eps"] + use_grad_norm = self.training_config["use_grad_norm"] + use_value_clip = self.training_config["use_value_clip"] + + losses, clip_losses, vf_losses, ent_losses = 0.0, 0.0, 0.0, 0.0 + for step in range(repeats): + dist = self.policy.dist_fn.proba_distribution(batch.logits) + if use_adv_norm: + mean, std = batch.adv.mean(), batch.adv.std() + batch.adv = (batch.adv - mean) / (std + adv_norm_eps) + + ratio = (dist.log_prob(batch.act) - batch.logp_old).exp().float() + ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) + surr1 = ratio * batch.adv + surr2 = ratio.clamp(1.0 - ratio_clip, 1.0 + ratio_clip) * batch.adv + + if dual_clip: + clip1 = torch.min(surr1, surr2) + clip2 = torch.max(clip1, dual_clip * batch.adv) + clip_loss = -torch.where(batch.adv < 0, clip2, clip1).mean() + else: + clip_loss = -torch.min(surr1, surr2).mean() + + value = self.policy.critic(batch.obs).flatten() + if use_value_clip: + v_clip = batch.state_value + (value - batch.state_value).clamp( + -ratio_clip, ratio_clip + ) + vf1 = (batch.returns - value).pow(2) + vf2 = (batch.returns - v_clip).pow(2) + vf_loss = torch.max(vf1, vf2).mean() + else: + vf_loss = (batch.returns - value).pow(2).mean() + + ent_loss = dist.entropy().mean() + loss = clip_loss + vf_ratio * vf_loss - ent_ratio * ent_loss + + self.optimizer.zero_grad() + loss.backward() + if use_grad_norm: # clip large gradient + nn.utils.clip_grad_norm_( + self.parameters, max_norm=self.training_config["grad_norm"] + ) + self.optimizer.step() + + clip_losses += clip_loss.item() / repeats + vf_losses += vf_loss.item() / repeats + ent_losses += ent_loss.item() / repeats + losses += loss.item() / repeats + + return { + "loss": losses, + "loss/clip": clip_losses, + "loss/vf": vf_losses, + "loss/ent": ent_losses, + } diff --git a/malib/utils/data.py b/malib/utils/data.py index 36b5fe8e..62552812 100644 --- a/malib/utils/data.py +++ b/malib/utils/data.py @@ -191,6 +191,19 @@ def compute_episodic_return( gamma: float = 0.99, gae_lambda: float = 0.95, ): + """Compute episodic return with GAE. + + Args: + batch (Dict[str, Any]): A dict of batch, including obs, rew, done at least + state_value (np.ndarray, optional): A batch of state value. Defaults to None. + next_state_value (np.ndarray, optional): A batch next state value. Defaults to None. + gamma (float, optional): Gamma. Defaults to 0.99. + gae_lambda (float, optional): GAE lambda. Defaults to 0.95. + + Returns: + a dict of batch + """ + if isinstance(batch["rew"], torch.Tensor): rew = batch["rew"].cpu().numpy() else: From b80f01d6902599b9df3b6a782640e4f31a002d63 Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Fri, 7 Apr 2023 14:40:58 +0800 Subject: [PATCH 02/24] tmp save --- examples/sarl/ppo_gym.py | 78 +++++++++++++++++++++++++++++++++++ malib/common/distributions.py | 4 +- malib/common/retrace.py | 0 malib/common/vtrace.py | 3 ++ malib/rl/ppo/__init__.py | 4 ++ malib/rl/ppo/config.py | 1 + malib/rl/ppo/trainer.py | 6 +-- malib/utils/statistic.py | 2 +- 8 files changed, 92 insertions(+), 6 deletions(-) create mode 100644 examples/sarl/ppo_gym.py create mode 100644 malib/common/retrace.py create mode 100644 malib/common/vtrace.py create mode 100644 malib/rl/ppo/config.py diff --git a/examples/sarl/ppo_gym.py b/examples/sarl/ppo_gym.py new file mode 100644 index 00000000..46e9a1f8 --- /dev/null +++ b/examples/sarl/ppo_gym.py @@ -0,0 +1,78 @@ +import os +import time + +from argparse import ArgumentParser + +from malib.agent import IndependentAgent +from malib.scenarios.marl_scenario import MARLScenario + +from malib.runner import run +from malib.rl.ppo import PPOPolicy, PPOTrainer, DEFAULT_CONFIG +from malib.rollout.envs.gym import env_desc_gen + + +if __name__ == "__main__": + parser = ArgumentParser("Use PPO solve Gym tasks.") + parser.add_argument("--log-dir", default="./logs/", help="Log directory.") + parser.add_argument("--env-id", default="CartPole-v1", help="gym environment id.") + parser.add_argument("--use-cuda", action="store_true") + + args = parser.parse_args() + + trainer_config = DEFAULT_CONFIG["training_config"].copy() + trainer_config["total_timesteps"] = int(1e6) + trainer_config["use_cuda"] = args.use_cuda + + training_config = { + "type": IndependentAgent, + "trainer_config": trainer_config, + "custom_config": {}, + } + + rollout_config = { + "fragment_length": 2000, # determine the size of sended data block + "max_step": 200, + "num_eval_episodes": 10, + "num_threads": 2, + "num_env_per_thread": 10, + "num_eval_threads": 1, + "use_subproc_env": False, + "batch_mode": "time_step", + "postprocessor_types": ["defaults"], + # every # rollout epoch run evaluation. + "eval_interval": 1, + "inference_server": "ray", # three kinds of inference server: `local`, `pipe` and `ray` + } + agent_mapping_func = lambda agent: agent + + algorithms = { + "default": ( + PPOPolicy, + PPOTrainer, + # model configuration, None for default + {}, + {"use_cuda": args.use_cuda}, + ) + } + + env_description = env_desc_gen(env_id=args.env_id, scenario_configs={}) + runtime_logdir = os.path.join(args.log_dir, f"sa_ppo_gym/{time.time()}") + + if not os.path.exists(runtime_logdir): + os.makedirs(runtime_logdir) + + scenario = MARLScenario( + name=f"ppo-sym-{args.env_id}", + log_dir=runtime_logdir, + algorithms=algorithms, + env_description=env_description, + training_config=training_config, + rollout_config=rollout_config, + agent_mapping_func=agent_mapping_func, + stopping_conditions={ + "training": {"max_iteration": int(1e10)}, + "rollout": {"max_iteration": 1000, "minimum_reward_improvement": 1.0}, + }, + ) + + run(scenario) diff --git a/malib/common/distributions.py b/malib/common/distributions.py index 132453e8..38d8b62b 100644 --- a/malib/common/distributions.py +++ b/malib/common/distributions.py @@ -270,7 +270,7 @@ def log_prob( ) # Squash correction (from original SAC implementation) # this comes from the fact that tanh is bijective and differentiable - log_prob -= torch.sum(torch.log(1 - actions**2 + self.epsilon), dim=1) + log_prob -= torch.sum(torch.log(1 - actions ** 2 + self.epsilon), dim=1) return log_prob def entropy(self) -> Optional[torch.Tensor]: @@ -627,7 +627,7 @@ def proba_distribution( """ # Stop gradient if we don't want to influence the features self._latent_sde = latent_sde if self.learn_features else latent_sde.detach() - variance = torch.mm(self._latent_sde**2, self.get_std(log_std) ** 2) + variance = torch.mm(self._latent_sde ** 2, self.get_std(log_std) ** 2) self.distribution = Normal(mean_actions, torch.sqrt(variance + self.epsilon)) return self diff --git a/malib/common/retrace.py b/malib/common/retrace.py new file mode 100644 index 00000000..e69de29b diff --git a/malib/common/vtrace.py b/malib/common/vtrace.py new file mode 100644 index 00000000..450f70cd --- /dev/null +++ b/malib/common/vtrace.py @@ -0,0 +1,3 @@ +""" +An implementation of V-Trace algorithm, to taclking the imbalance of data use for on-policy training. +""" diff --git a/malib/rl/ppo/__init__.py b/malib/rl/ppo/__init__.py index 47cb69ed..31b9a4d4 100644 --- a/malib/rl/ppo/__init__.py +++ b/malib/rl/ppo/__init__.py @@ -21,3 +21,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. + +from policy import PPOPolicy +from trainer import PPOTrainer +from config import DEFAULT_CONFIG diff --git a/malib/rl/ppo/config.py b/malib/rl/ppo/config.py new file mode 100644 index 00000000..bde9a653 --- /dev/null +++ b/malib/rl/ppo/config.py @@ -0,0 +1 @@ +DEFAULT_CONFIG = {} diff --git a/malib/rl/ppo/trainer.py b/malib/rl/ppo/trainer.py index 74e0a270..da161717 100644 --- a/malib/rl/ppo/trainer.py +++ b/malib/rl/ppo/trainer.py @@ -22,13 +22,12 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Sequence, List, Any, Dict +from typing import List, Dict import torch from torch import nn -from malib.utils.data import Postprocessor -from malib.utils.typing import AgentID + from malib.utils.tianshou_batch import Batch from malib.rl.a2c import A2CTrainer @@ -46,6 +45,7 @@ def train(self, batch: Batch) -> Dict[str, List[float]]: use_value_clip = self.training_config["use_value_clip"] losses, clip_losses, vf_losses, ent_losses = 0.0, 0.0, 0.0, 0.0 + for step in range(repeats): dist = self.policy.dist_fn.proba_distribution(batch.logits) if use_adv_norm: diff --git a/malib/utils/statistic.py b/malib/utils/statistic.py index 6b050059..9ed7ebf7 100644 --- a/malib/utils/statistic.py +++ b/malib/utils/statistic.py @@ -35,7 +35,7 @@ def update(self, data_array: np.ndarray) -> None: new_mean = self.mean + delta * batch_count / total_count m_a = self.var * self.count m_b = batch_var * batch_count - m_2 = m_a + m_b + delta**2 * self.count * batch_count / total_count + m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count new_var = m_2 / total_count self.mean, self.var = new_mean, new_var From f32ce4a5212d86dcc4988ba2fb9a97cea0fdd043 Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Fri, 7 Apr 2023 15:05:13 +0800 Subject: [PATCH 03/24] reformatted --- malib/common/distributions.py | 4 ++-- malib/rl/ppo/trainer.py | 1 + malib/utils/statistic.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/malib/common/distributions.py b/malib/common/distributions.py index 38d8b62b..132453e8 100644 --- a/malib/common/distributions.py +++ b/malib/common/distributions.py @@ -270,7 +270,7 @@ def log_prob( ) # Squash correction (from original SAC implementation) # this comes from the fact that tanh is bijective and differentiable - log_prob -= torch.sum(torch.log(1 - actions ** 2 + self.epsilon), dim=1) + log_prob -= torch.sum(torch.log(1 - actions**2 + self.epsilon), dim=1) return log_prob def entropy(self) -> Optional[torch.Tensor]: @@ -627,7 +627,7 @@ def proba_distribution( """ # Stop gradient if we don't want to influence the features self._latent_sde = latent_sde if self.learn_features else latent_sde.detach() - variance = torch.mm(self._latent_sde ** 2, self.get_std(log_std) ** 2) + variance = torch.mm(self._latent_sde**2, self.get_std(log_std) ** 2) self.distribution = Normal(mean_actions, torch.sqrt(variance + self.epsilon)) return self diff --git a/malib/rl/ppo/trainer.py b/malib/rl/ppo/trainer.py index da161717..126946c8 100644 --- a/malib/rl/ppo/trainer.py +++ b/malib/rl/ppo/trainer.py @@ -48,6 +48,7 @@ def train(self, batch: Batch) -> Dict[str, List[float]]: for step in range(repeats): dist = self.policy.dist_fn.proba_distribution(batch.logits) + if use_adv_norm: mean, std = batch.adv.mean(), batch.adv.std() batch.adv = (batch.adv - mean) / (std + adv_norm_eps) diff --git a/malib/utils/statistic.py b/malib/utils/statistic.py index 9ed7ebf7..6b050059 100644 --- a/malib/utils/statistic.py +++ b/malib/utils/statistic.py @@ -35,7 +35,7 @@ def update(self, data_array: np.ndarray) -> None: new_mean = self.mean + delta * batch_count / total_count m_a = self.var * self.count m_b = batch_var * batch_count - m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count + m_2 = m_a + m_b + delta**2 * self.count * batch_count / total_count new_var = m_2 / total_count self.mean, self.var = new_mean, new_var From f76e73b39caba54d8f512d811abb62e0e6211a91 Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Fri, 7 Apr 2023 15:11:17 +0800 Subject: [PATCH 04/24] fix relative import error --- malib/rl/ppo/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/malib/rl/ppo/__init__.py b/malib/rl/ppo/__init__.py index 31b9a4d4..31ec9130 100644 --- a/malib/rl/ppo/__init__.py +++ b/malib/rl/ppo/__init__.py @@ -22,6 +22,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from policy import PPOPolicy -from trainer import PPOTrainer -from config import DEFAULT_CONFIG +from .policy import PPOPolicy +from .trainer import PPOTrainer +from .config import DEFAULT_CONFIG From e736658e12c7099f1505e62fc8a3a085acfb77dc Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Fri, 7 Apr 2023 16:22:27 +0800 Subject: [PATCH 05/24] add an implementation of vtrace, not test --- malib/common/vtrace.py | 127 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) diff --git a/malib/common/vtrace.py b/malib/common/vtrace.py index 450f70cd..22cdbe07 100644 --- a/malib/common/vtrace.py +++ b/malib/common/vtrace.py @@ -1,3 +1,130 @@ """ An implementation of V-Trace algorithm, to taclking the imbalance of data use for on-policy training. """ + +from functools import reduce +from collections import namedtuple + +import torch + +from torch.nn import functional as F + +from malib.utils.data import to_torch + + +VTraceRet = namedtuple("VTraceReturn", "vs,adv") +VTraceFromLogitsReturns = namedtuple( + "VTraceFromLogitsReturns", + ["vs", "adv", "log_rhos", "behaviour_log_prob", "target_log_prob"], +) + + +def _acc_func(acc, item): + discount_t, c_t, delta_t = item + return delta_t + discount_t * c_t * acc + + +def from_importance_weights( + log_rhos: torch.Tensor, + discounts: torch.Tensor, + rewards: torch.Tensor, + values: torch.Tensor, + bootstrap_values: torch.Tensor, + clip_rho: float, + clip_pg_rho: float, +): + """Calculates V-trace values from log importance weights. + + Args: + log_rhos (torch.Tensor): A float tensor of shape [T, B, N_ACTIONS] representing the log importantce sampling weights, i.e., log[target_p(a) / behavior_p(a)]. + discounts (torch.Tensor): A float tensor of shape [T, B] with discounts encountered by following the behaviour policy. + rewards (torch.Tensor): A float tensor of shape [T, B] containing rewards generated by following the behavior policy. + values (torch.Tensor): A float tensor of shape [T, B] with the value function estimates wrt. the target policy. + bootstrap_values (torch.Tensor): A float of shape [B] with teh value function estimate at time T. + clip_rho (float): A float scalar with the clipping threshold for importance weights (rho) when calculating the baseline targets (Vs). + clip_pg_rho (float): A float scalar with the clipping threshold on rho_s in rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). + """ + + log_rhos = to_torch(log_rhos, dtype=torch.float32) + discounts = to_torch(discounts, dtype=torch.float32) + rewards = to_torch(rewards, dtype=torch.float32) + values = to_torch(values, dtype=torch.float32) + bootstrap_values = to_torch(bootstrap_values, dtype=torch.float32) + + # shape assert + assert log_rhos.shape.__len__() == 3, log_rhos.shape + assert values.shape.__len__() == 2, values.shape + assert bootstrap_values.shape.__len__() == 1, bootstrap_values.shape + assert rewards.shape.__len__() == 2, rewards.shape + assert discounts.shape.__len__() == 2, discounts.shape + + rhos = torch.exp(log_rhos) + clipped_rho = torch.minimum(clip_rho, rhos) if clip_rho else rhos + cs = torch.minimum(1.0, rhos) + values_t_plus_1 = torch.concat([values[1:], bootstrap_values.unsqueeze(0)], dim=0) + deltas = clipped_rho * (rewards + discounts * values_t_plus_1 - values) + + sequences = (discounts, cs, deltas) + vs_minus_v_xs = reduce(_acc_func, sequences) + + assert vs_minus_v_xs.shape == values, (vs_minus_v_xs.shape, values.shape) + vs = vs_minus_v_xs + values + + # compute advantages + vs_t_plus_1 = torch.concat([vs[1:], bootstrap_values.unsqueeze(0)], axis=0) + clippped_pg_rho = torch.minimum(clip_pg_rho, rhos) if clip_pg_rho else rhos + advantages = clippped_pg_rho * (rewards + discounts * vs_t_plus_1 - values) + + return VTraceRet(vs=vs.detach(), adv=advantages.detach()) + + +def vtrace( + behavior_logits: torch.Tensor, + target_logits: torch.Tensor, + actions: torch.Tensor, + discounts: torch.Tensor, + rewards: torch.Tensor, + values: torch.Tensor, + bootstrap_values: torch.Tensor, + clip_rho: float = 1.0, + clip_pg_rho: float = 1.0, +): + """Calculates V-trace actor critic targets for softmax policies as introduced in + + "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" by Espeholt, Soyer, Munos et al. + + Args: + behavior_logits (torch.Tensor): A float tensor of shape [T, B, N_ACTIONS] with un-normalized log-probabilities parameterizing the softmax behaviour policy. + target_logits (torch.Tensor): A float tensor of shape [T, B, N_ACTIONS] with un-normalized log-probabilities parameterizing the softmax target policy. + actions (torch.Tensor): An int tensor of shape [T, B] of actions sampled from the behavior policy. + discounts (torch.Tensor): A float tensor of shape [T, B] with the discount encountered when following the behavior policy. + rewards (torch.Tensor): A float tensor of shape [T, B] with the rewards generated by following the behavior policy. + values (torch.Tensor): A float tensor of shape [T, B] with the value function estimates wrt. the target policy. + bootstrap_values (torch.Tensor): A float of shape [B] with the value function estimate at time T. + clip_rho (float, optional): A float scalar with the clipping threshold for importance weights (RHO) when calculating the baseline targets (Vs), i.e., \bar{rho} in the paper. Defaults to 1.. + clip_pg_rho (float, optional): A float scalar with the clipping threshold on rho_s in rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). Defaults to 1. + """ + + behavior_logits = to_torch(behavior_logits, dtype=torch.float32) + target_logits = to_torch(target_logits, dtype=torch.float32) + actions = to_torch(actions, dtype=torch.int32) + + # shape checking + assert behavior_logits.shape.__len__() == 3, behavior_logits.shape + assert target_logits.shape.__len__() == 3, target_logits.shape + assert actions.shape.__len__() == 2, actions.shape + + # compute log probs for behavior and target policy + behavior_log_prob = F.cross_entropy(input=behavior_logits, target=actions) + target_log_prob = F.cross_entropy(input=target_logits, target=actions) + + log_rhos = target_log_prob - behavior_log_prob + vtrace_ret = from_importance_weights( + log_rhos, discounts, rewards, values, bootstrap_values, clip_rho, clip_pg_rho + ) + + return VTraceFromLogitsReturns( + log_rhos=log_rhos, + behaviour_log_prob=behavior_log_prob, + target_log_prob=target_log_prob ** vtrace_ret._asdict(), + ) From 70f91bb9c3026e7f6a9e9eefc4d863e45c47f396 Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Sun, 10 Sep 2023 21:49:59 +0800 Subject: [PATCH 06/24] tmp save: implement sync-agent pipeline --- docs/source/api/malib.common.rst | 16 +++++ docs/source/api/malib.rl.ppo.rst | 8 +++ examples/sarl/ppo_gym.py | 9 ++- malib/agent/agent_interface.py | 3 +- malib/agent/manager.py | 100 ++++++++++++++++++++++--------- malib/runner.py | 2 +- malib/scenarios/marl_scenario.py | 16 ++++- 7 files changed, 120 insertions(+), 34 deletions(-) diff --git a/docs/source/api/malib.common.rst b/docs/source/api/malib.common.rst index 46c0add4..042a71c2 100644 --- a/docs/source/api/malib.common.rst +++ b/docs/source/api/malib.common.rst @@ -33,6 +33,14 @@ malib.common.payoff\_manager module :undoc-members: :show-inheritance: +malib.common.retrace module +--------------------------- + +.. automodule:: malib.common.retrace + :members: + :undoc-members: + :show-inheritance: + malib.common.strategy\_spec module ---------------------------------- @@ -40,3 +48,11 @@ malib.common.strategy\_spec module :members: :undoc-members: :show-inheritance: + +malib.common.vtrace module +-------------------------- + +.. automodule:: malib.common.vtrace + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/malib.rl.ppo.rst b/docs/source/api/malib.rl.ppo.rst index 84dd29aa..c279a9d8 100644 --- a/docs/source/api/malib.rl.ppo.rst +++ b/docs/source/api/malib.rl.ppo.rst @@ -9,6 +9,14 @@ malib.rl.ppo package Submodules ---------- +malib.rl.ppo.config module +-------------------------- + +.. automodule:: malib.rl.ppo.config + :members: + :undoc-members: + :show-inheritance: + malib.rl.ppo.policy module -------------------------- diff --git a/examples/sarl/ppo_gym.py b/examples/sarl/ppo_gym.py index 46e9a1f8..34fd2806 100644 --- a/examples/sarl/ppo_gym.py +++ b/examples/sarl/ppo_gym.py @@ -24,7 +24,7 @@ trainer_config["use_cuda"] = args.use_cuda training_config = { - "type": IndependentAgent, + "learner_type": IndependentAgent, "trainer_config": trainer_config, "custom_config": {}, } @@ -43,13 +43,16 @@ "eval_interval": 1, "inference_server": "ray", # three kinds of inference server: `local`, `pipe` and `ray` } + + # one to one, no sharing, if sharing, implemented as: + # agent_mapping_func = lambda agent: "default" agent_mapping_func = lambda agent: agent algorithms = { "default": ( PPOPolicy, PPOTrainer, - # model configuration, None for default + # model configuration, None as default {}, {"use_cuda": args.use_cuda}, ) @@ -62,7 +65,7 @@ os.makedirs(runtime_logdir) scenario = MARLScenario( - name=f"ppo-sym-{args.env_id}", + name=f"ppo-gym-{args.env_id}", log_dir=runtime_logdir, algorithms=algorithms, env_description=env_description, diff --git a/malib/agent/agent_interface.py b/malib/agent/agent_interface.py index 411d0413..b4e43132 100644 --- a/malib/agent/agent_interface.py +++ b/malib/agent/agent_interface.py @@ -198,7 +198,7 @@ def add_policies(self, n: int) -> StrategySpec: """ for _ in range(n): - spec_pid = f"policy-{len(self._strategy_spec.policy_ids)}" + spec_pid = f"policy-{len(self._strategy_spec)}" self._strategy_spec.register_policy_id(policy_id=spec_pid) policy = self._strategy_spec.gen_policy() policy_id = f"{self._strategy_spec.id}/{spec_pid}" @@ -336,7 +336,6 @@ def train( if reset_state: self.reset() - # stopper = get_stopper(conditions=stopping_conditions) reader_info_dict: Dict[str, Tuple[str, Queue]] = {} assert len(self._active_tups) == 1, "the length of active tups can be only 1." diff --git a/malib/agent/manager.py b/malib/agent/manager.py index d8a0df42..a4f80c3b 100644 --- a/malib/agent/manager.py +++ b/malib/agent/manager.py @@ -22,11 +22,23 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Any, Callable, List, Tuple, Union, Set, Sequence, Type +from typing import ( + Dict, + Any, + Callable, + List, + Tuple, + Union, + Set, + Sequence, + Type, + Generator, +) from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, Future, CancelledError import os +import traceback import ray from malib.utils.typing import AgentID @@ -66,7 +78,7 @@ def __init__( env_desc (Dict[str, Any]): The description for environment generation. interface_config (Dict[str, Any]): Configuration for agent training inferece construction, keys include \ `type` and `custom_config`, a dict. - agent_mapping_func (Callable[[AgentID], str]): The mapping function maps agent id to training interface id. + agent_mapping_func (Callable[[AgentID], str]): The mapping function maps environment agent id to training agent (`agent_interface`) id. training_config (Dict[str, Any]): Training configuration, for agent interface, keys include \ `type`, `trainer_config` and `custom_config`. log_dir (str): Directory for logging. @@ -91,9 +103,10 @@ def __init__( if not os.path.exists(log_dir): os.makedirs(log_dir) - agent_cls = training_config["type"] + agent_cls = training_config["learner_type"] # update num gpus resource_config["num_gpus"] = num_gpus + # XXX(ming): why we hard set max_concurrency to 10? agent_cls = agent_cls.as_remote(**resource_config).options(max_concurrency=10) interfaces: Dict[str, Union[AgentInterface, ray.ObjectRef]] = {} @@ -148,10 +161,22 @@ def agent_groups(self) -> Dict[str, Set[AgentID]]: @property def workers(self) -> List[RemoteInterface]: + """A list of learner instance + + Returns: + List[RemoteInterface]: A list of learner instance + """ + return list(self._interfaces.values()) @property def runtime_ids(self) -> Tuple[str]: + """Return a tuple of learner ids + + Returns: + Tuple[str]: A tuple of string as leqrner ids + """ + return self._runtime_ids def add_policies( @@ -172,11 +197,12 @@ def add_policies( assert isinstance(interface_ids, (List, Tuple, Set)), type(interface_ids) - ns = dict.fromkeys(interface_ids, n) if isinstance(n, int) else n + policy_nums = dict.fromkeys(interface_ids, n) if isinstance(n, int) else n + if self._remote_mode: strategy_spec_list: List[StrategySpec] = ray.get( [ - self._interfaces[k].add_policies.remote(n=ns[k]) + self._interfaces[k].add_policies.remote(n=policy_nums[k]) for k in interface_ids ] ) @@ -185,7 +211,8 @@ def add_policies( ) else: strategy_spec_dict = { - k: self._interfaces[k].add_policies(n=ns[k]) for k in interface_ids + k: self._interfaces[k].add_policies(n=policy_nums[k]) + for k in interface_ids } return strategy_spec_dict @@ -193,29 +220,48 @@ def add_policies( def run(self, data_request_identifiers: Dict[str, str]): """Start training thread without blocking""" - if self._remote_mode: - for rid, interface in self._interfaces.items(): - self.pending_tasks.append( - interface.train.remote( - data_request_identifiers[rid], - self._stopping_conditions["training"], - ) + for rid, interface in self._interfaces.items(): + if self._remote_mode: + task = interface.train.remote( + data_request_identifiers[rid], + self._stopping_conditions["training"], ) - else: - for rid, interface in self._interfaces.items(): - self.pending_tasks.append( - self._thread_pool.submit( - interface.train, - data_request_identifiers[rid], - self._stopping_conditions["training"], - ) + else: + task = self._thread_pool.submit( + interface.train, + data_request_identifiers[rid], + self._stopping_conditions["training"], ) + self.pending_tasks.append(task) + + def retrive_results(self) -> Generator: + """Return a generator of results + + Yields: + Generator: A generator for task results + """ - def retrive_results(self): - while len(self.pending_tasks) > 0: - dones, self.pending_tasks = ray.wait(self.pending_tasks) - for done in ray.get(dones): - yield done + if self._remote_mode: + while len(self.pending_tasks) > 0: + dones, self.pending_tasks = ray.wait(self.pending_tasks) + for done in ray.get(dones): + yield done + else: + for task in self.pending_tasks: + assert isinstance(task, Future) + try: + if task.done(): + yield task.result(timeout=10) + except TimeoutError: + Logger.error( + f"Retrieving results of training task is timeout: {traceback.format_exc()}" + ) + except CancelledError: + Logger.error( + f"Try to retrieve results of a cancelled task: {traceback.format_exc()}" + ) + except Exception: + Logger.error(traceback.format_exc()) def terminate(self) -> None: """Terminate all training actors.""" diff --git a/malib/runner.py b/malib/runner.py index b248265d..380efb8c 100644 --- a/malib/runner.py +++ b/malib/runner.py @@ -84,7 +84,7 @@ def run(scenario: Scenario, cluster_address: str = "auto"): scenario.parameter_server = parameter_server scenario.offline_dataset_server = offline_dataset_server - experiment_tag = f"malib-{scenario.name}-{time.time()}" + experiment_tag = f"malib-{scenario.name}-{time.strftime('%Y-%m-%d-%H%M%S')}" if isinstance(scenario, psro_scenario.PSROScenario): psro_scenario.execution_plan(experiment_tag, scenario) diff --git a/malib/scenarios/marl_scenario.py b/malib/scenarios/marl_scenario.py index bfe17811..3dc41b54 100644 --- a/malib/scenarios/marl_scenario.py +++ b/malib/scenarios/marl_scenario.py @@ -87,7 +87,21 @@ def execution_plan( scenario: Scenario, recall_resource: bool = True, verbose: bool = True, -): +) -> Dict[str, Any]: + """Execute multi-agent learning task. If there is not a `training_manager`/`rollout_manager` \ + registered to the given scenario, then a new `training_manager`/`rollout_mamanger` will be created in \ + remote mode. + + Args: + experiment_tag (str): Experiment identity + scenario (Scenario): Scenario instance + recall_resource (bool, optional): Recall resource or not. Defaults to True. + verbose (bool, optional): Enable verbose print. Defaults to True. + + Returns: + Dict[str, Any]: A dict of final returns, currently returns only a dict of strategy specs, e.g., {'strategy_specs': strategy_specs} + """ + if hasattr(scenario, "training_manager"): training_manager: TrainingManager = scenario.training_manager else: From f7aff9055ac67c98c95bdd92326b840e8b65dfb6 Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Sun, 8 Oct 2023 08:55:34 +0800 Subject: [PATCH 07/24] tmp save --- malib/agent/manager.py | 49 ++++++------ malib/backend/league.py | 22 ++++++ malib/common/rollout_config.py | 31 ++++++++ malib/common/task.py | 37 +++++++++ malib/common/training_config.py | 32 ++++++++ malib/rollout/envs/sc2/env.py | 7 +- malib/rollout/manager.py | 34 ++++++--- malib/rollout/rolloutworker.py | 37 +++++---- malib/scenarios/sarl_scenario.py | 126 +++++++++++++++++++++++++++++++ setup.py | 2 +- tests/agents/test_manager.py | 9 +-- 11 files changed, 331 insertions(+), 55 deletions(-) create mode 100644 malib/backend/league.py create mode 100644 malib/common/rollout_config.py create mode 100644 malib/common/task.py create mode 100644 malib/common/training_config.py create mode 100644 malib/scenarios/sarl_scenario.py diff --git a/malib/agent/manager.py b/malib/agent/manager.py index a4f80c3b..f96d77e4 100644 --- a/malib/agent/manager.py +++ b/malib/agent/manager.py @@ -48,6 +48,7 @@ from malib.agent.agent_interface import AgentInterface from malib.common.strategy_spec import StrategySpec from malib.common.manager import Manager +from malib.common.training_config import TrainingConfig DEFAULT_RESOURCE_CONFIG = dict( @@ -63,7 +64,7 @@ def __init__( algorithms: Dict[str, Any], env_desc: Dict[str, Any], agent_mapping_func: Callable[[AgentID], str], - training_config: Dict[str, Any], + training_config: Union[Dict[str, Any], TrainingConfig], log_dir: str, remote_mode: bool = True, resource_config: Dict[str, Any] = None, @@ -82,12 +83,13 @@ def __init__( training_config (Dict[str, Any]): Training configuration, for agent interface, keys include \ `type`, `trainer_config` and `custom_config`. log_dir (str): Directory for logging. - remote_mode (bool, Optional): Init agent interfaces as remote actor or not. Default is True. + remote_mode (bool, Optional): Init learners as remote actor or not. Default is True. """ super().__init__(verbose=verbose) resource_config = resource_config or DEFAULT_RESOURCE_CONFIG + training_config = TrainingConfig.from_raw(training_config) # interface config give the agent type used here and the group mapping if needed agent_groups = defaultdict(lambda: set()) @@ -96,27 +98,29 @@ def __init__( agent_groups[rid].add(agent) # FIXME(ming): resource configuration is not available now, will open in the next version - if training_config["trainer_config"].get("use_cuda", False): + if training_config.trainer_config.get("use_cuda", False): num_gpus = 1 / len(agent_groups) else: num_gpus = 0.0 if not os.path.exists(log_dir): os.makedirs(log_dir) - agent_cls = training_config["learner_type"] + learner_cls = training_config.learner_type # update num gpus resource_config["num_gpus"] = num_gpus # XXX(ming): why we hard set max_concurrency to 10? - agent_cls = agent_cls.as_remote(**resource_config).options(max_concurrency=10) - interfaces: Dict[str, Union[AgentInterface, ray.ObjectRef]] = {} + learner_cls = learner_cls.as_remote(**resource_config).options( + max_concurrency=10 + ) + learners: Dict[str, Union[AgentInterface, ray.ObjectRef]] = {} assert ( "training" in stopping_conditions ), f"Stopping conditions should contains `training` stoppong conditions: {stopping_conditions}" for rid, agents in agent_groups.items(): - handler = agent_cls.remote if remote_mode else agent_cls - interfaces[rid] = handler( + _cls = learner_cls.remote if remote_mode else learner_cls + learners[rid] = _cls( experiment_tag=experiment_tag, runtime_id=rid, log_dir=f"{log_dir}/learner_{rid}", @@ -124,14 +128,14 @@ def __init__( algorithms=algorithms, agent_mapping_func=agent_mapping_func, governed_agents=tuple(agents), - trainer_config=training_config["trainer_config"], - custom_config=training_config.get("custom_config"), + trainer_config=training_config.trainer_config, + custom_config=training_config.custom_config, verbose=verbose, ) # ensure all interfaces have been started up if remote_mode: - _ = ray.get([x.connect.remote() for x in interfaces.values()]) + _ = ray.get([x.connect.remote() for x in learners.values()]) self._agent_groups = agent_groups self._runtime_ids = tuple(self._agent_groups.keys()) @@ -140,13 +144,13 @@ def __init__( self._training_config = training_config self._log_dir = log_dir self._agent_mapping_func = agent_mapping_func - self._interfaces = interfaces + self._learners = learners self._remote_mode = remote_mode - self._thread_pool = ThreadPoolExecutor(max_workers=len(interfaces)) + self._thread_pool = ThreadPoolExecutor(max_workers=len(learners)) self._stopping_conditions = stopping_conditions Logger.info( - f"training manager launched, {len(self._interfaces)} learner(s) created" + f"training manager launched, {len(self._learners)} learner(s) created" ) @property @@ -167,7 +171,7 @@ def workers(self) -> List[RemoteInterface]: List[RemoteInterface]: A list of learner instance """ - return list(self._interfaces.values()) + return list(self._learners.values()) @property def runtime_ids(self) -> Tuple[str]: @@ -193,7 +197,7 @@ def add_policies( """ if interface_ids is None: - interface_ids = list(self._interfaces.keys()) + interface_ids = list(self._learners.keys()) assert isinstance(interface_ids, (List, Tuple, Set)), type(interface_ids) @@ -202,7 +206,7 @@ def add_policies( if self._remote_mode: strategy_spec_list: List[StrategySpec] = ray.get( [ - self._interfaces[k].add_policies.remote(n=policy_nums[k]) + self._learners[k].add_policies.remote(n=policy_nums[k]) for k in interface_ids ] ) @@ -211,16 +215,19 @@ def add_policies( ) else: strategy_spec_dict = { - k: self._interfaces[k].add_policies(n=policy_nums[k]) + k: self._learners[k].add_policies(n=policy_nums[k]) for k in interface_ids } return strategy_spec_dict + def submit(self, task: Any): + raise NotImplementedError + def run(self, data_request_identifiers: Dict[str, str]): """Start training thread without blocking""" - for rid, interface in self._interfaces.items(): + for rid, interface in self._learners.items(): if self._remote_mode: task = interface.train.remote( data_request_identifiers[rid], @@ -269,11 +276,11 @@ def terminate(self) -> None: super().terminate() if self._remote_mode: - for x in self._interfaces.values(): + for x in self._learners.values(): ray.kill(x) self._thread_pool.shutdown() - del self._interfaces + del self._learners def get_exp(self, policy_distribution): """Compute exploitability""" diff --git a/malib/backend/league.py b/malib/backend/league.py new file mode 100644 index 00000000..f3d50edd --- /dev/null +++ b/malib/backend/league.py @@ -0,0 +1,22 @@ +import traceback + +from malib.utils.logging import Logger +from malib.common.manager import Manager + + +class League: + def __init__(self, training_manager: Manager, rollout_manager: Manager) -> None: + self.training_manager = training_manager + self.rollout_manager = rollout_manager + + def get_results(self): + try: + while True: + # TODO(ming): check whether done + raise NotImplementedError + except KeyboardInterrupt: + Logger.info("Keyboard interruption was detected, recalling resources ...") + except RuntimeError: + Logger.error(traceback.format_exc()) + except Exception: + Logger.error(traceback.format_exc()) diff --git a/malib/common/rollout_config.py b/malib/common/rollout_config.py new file mode 100644 index 00000000..c238182f --- /dev/null +++ b/malib/common/rollout_config.py @@ -0,0 +1,31 @@ +from typing import Dict, Any, Union + +from dataclasses import dataclass, field + + +@dataclass +class RolloutConfig: + + inference_server_type: str + + + @classmethod + def from_raw(cls, config: Union["RolloutConfig", Dict[str, Any]]) -> "RolloutConfig": + """Cat dict-style configuration to RolloutConfig instance + + Args: + config (Dict[str, Any]): A dict + + Raises: + RuntimeError: Unexpected config type + + Returns: + RolloutConfig: A rollout config instance + """ + + if isinstance(config, cls): + return config + elif isinstance(config, Dict): + return cls(**config) + else: + raise RuntimeError(f"Unexpected rollout config type: {type(config)}") \ No newline at end of file diff --git a/malib/common/task.py b/malib/common/task.py new file mode 100644 index 00000000..e06f654f --- /dev/null +++ b/malib/common/task.py @@ -0,0 +1,37 @@ +from typing import Dict, Union, Any, List + +from enum import IntEnum +from dataclasses import dataclass, field + +from malib.utils.typing import AgentID + + +class TaskType(IntEnum): + ROLLOUT = 0 + EVALUATION = 1 + OPTIMIZATION = 2 + + +@dataclass +class RolloutTask: + + task_type: int + active_agents: List[AgentID] + strategy_specs: Dict[str, Any] = field(default_factory=dict()) + + @classmethod + def from_raw(cls, dict_style: Union[Dict[str, Any], "RolloutTask"], **kwargs) -> "RolloutTask": + if isinstance(dict_style, Dict): + return cls(**dict_style, **kwargs) + elif isinstance(dict_style, cls): + return dict_style + else: + raise TypeError(f"Unexpected type: {type(dict_style)}") + + +@dataclass +class OptimizationTask: + + @classmethod + def from_raw(cls, dict_style: Union[Dict[str, Any], "OptimizationTask"], **kwargs) -> "OptimizationTask": + raise NotImplementedError diff --git a/malib/common/training_config.py b/malib/common/training_config.py new file mode 100644 index 00000000..93357cd4 --- /dev/null +++ b/malib/common/training_config.py @@ -0,0 +1,32 @@ +from typing import Dict, Any, Union + +from dataclasses import dataclass, field + + +@dataclass +class TrainingConfig: + + trainer_config: Dict[str, Any] + learner_type: str + custom_config: Dict[str, Any] = field(default_factory=dict()) + + @classmethod + def from_raw(cls, config: Union["TrainingConfig", Dict[str, Any]]) -> "TrainingConfig": + """Cat dict-style configuration to TrainingConfig instance + + Args: + config (Dict[str, Any]): A dict + + Raises: + RuntimeError: Unexpected config type + + Returns: + TrainingConfig: A training config instance + """ + + if isinstance(config, Dict): + return cls(**config) + elif isinstance(config, cls): + return config + else: + raise RuntimeError(f"Unexpected training config type: {type(config)}") diff --git a/malib/rollout/envs/sc2/env.py b/malib/rollout/envs/sc2/env.py index 2100d1e2..bb26e261 100644 --- a/malib/rollout/envs/sc2/env.py +++ b/malib/rollout/envs/sc2/env.py @@ -4,7 +4,12 @@ import gym from gym import spaces -from smac.env import StarCraft2Env as sc_env +from malib.utils.logging import Logger + +try: + from smac.env import StarCraft2Env as sc_env +except ImportError: + Logger.warning("Unable to import smac") from malib.rollout.envs.env import Environment, GroupWrapper from malib.utils.typing import AgentID diff --git a/malib/rollout/manager.py b/malib/rollout/manager.py index 717d5b1f..c0819bab 100644 --- a/malib/rollout/manager.py +++ b/malib/rollout/manager.py @@ -26,7 +26,7 @@ will be assigned with rollout tasks sent from the `CoordinatorServer`. """ -from typing import Dict, Tuple, Any, Callable, Set, List +from typing import Dict, Tuple, Any, Callable, Set, List, Union from collections import defaultdict import traceback @@ -35,6 +35,8 @@ from ray.util import ActorPool +from malib.utils.logging import Logger +from malib.common.task import TaskType, RolloutTask from malib.common.manager import Manager from malib.remote.interface import RemoteInterface from malib.common.strategy_spec import StrategySpec @@ -102,7 +104,7 @@ def __init__( worker_cls = rollout_worker_cls.as_remote(num_cpus=0, num_gpus=0) workers = [] - for i in range(num_worker): + for _ in range(num_worker): workers.append( worker_cls.options(max_concurrency=100).remote( experiment_tag=experiment_tag, @@ -124,6 +126,7 @@ def __init__( for agent in env_desc["possible_agents"]: rid = agent_mapping_func(agent) agent_groups[rid].add(agent) + self._runtime_ids = tuple(agent_groups.keys()) self._agent_groups = dict(agent_groups) self.experiment_tag = experiment_tag @@ -131,6 +134,7 @@ def __init__( assert ( "rollout" in stopping_conditions ), f"Stopping conditions should contain `rollout`: {stopping_conditions}" + self.stopping_conditions = stopping_conditions @property @@ -163,16 +167,24 @@ def workers(self) -> List[RemoteInterface]: return self._workers - def simulate(self, task_list): - """Parse simulation task and dispatch it to available workers""" + def submit(self, task: Union[Dict[str, Any], List[Dict[str, Any]]], task_type: Any): + """Submit a task to workers - for task in task_list: - self._actor_pool.submit( - lambda actor, task: actor.simulate.remote(runtime_strategy_specs=task), - task, - ) + Args: + task (Union[Dict[str, Any], List[Dict[str, Any]]]): Task description or a list of task description + task_type (Any): Task type, should be an instance from TaskType + """ + + if isinstance(task, List): + task = [RolloutTask.from_raw(e) for e in task] + else: + task = [RolloutTask.from_raw(task)] + + for _task in task: + validate_strategy_specs(_task.strategy_specs) + self._actor_pool.submit(lambda actor, _task: actor.rollout.remote(_task, stopping_conditions)) - def rollout(self, task_list: List[Dict[str, Any]]) -> None: + def _rollout(self, task_list: List[Dict[str, Any]]) -> None: """Start rollout task without blocking. Args: @@ -217,7 +229,7 @@ def retrive_results(self): while self._actor_pool.has_next(): yield self._actor_pool.get_next() except Exception as e: - print(traceback.format_exc()) + Logger.error(traceback.format_exc()) raise e def terminate(self): diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index 819063d7..5f8a9d19 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -46,6 +46,7 @@ from malib.utils.stopping_conditions import get_stopper from malib.utils.monitor import write_to_tensorboard from malib.common.strategy_spec import StrategySpec +from malib.common.task import RolloutTask, TaskType from malib.remote.interface import RemoteInterface from malib.rollout.inference.ray.server import ( RayInferenceWorkerSet as RayInferenceServer, @@ -392,37 +393,41 @@ def rollout( self, runtime_strategy_specs: Dict[str, StrategySpec], stopping_conditions: Dict[str, Any], - data_entrypoints: Dict[str, str], - trainable_agents: List[AgentID] = None, + data_entrypoints: Dict[str, str] = None, + active_agents: List[AgentID] = None, ): - """Run rollout procedure, collect data until meets the stopping conditions. + """Rollout, collecting training data when `data_entrypoints` is given, until meets the stopping conditions. The `active_agents` should be None or a none-empty list to specify active agents if rollout is not serve for evaluation. - NOTE: the data collection will be triggered only for trainable agents. + NOTE: the data collection will be triggered only for active agents. Args: runtime_strategy_specs (Dict[str, StrategySpec]): A dict of strategy spec, mapping from runtime id to `StrategySpec`. stopping_conditions (Dict[str, Any]): A dict of stopping conditions. - data_entrypoints (Dict[str, str]): Mapping from runtimeids to dataentrypoint names. - trainable_agents (List[AgentID], optional): A list of environment agent id. Defaults to None, which means all environment agents will be trainable. + data_entrypoints (Dict[str, str], optional): Mapping from runtimeids to dataentrypoint names. None for evaluation. + active_agents (List[AgentID], optional): A list of environment agent id. Defaults to None, which means all environment agents will be trainable. Empty list for evaluation mode. """ stopper = get_stopper(stopping_conditions) - trainable_agents = trainable_agents or self.env_agents - queue_info_dict: Dict[str, Tuple[str, Queue]] = { - rid: None for rid in self.runtime_agent_ids - } - for rid, identifier in data_entrypoints.items(): - queue_id, queue = ray.get( - self.dataset_server.start_producer_pipe.remote(name=identifier) - ) - queue_info_dict[rid] = (queue_id, queue) + active_agents = active_agents or self.env_agents + + if data_entrypoints is not None: + queue_info_dict: Dict[str, Tuple[str, Queue]] = { + rid: None for rid in self.runtime_agent_ids + } + for rid, identifier in data_entrypoints.items(): + queue_id, queue = ray.get( + self.dataset_server.start_producer_pipe.remote(name=identifier) + ) + queue_info_dict[rid] = (queue_id, queue) + else: + queue_info_dict = None rollout_config = self.rollout_config.copy() rollout_config.update( { "flag": "rollout", "strategy_specs": runtime_strategy_specs, - "trainable_agents": trainable_agents, + "active_agents": active_agents, "agent_group": self.agent_group, } ) diff --git a/malib/scenarios/sarl_scenario.py b/malib/scenarios/sarl_scenario.py new file mode 100644 index 00000000..cfa36430 --- /dev/null +++ b/malib/scenarios/sarl_scenario.py @@ -0,0 +1,126 @@ +# MIT License + +# Copyright (c) 2021 MARL @ SJTU + +# Author: Ming Zhou + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from types import LambdaType +from typing import Dict, Any + +from concurrent.futures import ThreadPoolExecutor +from malib.scenarios import Scenario + +from malib.utils.logging import Logger +from malib.backend.league import League +from malib.agent.manager import TrainingManager +from malib.rollout.manager import RolloutWorkerManager, TaskType + + +class SARLScenario(Scenario): + def __init__( + self, + name: str, + log_dir: str, + env_desc: Dict[str, Any], + algorithms: Dict[str, Any], + training_config: Dict[str, Any], + rollout_config: Dict[str, Any], + stopping_conditions: Dict[str, Any], + dataset_config: Dict[str, Any], + parameter_server_config: Dict[str, Any], + resource_config: Dict[str, Any] = None, + ): + super().__init__( + name, + log_dir, + env_desc, + algorithms, + lambda agent: "default", + training_config, + rollout_config, + stopping_conditions, + dataset_config, + parameter_server_config, + ) + self.num_policy_each_interface = 1 + self.resource_config = resource_config or {"training": None, "rollout": None} + + +def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = True): + training_manager = TrainingManager( + experiment_tag=experiment_tag, + stopping_conditions=scenario.stopping_conditions, + algorithms=scenario.algorithms, + env_desc=scenario.env_desc, + agent_mapping_func=scenario.agent_mapping_func, + training_config=scenario.training_config, + log_dir=scenario.log_dir, + remote_mode=True, + resource_config=scenario.resource_config["training"], + verbose=verbose, + ) + + rollout_manager = RolloutWorkerManager( + experiment_tag=experiment_tag, + stopping_conditions=scenario.stopping_conditions, + num_worker=scenario.num_worker, + agent_mapping_func=scenario.agent_mapping_func, + rollout_config=scenario.rollout_config, + env_desc=scenario.env_desc, + log_dir=scenario.log_dir, + resource_config=scenario.resource_config["rollout"], + verbose=verbose, + ) + + league = League(rollout_manager, training_manager) + + strategy_specs = training_manager.add_policies(n=1) + Logger.info( + f"Training manager was inistialized with a strategy spec:\n{strategy_specs}" + ) + + data_entrypoints = {rid: rid for rid in training_manager.runtime_ids} + + assert len(data_entrypoints) == 1, "Support single agent only!" + + training_manager.submit({"data_request_identifiers": data_entrypoints}) + + rollout_task = { + "num_workers": 1, + "runtime_strategy_specs": strategy_specs, + "data_entrypoints": None, + "rollout_config": scenario.rollout_config, + "active_agents": None + } + evaluation_task = { + "num_workers": 1, + "runtime_strategy_specs": strategy_specs, + "rollout_config": getattr( + scenario, "evaluation_config", scenario.rollout_config + ), + } + + rollout_manager.submit(rollout_task, task_type=TaskType.ROLLOUT) + rollout_manager.submit(evaluation_task, task_type=TaskType.EVALUATION) + + results = league.get_results() + + return results diff --git a/setup.py b/setup.py index 629101e9..8352ca7b 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ python_requires=">=3.7", install_requires=[ "wrapt", - "ray==1.13.0", + "ray>=1.13.0", "pickle5", "torch", "tensorboardX", diff --git a/tests/agents/test_manager.py b/tests/agents/test_manager.py index 7c9f577a..1ada504d 100644 --- a/tests/agents/test_manager.py +++ b/tests/agents/test_manager.py @@ -36,7 +36,6 @@ from malib.utils.typing import AgentID from malib.agent import IndependentAgent from malib.agent.manager import TrainingManager -from malib.scenarios.marl_scenario import MARLScenario from malib.rl.random import RandomPolicy, RandomTrainer, DEFAULT_CONFIG @@ -64,14 +63,14 @@ def agent_mapping_one_to_one( @pytest.mark.parametrize("algorithms", [default_algorithms()]) @pytest.mark.parametrize("env_desc", [generate_gym_desc("CartPole-v1")]) @pytest.mark.parametrize( - "training_type,custom_training_config", [(IndependentAgent, {})] + "learner_type,custom_training_config", [(IndependentAgent, {})] ) class TestTrainingManager: def test_policy_add( self, algorithms: Dict[str, Any], env_desc: Dict[str, Any], - training_type: Type, + learner_type: Type, custom_training_config: Dict[str, Any], remote_mode: bool = True, ): @@ -91,7 +90,7 @@ def test_policy_add( parameter_server, offline_dataset_server = start_servers() training_config = { - "type": training_type, + "learner_type": learner_type, "trainer_config": DEFAULT_CONFIG["training_config"], "custom_config": custom_training_config, } @@ -119,7 +118,7 @@ def test_policy_add( target_agent_groups, ) # check agent interfaces - agent_interfaces = self.training_manager._interfaces + agent_interfaces = self.training_manager._learners assert set(agent_interfaces.keys()) == set(target_agent_groups.keys()), ( agent_interfaces.keys(), target_agent_groups.keys(), From a09ed546d3ff334510184537a590b95d627d7884 Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Sun, 8 Oct 2023 20:03:05 +0800 Subject: [PATCH 08/24] tmp save --- install.sh => conda_install.sh | 0 malib/agent/agent_interface.py | 26 +++++++--------- malib/agent/manager.py | 45 ++++++++++++++++----------- malib/backend/league.py | 25 +++++++++++++-- malib/common/task.py | 23 ++++++++++++++ malib/common/training_config.py | 1 + malib/rollout/inference/ray/server.py | 1 + malib/scenarios/sarl_scenario.py | 13 ++++++-- malib/utils/zero_copy.py | 32 +++++++++++++++++++ setup.py | 1 + 10 files changed, 129 insertions(+), 38 deletions(-) rename install.sh => conda_install.sh (100%) create mode 100644 malib/utils/zero_copy.py diff --git a/install.sh b/conda_install.sh similarity index 100% rename from install.sh rename to conda_install.sh diff --git a/malib/agent/agent_interface.py b/malib/agent/agent_interface.py index b4e43132..012fc213 100644 --- a/malib/agent/agent_interface.py +++ b/malib/agent/agent_interface.py @@ -32,7 +32,7 @@ import time import traceback - +import gym import torch import ray @@ -60,7 +60,8 @@ def __init__( experiment_tag: str, runtime_id: str, log_dir: str, - env_desc: Dict[str, Any], + observation_space: gym.Space, + action_space: gym.Space, algorithms: Dict[str, Tuple[Type, Type, Dict, Dict]], agent_mapping_func: Callable[[AgentID], str], governed_agents: Tuple[AgentID], @@ -75,11 +76,12 @@ def __init__( experiment_tag (str): Experiment tag. runtime_id (str): Assigned runtime id, should be an element of the agent mapping results. log_dir (str): The directory for logging. - env_desc (Dict[str, Any]): A dict that describes the environment property. + observation_space (gym.Space): Observation space. + action_space (gym.Space): Action space. algorithms (Dict[str, Tuple[Type, Type, Dict]]): A dict that describes the algorithm candidates. Each is \ a tuple of `policy_cls`, `trainer_cls`, `model_config` and `custom_config`. agent_mapping_func (Callable[[AgentID], str]): A function that defines the rule of agent groupping. - governed_agents (Tuple[AgentID]): A tuple that records which agents is related to this training procedures. \ + governed_agents (Tuple[AgentID]): A tuple that records which agents is related to this learner. \ Note that it should be a subset of the original set of environment agents. trainer_config (Dict[str, Any]): Trainer configuration. custom_config (Dict[str, Any], optional): A dict of custom configuration. Defaults to None. @@ -88,15 +90,10 @@ def __init__( """ if verbose: - print("\tAssigned GPUs: {}".format(ray.get_gpu_ids())) + Logger.info("\tAssigned GPUs: {}".format(ray.get_gpu_ids())) local_buffer_config = local_buffer_config or {} device = torch.device("cuda" if ray.get_gpu_ids() else "cpu") - # a strategy spec dict, mapping from algorithm - obs_spaces = env_desc["observation_spaces"] - act_spaces = env_desc["action_spaces"] - selected_observation_space = obs_spaces[governed_agents[0]] - selected_action_space = act_spaces[governed_agents[0]] # initialize a strategy spec for policy maintainance. strategy_spec = StrategySpec( @@ -107,8 +104,8 @@ def __init__( "experiment_tag": experiment_tag, # for policy initialize "kwargs": { - "observation_space": selected_observation_space, - "action_space": selected_action_space, + "observation_space": observation_space, + "action_space": action_space, "model_config": algorithms["default"][2], "custom_config": algorithms["default"][3], "kwargs": {}, @@ -118,7 +115,6 @@ def __init__( self._runtime_id = runtime_id self._device = device - self._env_desc = env_desc self._algorithms = algorithms self._governed_agents = governed_agents self._strategy_spec = strategy_spec @@ -323,7 +319,7 @@ def train( data_request_identifier: str, reset_state: bool = True, ) -> Dict[str, Any]: - """Executes training task and returns the final interface state. + """Executes a optimization task and returns the final interface state. Args: stopping_conditions (Dict[str, Any]): Control the training stepping. @@ -333,6 +329,8 @@ def train( Dict[str, Any]: A dict that describes the final state. """ + # XXX(ming): why we need to reset the state here? I think it is not necessary as + # an optimization task should be independent with other tasks. if reset_state: self.reset() diff --git a/malib/agent/manager.py b/malib/agent/manager.py index f96d77e4..82ec1423 100644 --- a/malib/agent/manager.py +++ b/malib/agent/manager.py @@ -40,6 +40,7 @@ import os import traceback import ray +from malib.common.task import OptimizationTask from malib.utils.typing import AgentID from malib.utils.logging import Logger @@ -56,6 +57,11 @@ ) +def validate_spaces(agent_groups: Dict[str, Set[AgentID]], env_desc: Dict[str, Any]): + # TODO(ming): check whether the agents in the group share the same observation space and action space + raise NotImplementedError + + class TrainingManager(Manager): def __init__( self, @@ -97,6 +103,8 @@ def __init__( rid = agent_mapping_func(agent) agent_groups[rid].add(agent) + validate_spaces(agent_groups, env_desc) + # FIXME(ming): resource configuration is not available now, will open in the next version if training_config.trainer_config.get("use_cuda", False): num_gpus = 1 / len(agent_groups) @@ -108,7 +116,6 @@ def __init__( learner_cls = training_config.learner_type # update num gpus resource_config["num_gpus"] = num_gpus - # XXX(ming): why we hard set max_concurrency to 10? learner_cls = learner_cls.as_remote(**resource_config).options( max_concurrency=10 ) @@ -221,31 +228,33 @@ def add_policies( return strategy_spec_dict - def submit(self, task: Any): - raise NotImplementedError + def submit(self, task: OptimizationTask): + """Submit a training task, the manager will distribute it to the corresponding learners. - def run(self, data_request_identifiers: Dict[str, str]): - """Start training thread without blocking""" + Args: + task (OptimizationTask): A task description. + """ - for rid, interface in self._learners.items(): - if self._remote_mode: - task = interface.train.remote( - data_request_identifiers[rid], - self._stopping_conditions["training"], + # retrieve learners with active agents + for aid in task.active_agents: + rid = self._agent_mapping_func(aid) + if rid not in self._learners: + raise RuntimeError( + f"Agent {aid} is not registered in training manager" ) else: - task = self._thread_pool.submit( - interface.train, - data_request_identifiers[rid], - self._stopping_conditions["training"], - ) - self.pending_tasks.append(task) + learner = self._learners[rid] + if self._remote_mode: + ray_task = learner.train.remote(task) + self.pending_tasks.append(ray_task) + else: + raise NotImplementedError def retrive_results(self) -> Generator: - """Return a generator of results + """Return a generator of results. Yields: - Generator: A generator for task results + Generator: A generator for task results. """ if self._remote_mode: diff --git a/malib/backend/league.py b/malib/backend/league.py index f3d50edd..66206539 100644 --- a/malib/backend/league.py +++ b/malib/backend/league.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import traceback from malib.utils.logging import Logger @@ -9,14 +11,31 @@ def __init__(self, training_manager: Manager, rollout_manager: Manager) -> None: self.training_manager = training_manager self.rollout_manager = rollout_manager - def get_results(self): + def get_results(self) -> Dict[str, Dict[str, Any]]: + """Retrieve results from rollout and training manager. + + Returns: + Dict[str, Dict[str, Any]]: A dict of results, which contains rollout and training results. + """ + + rollout_results = [] + training_results = [] + try: while True: - # TODO(ming): check whether done - raise NotImplementedError + for result in self.rollout_manager.get_results(): + rollout_results.append(result) + for result in self.training_manager.get_results(): + training_results.append(result) except KeyboardInterrupt: Logger.info("Keyboard interruption was detected, recalling resources ...") except RuntimeError: Logger.error(traceback.format_exc()) except Exception: Logger.error(traceback.format_exc()) + + return {"rollout": rollout_results, "training": training_results} + + def terminate(self): + self.training_manager.terminate() + self.rollout_manager.terminate() diff --git a/malib/common/task.py b/malib/common/task.py index e06f654f..5e31051e 100644 --- a/malib/common/task.py +++ b/malib/common/task.py @@ -32,6 +32,29 @@ def from_raw(cls, dict_style: Union[Dict[str, Any], "RolloutTask"], **kwargs) -> @dataclass class OptimizationTask: + data_entrypoints: Dict[str, str] + """a mapping defines the data request identifier and the data entrypoint.""" + + stop_conditions: Dict[str, Any] + """stopping conditions for optimization task, e.g., max iteration, max time, etc.""" + + strategy_specs: Dict[str, Any] = field(default_factory=dict()) + """a dict of strategy specs, which defines the strategy spec for each agent.""" + + active_agents: List[AgentID] = field(default_factory=list) + """a list of active agents, which defines the agents that will be trained in this optimization task. None for all""" + @classmethod def from_raw(cls, dict_style: Union[Dict[str, Any], "OptimizationTask"], **kwargs) -> "OptimizationTask": + """Construct a OptimizationTask object from a dict or a existing OptimizationTask instance. + + Args: + dict_style (Union[Dict[str, Any], "OptimizationTask"]): A dict or a OptimizationTask instance. + + Raises: + NotImplementedError: _description_ + + Returns: + OptimizationTask: A OptimizationTask instance. + """ raise NotImplementedError diff --git a/malib/common/training_config.py b/malib/common/training_config.py index 93357cd4..023a9934 100644 --- a/malib/common/training_config.py +++ b/malib/common/training_config.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field +# TODO(ming): rename it as LearnerConfig @dataclass class TrainingConfig: diff --git a/malib/rollout/inference/ray/server.py b/malib/rollout/inference/ray/server.py index 247bee10..bc9a2487 100644 --- a/malib/rollout/inference/ray/server.py +++ b/malib/rollout/inference/ray/server.py @@ -33,6 +33,7 @@ import pickle as pkl import ray import gym +import torch from malib import settings from malib.remote.interface import RemoteInterface diff --git a/malib/scenarios/sarl_scenario.py b/malib/scenarios/sarl_scenario.py index cfa36430..470f6d5d 100644 --- a/malib/scenarios/sarl_scenario.py +++ b/malib/scenarios/sarl_scenario.py @@ -22,10 +22,9 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from types import LambdaType from typing import Dict, Any +from malib.common.task import OptimizationTask -from concurrent.futures import ThreadPoolExecutor from malib.scenarios import Scenario from malib.utils.logging import Logger @@ -65,6 +64,7 @@ def __init__( def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = True): + # TODO(ming): simplify the initialization of training and rollout manager with a scenario instance as input training_manager = TrainingManager( experiment_tag=experiment_tag, stopping_conditions=scenario.stopping_conditions, @@ -101,7 +101,12 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = assert len(data_entrypoints) == 1, "Support single agent only!" - training_manager.submit({"data_request_identifiers": data_entrypoints}) + optimization_task = OptimizationTask( + active_agents=None, + data_entrypoints=data_entrypoints, + stop_conditions=scenario.stopping_conditions["training"], + ) + training_manager.submit(optimization_task) rollout_task = { "num_workers": 1, @@ -123,4 +128,6 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = results = league.get_results() + league.terminate() + return results diff --git a/malib/utils/zero_copy.py b/malib/utils/zero_copy.py new file mode 100644 index 00000000..fbfee92d --- /dev/null +++ b/malib/utils/zero_copy.py @@ -0,0 +1,32 @@ +# https://github.com/project-codeflare/zero-copy-model-loading + +import torch +import ray + +from malib.utils.logging import Logger + +try: + import zerocopy +except ImportError: + Logger.warning("No package named zerocopy, please install it first.") + + +# the following code piece can load a BertModel in 0.004s +tmp = torch.nn.Module() +ref = ray.put(zerocopy.extract_tensors(tmp)) +model_graph, tensors = ray.get(ref) +zerocopy.replace_tensors(model_graph, tensors) + +# zero-copy method: stateless task +@ray.remote +def run_model(model_and_tensors, model_input): + model_graph, tensors = model_and_tensors + zerocopy.replace_tensors(model_graph, tensors) + with torch.inference_mode(): + return model_graph(**model_input) + + +model_result = ray.get(run_model.remote(ref, model_input)) + +async def get_model_result(model_ref, model_input): + return await zerocopy.call_model.remote(model_ref, [], model_input) \ No newline at end of file diff --git a/setup.py b/setup.py index 8352ca7b..43969866 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,7 @@ "pygame==2.1.0", "pettingzoo", "networkx", + "zerocopy>=0.1.0" ], extras_require={ "dev": [ From 0ccbf8aa7c4fc498cc8a436add4ac8f1894844f4 Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Mon, 9 Oct 2023 19:10:09 +0800 Subject: [PATCH 09/24] tmp save --- malib/agent/agent_interface.py | 40 ++++++-------- malib/agent/indepdent_agent.py | 1 - malib/agent/manager.py | 16 ++++-- malib/backend/data_loader.py | 69 ++++++++++++++++++++++++ malib/backend/league.py | 36 +++++++++++-- malib/common/rollout_config.py | 8 +-- malib/common/task.py | 12 +++-- malib/common/training_config.py | 5 +- malib/rl/pg/trainer.py | 1 - malib/rl/random/policy.py | 1 - malib/rollout/envs/pettingzoo/env.py | 1 - malib/rollout/envs/sc2/env.py | 2 - malib/rollout/envs/vector_env.py | 1 - malib/rollout/manager.py | 4 +- malib/rollout/rolloutworker.py | 2 +- malib/scenarios/sarl_scenario.py | 2 +- malib/utils/data.py | 1 - malib/utils/general.py | 20 +++++++ malib/utils/zero_copy.py | 4 +- setup.py | 2 +- tests/backend/test_dataset_server.py | 1 - tests/malib_utils/test_payoff_manager.py | 1 - tests/rl/test_dqn.py | 1 - tests/rl/test_policy.py | 1 - 24 files changed, 174 insertions(+), 58 deletions(-) create mode 100644 malib/backend/data_loader.py diff --git a/malib/agent/agent_interface.py b/malib/agent/agent_interface.py index 012fc213..3cac506f 100644 --- a/malib/agent/agent_interface.py +++ b/malib/agent/agent_interface.py @@ -41,6 +41,7 @@ from malib import settings from malib.backend.offline_dataset_server import OfflineDataset +from malib.backend.data_loader import RLDataLoader from malib.backend.parameter_server import ParameterServer from malib.utils.typing import AgentID from malib.utils.logging import Logger @@ -69,6 +70,7 @@ def __init__( custom_config: Dict[str, Any] = None, local_buffer_config: Dict = None, verbose: bool = True, + dataloader: RLDataLoader = None, ): """Construct agent interface for training. @@ -130,7 +132,9 @@ def __init__( self._offline_dataset: OfflineDataset = None self._parameter_server: ParameterServer = None + self._dataloader = dataloader or self.create_dataloader() self._active_tups = deque() + self.verbose = verbose @property @@ -153,6 +157,18 @@ def device(self) -> Union[str, torch.DeviceObjType]: return self._device + def create_dataloader(self) -> RLDataLoader: + """Create a data loader instance. + + Raises: + NotImplementedError: Raise if this method is not implemented. + + Returns: + RLDataLoader: A data loader instance. + """ + + raise NotImplementedError + def connect( self, max_tries: int = 10, @@ -215,30 +231,6 @@ def add_policies(self, n: int) -> StrategySpec: return self._strategy_spec - def get_algorithm(self, key: str) -> Any: # pragma: no cover - """Return a copy of algorithm configuration with given key, if not exist, raise KeyError. - - Args: - key (str): Algorithm configuration reference key. - - Raises: - KeyError: No such an algorithm configuration relates to the give key. - - Returns: - Any: Algorithm configuration, mabe a dict. - """ - - return copy.deepcopy(self._algorithms[key]) - - def get_algorthms(self) -> Dict[str, Any]: # pragma: no_cover - """Return a copy of full algorithm configurations. - - Returns: - Dict[str, Any]: Full algorithm configurations. - """ - - return copy.deepcopy(self._algorithms) - def push(self): """Push local weights to remote server""" diff --git a/malib/agent/indepdent_agent.py b/malib/agent/indepdent_agent.py index 0cb55cad..82352b80 100644 --- a/malib/agent/indepdent_agent.py +++ b/malib/agent/indepdent_agent.py @@ -64,7 +64,6 @@ def multiagent_post_process( Dict[AgentID, Tuple[Batch, List[int]]], Tuple[Batch, List[int]] ], ) -> Dict[str, Any]: - if not isinstance(batch_info, Tuple): raise TypeError( "IndependentAgent support only a tuple of batch info as input." diff --git a/malib/agent/manager.py b/malib/agent/manager.py index 82ec1423..4fd211c3 100644 --- a/malib/agent/manager.py +++ b/malib/agent/manager.py @@ -144,6 +144,8 @@ def __init__( if remote_mode: _ = ray.get([x.connect.remote() for x in learners.values()]) + # TODO(ming): collect data entrypoints from learners + self._agent_groups = agent_groups self._runtime_ids = tuple(self._agent_groups.keys()) self._experiment_tag = experiment_tag @@ -170,6 +172,16 @@ def agent_groups(self) -> Dict[str, Set[AgentID]]: return self._agent_groups + @property + def get_data_entrypoints(self) -> Dict[str, str]: + """Return a dict of data entrypoints, maps from runtime ids to data entrypoints. + + Returns: + Dict[str, str]: A dict of data entrypoints. + """ + + return {rid: rid for rid in self._runtime_ids} + @property def workers(self) -> List[RemoteInterface]: """A list of learner instance @@ -239,9 +251,7 @@ def submit(self, task: OptimizationTask): for aid in task.active_agents: rid = self._agent_mapping_func(aid) if rid not in self._learners: - raise RuntimeError( - f"Agent {aid} is not registered in training manager" - ) + raise RuntimeError(f"Agent {aid} is not registered in training manager") else: learner = self._learners[rid] if self._remote_mode: diff --git a/malib/backend/data_loader.py b/malib/backend/data_loader.py new file mode 100644 index 00000000..17a0e56c --- /dev/null +++ b/malib/backend/data_loader.py @@ -0,0 +1,69 @@ +from typing import Dict, Any, Union + +import pyarrow.flight as flight + +from malib.utils.logging import Logger +from malib.utils.general import find_free_port + + +class FlightServer(flight.FlightServerBase): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def list_flights( + self, context: flight.ServerCallContext, criteria: bytes + ) -> flight.FlightInfo: + info = flight.FlightInfo(...) + yield info + + def do_action(self, context: flight.ServerCallContext, action: flight.Action): + """Execute a custom action. This method should return an iterator, or it should be a generator. Applications should override this method to implement their own behavior. The default method raises a NotImplementedError. + + Args: + context (_type_): _description_ + action (_type_): _description_ + + Raises: + NotImplementedError: _description_ + """ + + raise NotImplementedError + + def do_exchange( + self, + context: flight.ServerCallContext, + descriptor: flight.FlightDescriptor, + reader: flight.MetadataRecordBatchReader, + writer: flight.MetadataRecordBatchWriter, + ): + raise NotImplementedError + + def do_put( + self, + context: flight.ServerCallContext, + descriptor: flight.FlightDescriptor, + reader: flight.MetadataRecordBatchReader, + writer: flight.FlightMetadataWriter, + ): + """Write data to a flight.""" + + +class DataCommunicator: + def __init__(self, flight_server_address: str) -> None: + self.flight_conn: flight.FlightClient = flight.connect(flight_server_address) + + def send(self, data): + self.flight_conn.do_put(data) + + def get(self, batch_size: int): + raise NotImplementedError + + def close(self): + self.flight_conn.close() + + +if __name__ == "__main__": + port = find_free_port() + flight_server = FlightServer(f"grpc://0.0.0.0:{port}") + Logger.info(f"Flight server listening on {port}") + flight_server.serve() diff --git a/malib/backend/league.py b/malib/backend/league.py index 66206539..b52d08d8 100644 --- a/malib/backend/league.py +++ b/malib/backend/league.py @@ -1,6 +1,10 @@ -from typing import Any, Dict - +from typing import Any, Dict, List +from concurrent import futures +import threading import traceback +import ray + +from readerwriterlock import rwlock from malib.utils.logging import Logger from malib.common.manager import Manager @@ -10,6 +14,30 @@ class League: def __init__(self, training_manager: Manager, rollout_manager: Manager) -> None: self.training_manager = training_manager self.rollout_manager = rollout_manager + self.flight_servers = [] + self.rw_lock = rwlock.RWLockFair() + self.event = threading.Event() + self.thread_pool = futures.ThreadPoolExecutor() + + def register_flight_server(self, flight_server_address: str): + raise NotImplementedError + + def list_flight_servers(self) -> List[str]: + raise NotImplementedError + + def _flight_server_check(self): + while not self.event.is_set(): + with self.rw_lock.gen_rlock(): + for flight_server in self.flight_servers: + if not ray.util.check_connection(flight_server): + self.flight_servers.remove(flight_server) + self.event.wait(10) + + def list_learners(self): + return self.training_manager.workers() + + def list_rollout_workers(self): + return self.rollout_manager.workers() def get_results(self) -> Dict[str, Dict[str, Any]]: """Retrieve results from rollout and training manager. @@ -35,7 +63,9 @@ def get_results(self) -> Dict[str, Dict[str, Any]]: Logger.error(traceback.format_exc()) return {"rollout": rollout_results, "training": training_results} - + def terminate(self): + self.event.set() + self.thread_pool.shutdown() self.training_manager.terminate() self.rollout_manager.terminate() diff --git a/malib/common/rollout_config.py b/malib/common/rollout_config.py index c238182f..4b4191d7 100644 --- a/malib/common/rollout_config.py +++ b/malib/common/rollout_config.py @@ -5,12 +5,12 @@ @dataclass class RolloutConfig: - inference_server_type: str - @classmethod - def from_raw(cls, config: Union["RolloutConfig", Dict[str, Any]]) -> "RolloutConfig": + def from_raw( + cls, config: Union["RolloutConfig", Dict[str, Any]] + ) -> "RolloutConfig": """Cat dict-style configuration to RolloutConfig instance Args: @@ -28,4 +28,4 @@ def from_raw(cls, config: Union["RolloutConfig", Dict[str, Any]]) -> "RolloutCon elif isinstance(config, Dict): return cls(**config) else: - raise RuntimeError(f"Unexpected rollout config type: {type(config)}") \ No newline at end of file + raise RuntimeError(f"Unexpected rollout config type: {type(config)}") diff --git a/malib/common/task.py b/malib/common/task.py index 5e31051e..007b6fb2 100644 --- a/malib/common/task.py +++ b/malib/common/task.py @@ -14,13 +14,14 @@ class TaskType(IntEnum): @dataclass class RolloutTask: - task_type: int active_agents: List[AgentID] strategy_specs: Dict[str, Any] = field(default_factory=dict()) - + @classmethod - def from_raw(cls, dict_style: Union[Dict[str, Any], "RolloutTask"], **kwargs) -> "RolloutTask": + def from_raw( + cls, dict_style: Union[Dict[str, Any], "RolloutTask"], **kwargs + ) -> "RolloutTask": if isinstance(dict_style, Dict): return cls(**dict_style, **kwargs) elif isinstance(dict_style, cls): @@ -31,7 +32,6 @@ def from_raw(cls, dict_style: Union[Dict[str, Any], "RolloutTask"], **kwargs) -> @dataclass class OptimizationTask: - data_entrypoints: Dict[str, str] """a mapping defines the data request identifier and the data entrypoint.""" @@ -45,7 +45,9 @@ class OptimizationTask: """a list of active agents, which defines the agents that will be trained in this optimization task. None for all""" @classmethod - def from_raw(cls, dict_style: Union[Dict[str, Any], "OptimizationTask"], **kwargs) -> "OptimizationTask": + def from_raw( + cls, dict_style: Union[Dict[str, Any], "OptimizationTask"], **kwargs + ) -> "OptimizationTask": """Construct a OptimizationTask object from a dict or a existing OptimizationTask instance. Args: diff --git a/malib/common/training_config.py b/malib/common/training_config.py index 023a9934..4fc1ae0e 100644 --- a/malib/common/training_config.py +++ b/malib/common/training_config.py @@ -6,13 +6,14 @@ # TODO(ming): rename it as LearnerConfig @dataclass class TrainingConfig: - trainer_config: Dict[str, Any] learner_type: str custom_config: Dict[str, Any] = field(default_factory=dict()) @classmethod - def from_raw(cls, config: Union["TrainingConfig", Dict[str, Any]]) -> "TrainingConfig": + def from_raw( + cls, config: Union["TrainingConfig", Dict[str, Any]] + ) -> "TrainingConfig": """Cat dict-style configuration to TrainingConfig instance Args: diff --git a/malib/rl/pg/trainer.py b/malib/rl/pg/trainer.py index 95602351..1a8d4091 100644 --- a/malib/rl/pg/trainer.py +++ b/malib/rl/pg/trainer.py @@ -45,7 +45,6 @@ def setup(self): self.ret_rms = None def post_process(self, batch: Batch, agent_filter: Sequence[AgentID]) -> Batch: - # v_s_ = np.full(indices.shape, self.ret_rms.mean) unnormalized_returns, _ = Postprocessor.compute_episodic_return( batch, gamma=self.training_config["gamma"], gae_lambda=1.0 diff --git a/malib/rl/random/policy.py b/malib/rl/random/policy.py index 5ea9a209..cadea858 100644 --- a/malib/rl/random/policy.py +++ b/malib/rl/random/policy.py @@ -14,7 +14,6 @@ def __init__( custom_config: Dict[str, Any], **kwargs ): - super().__init__( observation_space, action_space, model_config, custom_config, **kwargs ) diff --git a/malib/rollout/envs/pettingzoo/env.py b/malib/rollout/envs/pettingzoo/env.py index 56e53f1d..d9a63cd6 100644 --- a/malib/rollout/envs/pettingzoo/env.py +++ b/malib/rollout/envs/pettingzoo/env.py @@ -83,7 +83,6 @@ def time_step( Dict[AgentID, bool], Dict[AgentID, Any], ]: - if not self.parallel_simulate: self.env.step(actions[self.env.agent_selection]) rewards = self.env.rewards.copy() diff --git a/malib/rollout/envs/sc2/env.py b/malib/rollout/envs/sc2/env.py index bb26e261..c51c207a 100644 --- a/malib/rollout/envs/sc2/env.py +++ b/malib/rollout/envs/sc2/env.py @@ -154,7 +154,6 @@ def close(self): def StatedSC2(**config): - env = SC2Env(**config) class Wrapped(GroupWrapper): @@ -184,7 +183,6 @@ def group_rule(self, agent_id: AgentID) -> str: if __name__ == "__main__": - env = SC2Env(env_id="3m") state, obs = env.reset() diff --git a/malib/rollout/envs/vector_env.py b/malib/rollout/envs/vector_env.py index f887eed3..e5ebafe2 100644 --- a/malib/rollout/envs/vector_env.py +++ b/malib/rollout/envs/vector_env.py @@ -236,7 +236,6 @@ def is_terminated(self): def action_adapter( self, policy_outputs: List[Dict[str, Dict[AgentID, Any]]] ) -> List[Dict[AgentID, Any]]: - # since activ_envs maybe updated after self.step, so we should use keys # in self.active_envs res = defaultdict(list) diff --git a/malib/rollout/manager.py b/malib/rollout/manager.py index c0819bab..e6b39471 100644 --- a/malib/rollout/manager.py +++ b/malib/rollout/manager.py @@ -182,7 +182,9 @@ def submit(self, task: Union[Dict[str, Any], List[Dict[str, Any]]], task_type: A for _task in task: validate_strategy_specs(_task.strategy_specs) - self._actor_pool.submit(lambda actor, _task: actor.rollout.remote(_task, stopping_conditions)) + self._actor_pool.submit( + lambda actor, _task: actor.rollout.remote(_task, stopping_conditions) + ) def _rollout(self, task_list: List[Dict[str, Any]]) -> None: """Start rollout task without blocking. diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index 5f8a9d19..8c406827 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -396,7 +396,7 @@ def rollout( data_entrypoints: Dict[str, str] = None, active_agents: List[AgentID] = None, ): - """Rollout, collecting training data when `data_entrypoints` is given, until meets the stopping conditions. The `active_agents` should be None or a none-empty list to specify active agents if rollout is not serve for evaluation. + """Rollout, collecting training data when `data_entrypoints` is given, until meets the stopping conditions. The `active_agents` should be None or a none-empty list to specify active agents if rollout is not serve for evaluation. NOTE: the data collection will be triggered only for active agents. diff --git a/malib/scenarios/sarl_scenario.py b/malib/scenarios/sarl_scenario.py index 470f6d5d..d5f0ebd7 100644 --- a/malib/scenarios/sarl_scenario.py +++ b/malib/scenarios/sarl_scenario.py @@ -113,7 +113,7 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = "runtime_strategy_specs": strategy_specs, "data_entrypoints": None, "rollout_config": scenario.rollout_config, - "active_agents": None + "active_agents": None, } evaluation_task = { "num_workers": 1, diff --git a/malib/utils/data.py b/malib/utils/data.py index 62552812..2ba06c84 100644 --- a/malib/utils/data.py +++ b/malib/utils/data.py @@ -176,7 +176,6 @@ def gae_return( gamma: float = 0.99, gae_lambda: float = 0.95, ): - adv = _gae_return( state_value, next_state_value, reward, done, gamma, gae_lambda ) diff --git a/malib/utils/general.py b/malib/utils/general.py index b7c9e250..2bf25ada 100644 --- a/malib/utils/general.py +++ b/malib/utils/general.py @@ -341,6 +341,26 @@ def frozen_data(data): return _hash +import socket + + +def find_free_port(random_port: int = 8000) -> int: + """Find a free port. + + Args: + random_port (int, optional): Given a random port. Defaults to 8000. + + Returns: + int: Port number + """ + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + if s.connect_ex(("localhost", random_port)) == 0: + return find_free_port(random_port + 1) + else: + return random_port + + # =============== below refer to: https://docs.ray.io/en/releases-1.9.1/_modules/ray/util/ml_utils/dict.html#merge_dicts def merge_dicts(d1: dict, d2: dict) -> dict: """ diff --git a/malib/utils/zero_copy.py b/malib/utils/zero_copy.py index fbfee92d..5524374c 100644 --- a/malib/utils/zero_copy.py +++ b/malib/utils/zero_copy.py @@ -17,6 +17,7 @@ model_graph, tensors = ray.get(ref) zerocopy.replace_tensors(model_graph, tensors) + # zero-copy method: stateless task @ray.remote def run_model(model_and_tensors, model_input): @@ -28,5 +29,6 @@ def run_model(model_and_tensors, model_input): model_result = ray.get(run_model.remote(ref, model_input)) + async def get_model_result(model_ref, model_input): - return await zerocopy.call_model.remote(model_ref, [], model_input) \ No newline at end of file + return await zerocopy.call_model.remote(model_ref, [], model_input) diff --git a/setup.py b/setup.py index 43969866..cd0f3190 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ "pygame==2.1.0", "pettingzoo", "networkx", - "zerocopy>=0.1.0" + "zerocopy>=0.1.0", ], extras_require={ "dev": [ diff --git a/tests/backend/test_dataset_server.py b/tests/backend/test_dataset_server.py index a3ed916f..d51a4697 100644 --- a/tests/backend/test_dataset_server.py +++ b/tests/backend/test_dataset_server.py @@ -126,7 +126,6 @@ def test_datatable_read_and_write(read_size: int, write_size: int): def test_offline_dataset(): - if not ray.is_initialized(): ray.init() diff --git a/tests/malib_utils/test_payoff_manager.py b/tests/malib_utils/test_payoff_manager.py index 1b92becb..ddd93979 100644 --- a/tests/malib_utils/test_payoff_manager.py +++ b/tests/malib_utils/test_payoff_manager.py @@ -44,7 +44,6 @@ def test_default_solver(solve_method: str): @pytest.mark.parametrize("n_player", [2, 4]) def test_payoff_table(n_player: int): - agents = [f"player_{i}" for i in range(n_player)] # start from one policy each player shape = [0] * n_player diff --git a/tests/rl/test_dqn.py b/tests/rl/test_dqn.py index a5ac2438..08eca69c 100644 --- a/tests/rl/test_dqn.py +++ b/tests/rl/test_dqn.py @@ -82,7 +82,6 @@ def test_dqn_policy( def test_dqn_trainer(): - num_agents = 4 batch_size = 64 diff --git a/tests/rl/test_policy.py b/tests/rl/test_policy.py index 8bc42bea..33f3eb9a 100644 --- a/tests/rl/test_policy.py +++ b/tests/rl/test_policy.py @@ -71,7 +71,6 @@ def test_interface_calling( model_config: Dict[str, Any], custom_config: Dict[str, Any], ): - policy_caller = partial( FakePolicy, observation_space, action_space, model_config, custom_config ) From 967a49326740b36a670ea060e8abd706afd7083c Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Mon, 23 Oct 2023 20:08:05 +0800 Subject: [PATCH 10/24] tmp save --- Makefile | 4 + malib/backend/data_loader.py | 69 ---------------- malib/backend/dataset_server/__init__.py | 0 malib/backend/dataset_server/data_loader.py | 64 +++++++++++++++ malib/backend/dataset_server/data_pb2.py | 30 +++++++ malib/backend/dataset_server/data_pb2.pyi | 21 +++++ malib/backend/dataset_server/data_pb2_grpc.py | 79 +++++++++++++++++++ malib/backend/dataset_server/feature.py | 30 +++++++ malib/backend/dataset_server/service.py | 22 ++++++ malib/backend/dataset_server/utils.py | 12 +++ malib/backend/protos/data.proto | 16 ++++ setup.py | 1 + 12 files changed, 279 insertions(+), 69 deletions(-) delete mode 100644 malib/backend/data_loader.py create mode 100644 malib/backend/dataset_server/__init__.py create mode 100644 malib/backend/dataset_server/data_loader.py create mode 100644 malib/backend/dataset_server/data_pb2.py create mode 100644 malib/backend/dataset_server/data_pb2.pyi create mode 100644 malib/backend/dataset_server/data_pb2_grpc.py create mode 100644 malib/backend/dataset_server/feature.py create mode 100644 malib/backend/dataset_server/service.py create mode 100644 malib/backend/dataset_server/utils.py create mode 100644 malib/backend/protos/data.proto diff --git a/Makefile b/Makefile index 5e55e3aa..ec212690 100644 --- a/Makefile +++ b/Makefile @@ -62,3 +62,7 @@ coverage-view: test-verbose: pytest --cov-config=.coveragerc --cov=malib --cov-report html --doctest-modules tests -v -s rm -f .coverage.* + +.PHONY: compile +compile: + python -m grpc_tools.protoc -I malib/backend/protos --python_out=malib/backend/dataset_server --pyi_out=malib/backend/dataset_server --grpc_python_out=malib/backend/dataset_server malib/backend/protos/data.proto \ No newline at end of file diff --git a/malib/backend/data_loader.py b/malib/backend/data_loader.py deleted file mode 100644 index 17a0e56c..00000000 --- a/malib/backend/data_loader.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Dict, Any, Union - -import pyarrow.flight as flight - -from malib.utils.logging import Logger -from malib.utils.general import find_free_port - - -class FlightServer(flight.FlightServerBase): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def list_flights( - self, context: flight.ServerCallContext, criteria: bytes - ) -> flight.FlightInfo: - info = flight.FlightInfo(...) - yield info - - def do_action(self, context: flight.ServerCallContext, action: flight.Action): - """Execute a custom action. This method should return an iterator, or it should be a generator. Applications should override this method to implement their own behavior. The default method raises a NotImplementedError. - - Args: - context (_type_): _description_ - action (_type_): _description_ - - Raises: - NotImplementedError: _description_ - """ - - raise NotImplementedError - - def do_exchange( - self, - context: flight.ServerCallContext, - descriptor: flight.FlightDescriptor, - reader: flight.MetadataRecordBatchReader, - writer: flight.MetadataRecordBatchWriter, - ): - raise NotImplementedError - - def do_put( - self, - context: flight.ServerCallContext, - descriptor: flight.FlightDescriptor, - reader: flight.MetadataRecordBatchReader, - writer: flight.FlightMetadataWriter, - ): - """Write data to a flight.""" - - -class DataCommunicator: - def __init__(self, flight_server_address: str) -> None: - self.flight_conn: flight.FlightClient = flight.connect(flight_server_address) - - def send(self, data): - self.flight_conn.do_put(data) - - def get(self, batch_size: int): - raise NotImplementedError - - def close(self): - self.flight_conn.close() - - -if __name__ == "__main__": - port = find_free_port() - flight_server = FlightServer(f"grpc://0.0.0.0:{port}") - Logger.info(f"Flight server listening on {port}") - flight_server.serve() diff --git a/malib/backend/dataset_server/__init__.py b/malib/backend/dataset_server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/malib/backend/dataset_server/data_loader.py b/malib/backend/dataset_server/data_loader.py new file mode 100644 index 00000000..72b23651 --- /dev/null +++ b/malib/backend/dataset_server/data_loader.py @@ -0,0 +1,64 @@ +from typing import Type, Any + +import grpc + +from concurrent import futures +from torch.utils.data import DataLoader, Dataset + +from malib.utils.general import find_free_port + +from .service import DatasetServer +from . import data_pb2_grpc +from .feature import BaseFeature + + +class EmptyError(Exception): + pass + + +class DynamicDataset(Dataset): + def __init__( + self, + grpc_thread_num_workers: int, + max_message_length: int, + feature_handler_caller: Type, + ) -> None: + super().__init__() + + # start a service as thread + self.thread_pool = futures.ThreadPoolExecutor(max_workers=2) + self.thread_pool.submit( + self._start_servicer, + grpc_thread_num_workers, + max_message_length, + find_free_port(), + ) + self.feature_handler: BaseFeature = feature_handler_caller() + + def _start_servicer( + self, max_workers: int, max_message_length: int, grpc_port: int + ): + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=max_workers), + options=[ + ("grpc.max_send_message_length", max_message_length), + ("grpc.max_receive_message_length", max_message_length), + ], + ) + servicer = DatasetServer(self.feature_handler) + data_pb2_grpc.add_SendDataServicer_to_server(servicer, server) + + server.add_insecure_port(f"[::]:{grpc_port}") + server.start() + + def __len__(self): + return self.feature_handler_caller.block_size + + def __getitem__(self, index) -> Any: + if index >= len(self): + raise IndexError + + if len(self.feature_handler) == 0: + raise EmptyError(f"No available data for sampling") + + return self.feature_handler.safe_get(index) diff --git a/malib/backend/dataset_server/data_pb2.py b/malib/backend/dataset_server/data_pb2.py new file mode 100644 index 00000000..68bac728 --- /dev/null +++ b/malib/backend/dataset_server/data_pb2.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: data.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\ndata.proto\x12\x04\x64\x61ta"%\n\x04\x44\x61ta\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x0f\n\x07message\x18\x02 \x01(\t"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\t20\n\x08SendData\x12$\n\x07\x63ollect\x12\n.data.Data\x1a\x0b.data.Reply"\x00\x62\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "data_pb2", _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals["_DATA"]._serialized_start = 20 + _globals["_DATA"]._serialized_end = 57 + _globals["_REPLY"]._serialized_start = 59 + _globals["_REPLY"]._serialized_end = 83 + _globals["_SENDDATA"]._serialized_start = 85 + _globals["_SENDDATA"]._serialized_end = 133 +# @@protoc_insertion_point(module_scope) diff --git a/malib/backend/dataset_server/data_pb2.pyi b/malib/backend/dataset_server/data_pb2.pyi new file mode 100644 index 00000000..c992b28e --- /dev/null +++ b/malib/backend/dataset_server/data_pb2.pyi @@ -0,0 +1,21 @@ +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Optional as _Optional + +DESCRIPTOR: _descriptor.FileDescriptor + +class Data(_message.Message): + __slots__ = ["data", "message"] + DATA_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + data: bytes + message: str + def __init__( + self, data: _Optional[bytes] = ..., message: _Optional[str] = ... + ) -> None: ... + +class Reply(_message.Message): + __slots__ = ["message"] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + message: str + def __init__(self, message: _Optional[str] = ...) -> None: ... diff --git a/malib/backend/dataset_server/data_pb2_grpc.py b/malib/backend/dataset_server/data_pb2_grpc.py new file mode 100644 index 00000000..d5e49398 --- /dev/null +++ b/malib/backend/dataset_server/data_pb2_grpc.py @@ -0,0 +1,79 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import data_pb2 as data__pb2 + + +class SendDataStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.collect = channel.unary_unary( + "/data.SendData/collect", + request_serializer=data__pb2.Data.SerializeToString, + response_deserializer=data__pb2.Reply.FromString, + ) + + +class SendDataServicer(object): + """Missing associated documentation comment in .proto file.""" + + def collect(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_SendDataServicer_to_server(servicer, server): + rpc_method_handlers = { + "collect": grpc.unary_unary_rpc_method_handler( + servicer.collect, + request_deserializer=data__pb2.Data.FromString, + response_serializer=data__pb2.Reply.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "data.SendData", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class SendData(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def collect( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/data.SendData/collect", + data__pb2.Data.SerializeToString, + data__pb2.Reply.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/malib/backend/dataset_server/feature.py b/malib/backend/dataset_server/feature.py new file mode 100644 index 00000000..3276e96d --- /dev/null +++ b/malib/backend/dataset_server/feature.py @@ -0,0 +1,30 @@ +from typing import Any +from readerwriterlock import rwlock + + +class BaseFeature: + def __init__(self) -> None: + self.rw_lock = rwlock.RWLockFair() + self._readable_index = [] + self._writable_index = [] + + @property + def block_size(self) -> int: + raise NotImplementedError + + def __len__(self): + return len(self._readable_index) + + def _get(self, index: int): + raise NotImplementedError + + def safe_get(self, index: int): + with self.rw_lock.gen_rlock(): + return self._get(index) + + def _write(self, data: Any): + raise NotImplementedError + + def safe_put(self, data: Any): + with self.rw_lock.gen_wlock(): + self._write(data) diff --git a/malib/backend/dataset_server/service.py b/malib/backend/dataset_server/service.py new file mode 100644 index 00000000..cc00f5d6 --- /dev/null +++ b/malib/backend/dataset_server/service.py @@ -0,0 +1,22 @@ +import pickle +import traceback + +from . import data_pb2_grpc +from . import data_pb2 +from . import feature + + +class DatasetServer(data_pb2_grpc.SendDataServicer): + def __init__(self, feature_handler: feature.BaseFeature) -> None: + super().__init__() + self.feature_handler = feature_handler + + def collect(self, request, context): + try: + data = pickle.loads(request.data) + self.feature_handler.safe_put(data) + message = "success" + except Exception as e: + message = traceback.format_exc() + + return data_pb2.Reply(message=message) diff --git a/malib/backend/dataset_server/utils.py b/malib/backend/dataset_server/utils.py new file mode 100644 index 00000000..2920e9be --- /dev/null +++ b/malib/backend/dataset_server/utils.py @@ -0,0 +1,12 @@ +import grpc + +from . import data_pb2 +from . import data_pb2_grpc + + +def send_data(host: str, port: int, data: bytes): + with grpc.insecure_channel(f"{host}:{port}") as channel: + stub = data_pb2_grpc.SendDataStub(channel) + reply = stub.collect(data_pb2.Data(data=data)) + + return reply.message diff --git a/malib/backend/protos/data.proto b/malib/backend/protos/data.proto new file mode 100644 index 00000000..ff32d444 --- /dev/null +++ b/malib/backend/protos/data.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package data; + +message Data { + bytes data = 1; + string message = 2; +} + +message Reply { + string message = 1; +} + +service SendData { + rpc collect (Data) returns (Reply) {} +} diff --git a/setup.py b/setup.py index cd0f3190..f46f707f 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,7 @@ "pettingzoo", "networkx", "zerocopy>=0.1.0", + "grpcio-tools>=1.59.0", ], extras_require={ "dev": [ From 19ed9ab1123536bdcda3095056d898db111b18ba Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Tue, 24 Oct 2023 20:30:48 +0800 Subject: [PATCH 11/24] tmp save --- malib/backend/dataset_server/data_loader.py | 12 ++-- malib/backend/dataset_server/data_pb2.py | 2 +- malib/backend/dataset_server/data_pb2_grpc.py | 14 ++--- malib/backend/dataset_server/service.py | 12 +++- malib/backend/dataset_server/utils.py | 10 ++- malib/backend/protos/data.proto | 2 +- malib/rollout/inference/model_server.py | 62 +++++++++++++++++++ 7 files changed, 96 insertions(+), 18 deletions(-) create mode 100644 malib/rollout/inference/model_server.py diff --git a/malib/backend/dataset_server/data_loader.py b/malib/backend/dataset_server/data_loader.py index 72b23651..f151190d 100644 --- a/malib/backend/dataset_server/data_loader.py +++ b/malib/backend/dataset_server/data_loader.py @@ -1,5 +1,6 @@ from typing import Type, Any +import threading import grpc from concurrent import futures @@ -26,14 +27,12 @@ def __init__( super().__init__() # start a service as thread - self.thread_pool = futures.ThreadPoolExecutor(max_workers=2) - self.thread_pool.submit( - self._start_servicer, + self.feature_handler: BaseFeature = feature_handler_caller() + self.server = self._start_servicer( grpc_thread_num_workers, max_message_length, find_free_port(), ) - self.feature_handler: BaseFeature = feature_handler_caller() def _start_servicer( self, max_workers: int, max_message_length: int, grpc_port: int @@ -51,6 +50,8 @@ def _start_servicer( server.add_insecure_port(f"[::]:{grpc_port}") server.start() + return server + def __len__(self): return self.feature_handler_caller.block_size @@ -62,3 +63,6 @@ def __getitem__(self, index) -> Any: raise EmptyError(f"No available data for sampling") return self.feature_handler.safe_get(index) + + def close(self): + self.server.stop() diff --git a/malib/backend/dataset_server/data_pb2.py b/malib/backend/dataset_server/data_pb2.py index 68bac728..22cf1c3c 100644 --- a/malib/backend/dataset_server/data_pb2.py +++ b/malib/backend/dataset_server/data_pb2.py @@ -13,7 +13,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\ndata.proto\x12\x04\x64\x61ta"%\n\x04\x44\x61ta\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x0f\n\x07message\x18\x02 \x01(\t"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\t20\n\x08SendData\x12$\n\x07\x63ollect\x12\n.data.Data\x1a\x0b.data.Reply"\x00\x62\x06proto3' + b'\n\ndata.proto\x12\x04\x64\x61ta"%\n\x04\x44\x61ta\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x0f\n\x07message\x18\x02 \x01(\t"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\t20\n\x08SendData\x12$\n\x07\x43ollect\x12\n.data.Data\x1a\x0b.data.Reply"\x00\x62\x06proto3' ) _globals = globals() diff --git a/malib/backend/dataset_server/data_pb2_grpc.py b/malib/backend/dataset_server/data_pb2_grpc.py index d5e49398..1a3cad02 100644 --- a/malib/backend/dataset_server/data_pb2_grpc.py +++ b/malib/backend/dataset_server/data_pb2_grpc.py @@ -14,8 +14,8 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ - self.collect = channel.unary_unary( - "/data.SendData/collect", + self.Collect = channel.unary_unary( + "/data.SendData/Collect", request_serializer=data__pb2.Data.SerializeToString, response_deserializer=data__pb2.Reply.FromString, ) @@ -24,7 +24,7 @@ def __init__(self, channel): class SendDataServicer(object): """Missing associated documentation comment in .proto file.""" - def collect(self, request, context): + def Collect(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") @@ -33,8 +33,8 @@ def collect(self, request, context): def add_SendDataServicer_to_server(servicer, server): rpc_method_handlers = { - "collect": grpc.unary_unary_rpc_method_handler( - servicer.collect, + "Collect": grpc.unary_unary_rpc_method_handler( + servicer.Collect, request_deserializer=data__pb2.Data.FromString, response_serializer=data__pb2.Reply.SerializeToString, ), @@ -50,7 +50,7 @@ class SendData(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def collect( + def Collect( request, target, options=(), @@ -65,7 +65,7 @@ def collect( return grpc.experimental.unary_unary( request, target, - "/data.SendData/collect", + "/data.SendData/Collect", data__pb2.Data.SerializeToString, data__pb2.Reply.FromString, options, diff --git a/malib/backend/dataset_server/service.py b/malib/backend/dataset_server/service.py index cc00f5d6..d2af656c 100644 --- a/malib/backend/dataset_server/service.py +++ b/malib/backend/dataset_server/service.py @@ -1,5 +1,6 @@ -import pickle +import threading import traceback +import pickle from . import data_pb2_grpc from . import data_pb2 @@ -7,11 +8,16 @@ class DatasetServer(data_pb2_grpc.SendDataServicer): - def __init__(self, feature_handler: feature.BaseFeature) -> None: + def __init__( + self, + feature_handler: feature.BaseFeature, + service_event: threading.Event = None, + ) -> None: super().__init__() self.feature_handler = feature_handler + self.service_event = service_event - def collect(self, request, context): + def Collect(self, request, context): try: data = pickle.loads(request.data) self.feature_handler.safe_put(data) diff --git a/malib/backend/dataset_server/utils.py b/malib/backend/dataset_server/utils.py index 2920e9be..190d2360 100644 --- a/malib/backend/dataset_server/utils.py +++ b/malib/backend/dataset_server/utils.py @@ -1,12 +1,18 @@ +from typing import Any, Union + +import pickle import grpc from . import data_pb2 from . import data_pb2_grpc -def send_data(host: str, port: int, data: bytes): +def send_data(host: str, port: int, data: Any): + if not isinstance(data, bytes): + data = pickle.dumps(data) + with grpc.insecure_channel(f"{host}:{port}") as channel: stub = data_pb2_grpc.SendDataStub(channel) - reply = stub.collect(data_pb2.Data(data=data)) + reply = stub.Collect(data_pb2.Data(data=data)) return reply.message diff --git a/malib/backend/protos/data.proto b/malib/backend/protos/data.proto index ff32d444..7ac5e143 100644 --- a/malib/backend/protos/data.proto +++ b/malib/backend/protos/data.proto @@ -12,5 +12,5 @@ message Reply { } service SendData { - rpc collect (Data) returns (Reply) {} + rpc Collect (Data) returns (Reply) {} } diff --git a/malib/rollout/inference/model_server.py b/malib/rollout/inference/model_server.py new file mode 100644 index 00000000..f5fbb274 --- /dev/null +++ b/malib/rollout/inference/model_server.py @@ -0,0 +1,62 @@ +from typing import Dict, Any +from concurrent import futures + +import threading + +from readerwriterlock import rwlock +from torch import nn + +import torch +import ray + + +def load_state_dict(client, timeout=10): + if isinstance(client, ray.ObjectRef): + return ray.get(client.get_state_dict.remote(), timeout=10) + else: + raise NotImplementedError + + +class ModelClient: + def __init__( + self, entry_point: str, model_cls: nn.Module, model_args: Dict[str, Any] + ): + # TODO(ming): init server from entry point + cluster_type, name_or_address = entry_point.split(":") + + if "ray" in cluster_type: + self.client = ray.get_actor(name_or_address) + else: + raise NotImplementedError + + self.cluster_type = cluster_type + self.server_address = name_or_address + self.thread_pool = futures.ThreadPoolExecutor(max_workers=10) + + self.event = threading.Event() + self.thread_pool.submit(self._model_update, self.event) + self.model: nn.Module = model_cls(**model_args).cpu() + self.model.share_memory() + + def __call__(self, *args: Any, **kwds: Any) -> Any: + with torch.inference_mode(): + return self.model(*args, **kwds) + + def _model_update(self, event: threading.Event): + while not event.is_set(): + # TODO(ming): update model from remote server + try: + state_dict = load_state_dict(self.client) + + event.wait(0.5) + except TimeoutError: + # TODO(ming): count or reconnect + event.wait(1) + except RuntimeError: + pass + except KeyboardInterrupt: + break + + def shutdown(self): + self.event.set() + self.thread_pool.shutdown() From c01e4634036a31c5326d347e87b26a1aa9cc5b7c Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Fri, 27 Oct 2023 19:39:12 +0800 Subject: [PATCH 12/24] tmp save --- malib/agent/manager.py | 24 +- malib/backend/dataset_server/__init__.py | 0 malib/backend/dataset_server/utils.py | 17 +- malib/common/task.py | 6 +- malib/rl/common/policy.py | 16 + malib/rl/pg/policy.py | 23 +- malib/rollout/__init__.py | 6 +- .../inference/{ray/server.py => client.py} | 42 +- .../{ray/client.py => env_runner.py} | 88 ++-- .../{model_server.py => model_client.py} | 9 +- malib/rollout/inference/ray/__init__.py | 26 -- malib/rollout/inference/utils.py | 6 +- malib/rollout/manager.py | 54 +-- malib/rollout/pb_rolloutworker.py | 40 +- malib/rollout/rolloutworker.py | 260 ++---------- malib/scenarios/sarl_scenario.py | 13 +- malib/scenarios/scenario.py | 56 ++- tests/rollout/test_env_runner.py | 42 ++ tests/rollout/test_ray_inference.py | 397 ------------------ 19 files changed, 257 insertions(+), 868 deletions(-) delete mode 100644 malib/backend/dataset_server/__init__.py rename malib/rollout/inference/{ray/server.py => client.py} (78%) rename malib/rollout/inference/{ray/client.py => env_runner.py} (78%) rename malib/rollout/inference/{model_server.py => model_client.py} (88%) delete mode 100644 malib/rollout/inference/ray/__init__.py create mode 100644 tests/rollout/test_env_runner.py delete mode 100644 tests/rollout/test_ray_inference.py diff --git a/malib/agent/manager.py b/malib/agent/manager.py index 4fd211c3..7eca9f3a 100644 --- a/malib/agent/manager.py +++ b/malib/agent/manager.py @@ -57,11 +57,6 @@ ) -def validate_spaces(agent_groups: Dict[str, Set[AgentID]], env_desc: Dict[str, Any]): - # TODO(ming): check whether the agents in the group share the same observation space and action space - raise NotImplementedError - - class TrainingManager(Manager): def __init__( self, @@ -70,6 +65,7 @@ def __init__( algorithms: Dict[str, Any], env_desc: Dict[str, Any], agent_mapping_func: Callable[[AgentID], str], + group_info: Dict[str, Any], training_config: Union[Dict[str, Any], TrainingConfig], log_dir: str, remote_mode: bool = True, @@ -98,16 +94,10 @@ def __init__( training_config = TrainingConfig.from_raw(training_config) # interface config give the agent type used here and the group mapping if needed - agent_groups = defaultdict(lambda: set()) - for agent in env_desc["possible_agents"]: - rid = agent_mapping_func(agent) - agent_groups[rid].add(agent) - - validate_spaces(agent_groups, env_desc) # FIXME(ming): resource configuration is not available now, will open in the next version if training_config.trainer_config.get("use_cuda", False): - num_gpus = 1 / len(agent_groups) + num_gpus = 1 / len(group_info["agent_groups"]) else: num_gpus = 0.0 if not os.path.exists(log_dir): @@ -125,7 +115,7 @@ def __init__( "training" in stopping_conditions ), f"Stopping conditions should contains `training` stoppong conditions: {stopping_conditions}" - for rid, agents in agent_groups.items(): + for rid, agents in group_info["agent_groups"].items(): _cls = learner_cls.remote if remote_mode else learner_cls learners[rid] = _cls( experiment_tag=experiment_tag, @@ -145,8 +135,7 @@ def __init__( _ = ray.get([x.connect.remote() for x in learners.values()]) # TODO(ming): collect data entrypoints from learners - - self._agent_groups = agent_groups + self._group_info = group_info self._runtime_ids = tuple(self._agent_groups.keys()) self._experiment_tag = experiment_tag self._env_description = env_desc @@ -170,7 +159,7 @@ def agent_groups(self) -> Dict[str, Set[AgentID]]: Dict[str, Set[AgentID]]: A dict of agent set. """ - return self._agent_groups + return self._group_info["agent_groups"] @property def get_data_entrypoints(self) -> Dict[str, str]: @@ -202,6 +191,9 @@ def runtime_ids(self) -> Tuple[str]: return self._runtime_ids + def get_data_entrypoint_mapping(self) -> Dict[AgentID, str]: + raise NotImplementedError + def add_policies( self, interface_ids: Sequence[str] = None, n: Union[int, Dict[str, int]] = 1 ) -> Dict[str, Type[StrategySpec]]: diff --git a/malib/backend/dataset_server/__init__.py b/malib/backend/dataset_server/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/malib/backend/dataset_server/utils.py b/malib/backend/dataset_server/utils.py index 190d2360..a78e6517 100644 --- a/malib/backend/dataset_server/utils.py +++ b/malib/backend/dataset_server/utils.py @@ -2,17 +2,26 @@ import pickle import grpc +import sys +import os + +sys.path.append(os.path.dirname(__file__)) from . import data_pb2 from . import data_pb2_grpc -def send_data(host: str, port: int, data: Any): +def send_data(data: Any, host: str = None, port: int = None, entrypoint: str = None): if not isinstance(data, bytes): data = pickle.dumps(data) - with grpc.insecure_channel(f"{host}:{port}") as channel: - stub = data_pb2_grpc.SendDataStub(channel) - reply = stub.Collect(data_pb2.Data(data=data)) + if host is not None: + with grpc.insecure_channel(f"{host}:{port}") as channel: + stub = data_pb2_grpc.SendDataStub(channel) + reply = stub.Collect(data_pb2.Data(data=data)) + else: + with grpc.insecure_channel(entrypoint) as channel: + stub = data_pb2_grpc.SendDataStub(channel) + reply = stub.Collect(data_pb2.Data(data=data)) return reply.message diff --git a/malib/common/task.py b/malib/common/task.py index 007b6fb2..e59df234 100644 --- a/malib/common/task.py +++ b/malib/common/task.py @@ -15,8 +15,9 @@ class TaskType(IntEnum): @dataclass class RolloutTask: task_type: int - active_agents: List[AgentID] strategy_specs: Dict[str, Any] = field(default_factory=dict()) + stopping_conditions: Dict[str, Any] = field(default_factory=dict()) + data_entrypoint_mapping: Dict[str, Any] = field(default_factory=dict()) @classmethod def from_raw( @@ -32,9 +33,6 @@ def from_raw( @dataclass class OptimizationTask: - data_entrypoints: Dict[str, str] - """a mapping defines the data request identifier and the data entrypoint.""" - stop_conditions: Dict[str, Any] """stopping conditions for optimization task, e.g., max iteration, max time, etc.""" diff --git a/malib/rl/common/policy.py b/malib/rl/common/policy.py index 0f673f9a..f3525308 100644 --- a/malib/rl/common/policy.py +++ b/malib/rl/common/policy.py @@ -28,6 +28,7 @@ import torch import torch.nn as nn +import gym from gym import spaces @@ -100,6 +101,21 @@ def __init__( use_sde=custom_config.get("use_sde", False), dist_kwargs=custom_config.get("dist_kwargs", None), ) + if kwargs.get("model_client"): + self.model = kwargs["model_client"] + else: + self.model = self.create_model() + + def create_model(self): + raise NotImplementedError + + @property + def action_space(self) -> gym.Space: + return self._action_space + + @property + def observation_space(self) -> gym.Space: + return self._observation_space @property def model_config(self): diff --git a/malib/rl/pg/policy.py b/malib/rl/pg/policy.py index 583eccbd..8903ff73 100644 --- a/malib/rl/pg/policy.py +++ b/malib/rl/pg/policy.py @@ -70,36 +70,39 @@ def __init__( observation_space, action_space, model_config, custom_config, **kwargs ) + def create_model(self): # update model preprocess_net config here action_shape = ( - (action_space.n,) if len(action_space.shape) == 0 else action_space.shape + (self.action_space.n,) + if len(self.action_space.shape) == 0 + else self.action_space.shape ) preprocess_net: nn.Module = net.make_net( - observation_space, + self.observation_space, self.device, - model_config["preprocess_net"].get("net_type", None), - **model_config["preprocess_net"]["config"] + self.model_config["preprocess_net"].get("net_type", None), + **self.model_config["preprocess_net"]["config"] ) - if isinstance(action_space, spaces.Discrete): + if isinstance(self.action_space, spaces.Discrete): self.actor = discrete.Actor( preprocess_net=preprocess_net, action_shape=action_shape, - hidden_sizes=model_config["hidden_sizes"], + hidden_sizes=self.model_config["hidden_sizes"], softmax_output=False, device=self.device, ) - elif isinstance(action_space, spaces.Box): + elif isinstance(self.action_space, spaces.Box): self.actor = continuous.Actor( preprocess_net=preprocess_net, action_shape=action_shape, - hidden_sizes=model_config["hidden_sizes"], - max_action=custom_config.get("max_action", 1.0), + hidden_sizes=self.model_config["hidden_sizes"], + max_action=self.custom_config.get("max_action", 1.0), device=self.device, ) else: raise TypeError( - "Unexpected action space type: {}".format(type(action_space)) + "Unexpected action space type: {}".format(type(self.action_space)) ) self.register_state(self.actor, "actor") diff --git a/malib/rollout/__init__.py b/malib/rollout/__init__.py index 8d802638..0e02ff91 100644 --- a/malib/rollout/__init__.py +++ b/malib/rollout/__init__.py @@ -23,8 +23,8 @@ # SOFTWARE. from .pb_rolloutworker import RolloutWorker -from .inference.ray.client import RayInferenceClient as InferenceClient -from .inference.ray.server import RayInferenceWorkerSet as InferenceWorkerSet +from .inference.env_runner import EnvRunner +from .inference.client import InferenceClient -__all__ = ["RolloutWorker", "InferenceClient", "InferenceWorkerSet"] +__all__ = ["RolloutWorker", "EnvRunner", "InferenceClient"] diff --git a/malib/rollout/inference/ray/server.py b/malib/rollout/inference/client.py similarity index 78% rename from malib/rollout/inference/ray/server.py rename to malib/rollout/inference/client.py index bc9a2487..fa28c7f5 100644 --- a/malib/rollout/inference/ray/server.py +++ b/malib/rollout/inference/client.py @@ -31,9 +31,7 @@ import os import pickle as pkl -import ray import gym -import torch from malib import settings from malib.remote.interface import RemoteInterface @@ -42,20 +40,17 @@ from malib.utils.episode import Episode from malib.common.strategy_spec import StrategySpec from malib.rl.common.policy import Policy -from malib.backend.parameter_server import ParameterServer -ClientHandler = namedtuple("ClientHandler", "sender,recver,runtime_config,rnn_states") +Connection = namedtuple("Connection", "sender,recver,runtime_config,rnn_states") -class RayInferenceWorkerSet(RemoteInterface): +class InferenceClient(RemoteInterface): def __init__( self, agent_id: AgentID, observation_space: gym.Space, action_space: gym.Space, - parameter_server: ParameterServer, - governed_agents: List[AgentID], ) -> None: """Create ray-based inference server. @@ -63,26 +58,22 @@ def __init__( agent_id (AgentID): Runtime agent id, not environment agent id. observation_space (gym.Space): Observation space related to the governed environment agents. action_space (gym.Space): Action space related to the governed environment agents. - parameter_server (ParameterServer): Parameter server. - governed_agents (List[AgentID]): A list of environment agents. """ self.runtime_agent_id = agent_id self.observation_space = observation_space self.action_space = action_space - self.parameter_server = parameter_server self.thread_pool = ThreadPoolExecutor() - self.governed_agents = governed_agents self.policies: Dict[str, Policy] = {} self.strategy_spec_dict: Dict[str, StrategySpec] = {} def shutdown(self): self.thread_pool.shutdown(wait=True) - for _handler in self.clients.values(): + for _handler in self.connections.values(): _handler.sender.shutdown(True) _handler.recver.shutdown(True) - self.clients: Dict[int, ClientHandler] = {} + self.connections: Dict[int, Connection] = {} def save(self, model_dir: str) -> None: if not os.path.exists(model_dir): @@ -100,12 +91,6 @@ def compute_action( strategy_specs: Dict[AgentID, StrategySpec] = runtime_config["strategy_specs"] return_dataframes: List[DataFrame] = [] - # check policy - self._update_policies( - runtime_config["strategy_specs"][self.runtime_agent_id], - self.runtime_agent_id, - ) - assert len(dataframes) > 0 for dataframe in dataframes: @@ -132,16 +117,6 @@ def compute_action( rets = {} - with timer.time_avg("policy_update"): - info = ray.get( - self.parameter_server.get_weights.remote( - spec_id=spec.id, - spec_policy_id=spec_policy_id, - ) - ) - if info["weights"] is not None: - self.policies[policy_id].load_state_dict(info["weights"]) - with timer.time_avg("compute_action"): ( rets[Episode.ACTION], @@ -170,19 +145,12 @@ def compute_action( continue else: rets[k] = v.reshape(batch_size, -1) + return_dataframes.append( DataFrame(identifier=agent_id, data=rets, meta_data=dataframe.meta_data) ) - # print(f"timer information: {timer.todict()}") return return_dataframes - def _update_policies(self, strategy_spec: StrategySpec, agent_id: AgentID): - for strategy_spec_pid in strategy_spec.policy_ids: - policy_id = f"{strategy_spec.id}/{strategy_spec_pid}" - if policy_id not in self.policies: - policy = strategy_spec.gen_policy(device="cpu") - self.policies[policy_id] = policy - def _get_initial_states(self, client_id, observation, policy: Policy, identifier): if ( diff --git a/malib/rollout/inference/ray/client.py b/malib/rollout/inference/env_runner.py similarity index 78% rename from malib/rollout/inference/ray/client.py rename to malib/rollout/inference/env_runner.py index aa143498..1b23fa02 100644 --- a/malib/rollout/inference/ray/client.py +++ b/malib/rollout/inference/env_runner.py @@ -23,7 +23,7 @@ # SOFTWARE. from argparse import Namespace -from typing import Any, List, Dict, Tuple +from typing import Any, List, Dict, Tuple, Set from types import LambdaType from collections import defaultdict @@ -31,36 +31,38 @@ import time import traceback +import pickle import ray -from ray.util.queue import Queue from ray.actor import ActorHandle -from malib.utils.logging import Logger - from malib.utils.typing import AgentID, DataFrame, BehaviorMode -from malib.utils.episode import Episode, NewEpisodeDict, NewEpisodeList +from malib.utils.episode import NewEpisodeList from malib.utils.preprocessor import Preprocessor, get_preprocessor from malib.utils.timing import Timing from malib.remote.interface import RemoteInterface from malib.rollout.envs.vector_env import VectorEnv, SubprocVecEnv -from malib.rollout.inference.ray.server import RayInferenceWorkerSet +from malib.rollout.inference.client import InferenceClient from malib.rollout.inference.utils import process_env_rets, process_policy_outputs +from malib.backend.dataset_server.utils import send_data + +class EnvRunner(RemoteInterface): + def __repr__(self) -> str: + return f"" -class RayInferenceClient(RemoteInterface): def __init__( self, env_desc: Dict[str, Any], - dataset_server: ray.ObjectRef, max_env_num: int, + agent_groups: Dict[str, Set], use_subproc_env: bool = False, batch_mode: str = "time_step", postprocessor_types: Dict = None, training_agent_mapping: LambdaType = None, custom_config: Dict[str, Any] = {}, ): - """Construct an inference client. + """Construct an inference client, one for each agent. Args: env_desc (Dict[str, Any]): Environment description @@ -73,7 +75,6 @@ def __init__( custom_config (Dict[str, Any], optional): Custom configuration. Defaults to an empty dict. """ - self.dataset_server = dataset_server self.use_subproc_env = use_subproc_env self.batch_mode = batch_mode self.postprocessor_types = postprocessor_types or ["defaults"] @@ -82,15 +83,8 @@ def __init__( self.training_agent_mapping = training_agent_mapping or (lambda agent: agent) self.max_env_num = max_env_num self.custom_configs = custom_config - - agent_group = defaultdict(lambda: []) - runtime_agent_ids = [] - for agent in env_desc["possible_agents"]: - runtime_id = training_agent_mapping(agent) - agent_group[runtime_id].append(agent) - runtime_agent_ids.append(runtime_id) - self.runtime_agent_ids = set(runtime_agent_ids) - self.agent_group = dict(agent_group) + self.runtime_agent_ids = list(agent_groups.keys()) + self.agent_groups = agent_groups obs_spaces = env_desc["observation_spaces"] act_spaces = env_desc["action_spaces"] @@ -121,9 +115,9 @@ def close(self): def run( self, - agent_interfaces: Dict[AgentID, RayInferenceWorkerSet], + inference_clients: Dict[AgentID, InferenceClient], rollout_config: Dict[str, Any], - dataset_writer_info_dict: Dict[str, Tuple[str, Queue]] = None, + data_entrypoint_mapping: Dict[AgentID, str] = None, ) -> Dict[str, Any]: """Executes environment runner to collect training data or run purely simulation/evaluation. @@ -131,7 +125,7 @@ def run( Only simulation/evaluation tasks return evaluation information. Args: - agent_interfaces (Dict[AgentID, InferenceWorkerSet]): A dict of agent interface servers. + inference_clients (Dict[AgentID, InferenceClient]): A dict of agent interface servers. rollout_config (Dict[str, Any]): Rollout configuration. dataset_writer_info_dict (Dict[str, Tuple[str, Queue]], optional): Dataset writer info dict. Defaults to None. @@ -148,20 +142,13 @@ def run( "strategy_specs": rollout_config["strategy_specs"], } - if task_type == "rollout": - assert ( - dataset_writer_info_dict is not None - ), "rollout task has no available dataset writer" - server_runtime_config["behavior_mode"] = BehaviorMode.EXPLORATION - elif task_type in ["evaluation", "simulation"]: - server_runtime_config["behavior_mode"] = BehaviorMode.EXPLOITATION - eval_results, performance = env_runner( self, - agent_interfaces, + inference_clients, + self.preprocessor, rollout_config, server_runtime_config, - dwriter_info_dict=dataset_writer_info_dict, + data_entrypoint_mapping, ) res = performance.copy() @@ -171,11 +158,12 @@ def run( def env_runner( - client: RayInferenceClient, - servers: Dict[str, RayInferenceWorkerSet], + client: InferenceClient, + agents: Dict[str, InferenceClient], + preprocessors: Dict[str, Preprocessor], rollout_config: Dict[str, Any], server_runtime_config: Dict[str, Any], - dwriter_info_dict: Dict[str, Tuple[str, Queue]] = None, + data_entrypoint_mapping: Dict[AgentID, str], ) -> Tuple[List[Dict[str, Any]], Dict[str, float]]: """The main logic of environment stepping, also for data collections. @@ -198,11 +186,11 @@ def env_runner( """ # check whether remote server or not - evaluate_on = server_runtime_config["behavior_mode"] == BehaviorMode.EXPLOITATION - remote_actor = isinstance(list(servers.values())[0], ActorHandle) + evaluate_on = rollout_config["behavior_mode"] == BehaviorMode.EXPLOITATION + remote_actor = isinstance(list(agents.values())[0], ActorHandle) try: - if dwriter_info_dict is not None: + if data_entrypoint_mapping is not None: episodes = NewEpisodeList( num=client.env.num_envs, agents=client.env.possible_agents ) @@ -217,9 +205,10 @@ def env_runner( env_dones, processed_env_ret, dataframes = process_env_rets( env_rets=env_rets, - preprocessor=server_runtime_config["preprocessor"], + preprocessors=preprocessors, preset_meta_data={"evaluate": evaluate_on}, ) + # env ret is key first, not agent first: state, obs if episodes is not None: episodes.record( @@ -240,20 +229,20 @@ def env_runner( if remote_actor: policy_outputs: Dict[str, List[DataFrame]] = { rid: ray.get( - server.compute_action.remote( + agent.compute_action.remote( grouped_data_frames[rid], runtime_config=server_runtime_config, ) ) - for rid, server in servers.items() + for rid, agent in agents.items() } else: policy_outputs: Dict[str, List[DataFrame]] = { - rid: server.compute_action( + rid: agent.compute_action( grouped_data_frames[rid], runtime_config=server_runtime_config, ) - for rid, server in servers.items() + for rid, agent in agents.items() } with client.timer.time_avg("process_policy_output"): @@ -284,18 +273,11 @@ def env_runner( cnt += 1 - if dwriter_info_dict is not None: + if data_entrypoint_mapping is not None: # episode_id: agent_id: dict_data episodes = episodes.to_numpy() - for rid, writer_info in dwriter_info_dict.items(): - # get agents from agent group - agents = client.agent_group[rid] - batches = [] - # FIXME(ming): multi-agent is wrong! - for episode in episodes: - agent_buffer = [episode[aid] for aid in agents] - batches.append(agent_buffer) - writer_info[-1].put_nowait_batch(batches) + for entrypoint in data_entrypoint_mapping.values(): + send_data(pickle.dumps(episodes), entrypoint) end = time.time() rollout_info = client.env.collect_info() except Exception as e: diff --git a/malib/rollout/inference/model_server.py b/malib/rollout/inference/model_client.py similarity index 88% rename from malib/rollout/inference/model_server.py rename to malib/rollout/inference/model_client.py index f5fbb274..5391e0a1 100644 --- a/malib/rollout/inference/model_server.py +++ b/malib/rollout/inference/model_client.py @@ -9,6 +9,8 @@ import torch import ray +from malib.utils.typing import AgentID, DataFrame + def load_state_dict(client, timeout=10): if isinstance(client, ray.ObjectRef): @@ -21,7 +23,6 @@ class ModelClient: def __init__( self, entry_point: str, model_cls: nn.Module, model_args: Dict[str, Any] ): - # TODO(ming): init server from entry point cluster_type, name_or_address = entry_point.split(":") if "ray" in cluster_type: @@ -42,6 +43,12 @@ def __call__(self, *args: Any, **kwds: Any) -> Any: with torch.inference_mode(): return self.model(*args, **kwds) + def actor(self, *args, **kwargs): + return self.model.actor(*args, **kwargs) + + def critic(self, *args, **kwargs): + return self.model.critic(*args, **kwargs) + def _model_update(self, event: threading.Event): while not event.is_set(): # TODO(ming): update model from remote server diff --git a/malib/rollout/inference/ray/__init__.py b/malib/rollout/inference/ray/__init__.py deleted file mode 100644 index 26ac2ced..00000000 --- a/malib/rollout/inference/ray/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from .client import RayInferenceClient -from .server import RayInferenceWorkerSet - -__all__ = ["RayInferenceClient", "RayInferenceWorkerSet"] diff --git a/malib/rollout/inference/utils.py b/malib/rollout/inference/utils.py index 588f0538..c6aa1ed3 100644 --- a/malib/rollout/inference/utils.py +++ b/malib/rollout/inference/utils.py @@ -37,7 +37,7 @@ def process_env_rets( env_rets: List[Tuple["states", "observations", "rewards", "dones", "infos"]], - preprocessor: Dict[AgentID, Preprocessor], + preprocessors: Dict[AgentID, Preprocessor], preset_meta_data: Dict[str, Any], ): """Process environment returns, generally, for the observation transformation. @@ -70,7 +70,7 @@ def process_env_rets( agents = list(ret[1].keys()) processed_obs = { - agent: preprocessor[agent].transform(raw_obs) + agent: preprocessors[agent].transform(raw_obs) for agent, raw_obs in ret[1].items() } @@ -85,7 +85,7 @@ def process_env_rets( agent_state_list[agent].append(_state) env_rets_to_save[Episode.CUR_STATE] = ret[0] - original_obs_space = list(preprocessor.values())[0].original_space + original_obs_space = list(preprocessors.values())[0].original_space if ( isinstance(original_obs_space, spaces.Dict) and "action_mask" in original_obs_space diff --git a/malib/rollout/manager.py b/malib/rollout/manager.py index e6b39471..1a4ddbd7 100644 --- a/malib/rollout/manager.py +++ b/malib/rollout/manager.py @@ -27,7 +27,6 @@ """ from typing import Dict, Tuple, Any, Callable, Set, List, Union -from collections import defaultdict import traceback import ray @@ -36,7 +35,7 @@ from ray.util import ActorPool from malib.utils.logging import Logger -from malib.common.task import TaskType, RolloutTask +from malib.common.task import RolloutTask from malib.common.manager import Manager from malib.remote.interface import RemoteInterface from malib.common.strategy_spec import StrategySpec @@ -79,6 +78,7 @@ def __init__( stopping_conditions: Dict[str, Any], num_worker: int, agent_mapping_func: Callable, + group_info: Dict[str, Any], rollout_config: Dict[str, Any], env_desc: Dict[str, Any], log_dir: str, @@ -110,6 +110,7 @@ def __init__( experiment_tag=experiment_tag, env_desc=env_desc, agent_mapping_func=agent_mapping_func, + agent_groups=group_info["agent_groups"], rollout_config=rollout_config, log_dir=log_dir, rollout_callback=None, @@ -121,14 +122,8 @@ def __init__( self._workers: List[ray.actor] = workers self._actor_pool = ActorPool(self._workers) - - agent_groups = defaultdict(lambda: set()) - for agent in env_desc["possible_agents"]: - rid = agent_mapping_func(agent) - agent_groups[rid].add(agent) - - self._runtime_ids = tuple(agent_groups.keys()) - self._agent_groups = dict(agent_groups) + self._runtime_ids = tuple(group_info["agent_groups"].keys()) + self._group_info = group_info self.experiment_tag = experiment_tag assert ( @@ -155,7 +150,7 @@ def agent_groups(self) -> Dict[str, Set]: Dict[str, Set]: A dict of set. """ - return self._agent_groups + return self._group_info["agent_groups"] @property def workers(self) -> List[RemoteInterface]: @@ -167,7 +162,7 @@ def workers(self) -> List[RemoteInterface]: return self._workers - def submit(self, task: Union[Dict[str, Any], List[Dict[str, Any]]], task_type: Any): + def submit(self, task: Union[Dict[str, Any], List[Dict[str, Any]]]): """Submit a task to workers Args: @@ -182,40 +177,7 @@ def submit(self, task: Union[Dict[str, Any], List[Dict[str, Any]]], task_type: A for _task in task: validate_strategy_specs(_task.strategy_specs) - self._actor_pool.submit( - lambda actor, _task: actor.rollout.remote(_task, stopping_conditions) - ) - - def _rollout(self, task_list: List[Dict[str, Any]]) -> None: - """Start rollout task without blocking. - - Args: - task_list (List[Dict[str, Any]]): A list of task dict, keys include: - - `strategy_specs`: a dict of strategy specs, mapping from runtime ids to specs. - - `trainable_agents`: a list of trainable agents. - - """ - - # validate all strategy specs here - for task in task_list: - validate_strategy_specs(task["strategy_specs"]) - - while self._actor_pool.has_next(): - try: - self._actor_pool.get_next(timeout=0) - except TimeoutError: - pass - - for task in task_list: - self._actor_pool.submit( - lambda actor, task: actor.rollout.remote( - runtime_strategy_specs=task["strategy_specs"], - stopping_conditions=self.stopping_conditions["rollout"], - trainable_agents=task["trainable_agents"], - data_entrypoints=task["data_entrypoints"], - ), - task, - ) + self._actor_pool.submit(lambda actor, _task: actor.rollout.remote(_task)) def retrive_results(self): """Retrieve task results diff --git a/malib/rollout/pb_rolloutworker.py b/malib/rollout/pb_rolloutworker.py index c37426ef..852dc68d 100644 --- a/malib/rollout/pb_rolloutworker.py +++ b/malib/rollout/pb_rolloutworker.py @@ -25,7 +25,6 @@ from typing import Dict, Any from malib.rollout.rolloutworker import RolloutWorker, parse_rollout_info -from malib.common.strategy_spec import StrategySpec from malib.utils.logging import Logger @@ -36,7 +35,7 @@ def step_rollout( self, eval_step: bool, rollout_config: Dict[str, Any], - dataset_writer_info_dict: Dict[str, Any], + data_entrypoint_mapping: Dict[str, Any], ): tasks = [rollout_config for _ in range(self.rollout_config["num_threads"])] @@ -53,11 +52,11 @@ def step_rollout( rets = [ x - for x in self.actor_pool.map( + for x in self.env_runner_pool.map( lambda a, task: a.run.remote( - agent_interfaces=self.agent_interfaces, + inference_clients=self.inference_clients, rollout_config=task, - dataset_writer_info_dict=dataset_writer_info_dict, + data_entrypoint_mapping=data_entrypoint_mapping, ), tasks, ) @@ -67,34 +66,3 @@ def step_rollout( parsed_results = parse_rollout_info(rets) Logger.debug(f"parsed results: {parsed_results}") return parsed_results - - def step_simulation( - self, - runtime_strategy_specs: Dict[str, StrategySpec], - runtime_config_template: Dict[str, Any], - ) -> Dict[str, Any]: - """Step simulation task with a given list of strategy spec dicts. - - Args: - runtime_strategy_specs (Dict[str, StrategySpec]): A strategy spec dicts. - runtime_config_template (Dict[str, Any]): Runtime configuration template. - - Returns: - Dict[str, Any]: Evaluation results, a dict. - """ - - task = runtime_config_template.copy() - task["strategy_specs"] = runtime_strategy_specs - - # we should keep dimension as tasks. - rets = [ - parse_rollout_info([x]) - for x in self.actor_pool.map( - lambda a, task: a.run.remote( - agent_interfaces=self.agent_interfaces, rollout_config=task - ), - [task], - ) - ][0] - - return rets diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index 8c406827..6daff6e9 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Any, List, Callable, Sequence, Tuple +from typing import Dict, Any, List, Callable, Sequence, Tuple, Set from abc import abstractmethod from collections import defaultdict @@ -37,21 +37,18 @@ import numpy as np from ray.util import ActorPool -from ray.util.queue import Queue from torch.utils import tensorboard from malib import settings -from malib.utils.typing import AgentID +from malib.utils.typing import AgentID, BehaviorMode from malib.utils.logging import Logger from malib.utils.stopping_conditions import get_stopper from malib.utils.monitor import write_to_tensorboard from malib.common.strategy_spec import StrategySpec -from malib.common.task import RolloutTask, TaskType +from malib.common.task import RolloutTask from malib.remote.interface import RemoteInterface -from malib.rollout.inference.ray.server import ( - RayInferenceWorkerSet as RayInferenceServer, -) -from malib.rollout.inference.ray.client import RayInferenceClient +from malib.rollout.inference.client import InferenceClient +from malib.rollout.inference.env_runner import EnvRunner PARAMETER_GET_TIMEOUT = 3 @@ -116,34 +113,6 @@ def log(message: str): logger.log(settings.LOG_LEVEL, f"(rollout worker) {message}") -def validate_agent_group( - agent_group: Dict[str, List[AgentID]], - full_keys: List[AgentID], - observation_spaces: Dict[AgentID, gym.Space], - action_spaces: Dict[AgentID, gym.Space], -) -> None: - """Validate agent group, check spaces. - - Args: - agent_group (Dict[str, List[AgentID]]): A dict, mapping from runtime ids to lists of agent ids. - full_keys (List[AgentID]): A list of original environment agent ids. - observation_spaces (Dict[AgentID, gym.Space]): Agent observation space dict. - action_spaces (Dict[AgentID, gym.Space]): Agent action space dict. - - Raises: - RuntimeError: Agents in a same group should share the same observation space and action space. - NotImplementedError: _description_ - """ - for agents in agent_group.values(): - select_obs_space = observation_spaces[agents[0]] - select_act_space = action_spaces[agents[0]] - for agent in agents[1:]: - assert type(select_obs_space) == type(observation_spaces[agent]) - assert select_obs_space.shape == observation_spaces[agent].shape - assert type(select_act_space) == type(action_spaces[agent]) - assert select_act_space.shape == action_spaces[agent].shape - - def default_rollout_callback(coordinator: ray.ObjectRef, results: Dict[str, Any]): pass @@ -180,6 +149,7 @@ def __init__( experiment_tag: str, env_desc: Dict[str, Any], agent_mapping_func: Callable, + agent_groups: Dict[str, Set], rollout_config: Dict[str, Any], log_dir: str, rollout_callback: Callable[[ray.ObjectRef, Dict[str, Any]], Any] = None, @@ -187,9 +157,7 @@ def __init__( resource_config: Dict[str, Any] = None, verbose: bool = True, ): - """Create a instance for simulations, rollout and evaluation. This base class initializes \ - all necessary servers and workers for rollouts. Including remote agent interfaces, \ - workers for simultaions. + """Construct a rollout worker, consuming rollout/evaluation tasks. Args: env_desc (Dict[str, Any]): The environment description. @@ -212,62 +180,29 @@ def __init__( self.worker_indentifier = f"rolloutworker_{os.getpid()}" # map agents - agent_group = defaultdict(lambda: []) - runtime_agent_ids = [] - for agent in env_desc["possible_agents"]: - runtime_id = agent_mapping_func(agent) - agent_group[runtime_id].append(agent) - runtime_agent_ids.append(runtime_id) - runtime_agent_ids = set(runtime_agent_ids) - agent_group = dict(agent_group) resource_config = resource_config or DEFAULT_RESOURCE_CONFIG - # valid agent group - validate_agent_group( - agent_group=agent_group, - full_keys=env_desc["possible_agents"], - observation_spaces=env_desc["observation_spaces"], - action_spaces=env_desc["action_spaces"], - ) - self.env_description = env_desc self.env_agents = env_desc["possible_agents"] - self.runtime_agent_ids = runtime_agent_ids - self.agent_group = agent_group + self.runtime_agent_ids = list(agent_groups.keys()) + self.agent_groups = agent_groups self.rollout_config: Dict[str, Any] = rollout_config validate_runtime_configs(self.rollout_config) - self.coordinator = None - self.dataset_server = None - self.parameter_server = None - - self.init_servers() - - if rollout_config["inference_server"] == "local": - self.inference_server_cls = None - self.inference_client_cls = RayInferenceClient.as_remote( - **resource_config["inference_client"] - ) - elif rollout_config["inference_server"] == "ray": - self.inference_client_cls = RayInferenceClient.as_remote( - **resource_config["inference_client"] - ) - self.inference_server_cls = RayInferenceServer.as_remote( - **resource_config["inference_server"] - ).options(max_concurrency=100) - - else: - raise ValueError( - "unexpected inference server type: {}".format( - rollout_config["inference_server"] - ) - ) + self.inference_client_cls = InferenceClient.as_remote( + **resource_config["inference_client"] + ) + self.env_runner_cls = EnvRunner.as_remote( + **resource_config["inference_server"] + ).options(max_concurrency=100) - self.agent_interfaces = self.init_agent_interfaces(env_desc, runtime_agent_ids) - self.actor_pool: ActorPool = self.init_actor_pool( + self.env_runner_pool: ActorPool = self.init_env_runner_pool( env_desc, rollout_config, agent_mapping_func ) + self.inference_clients: Dict[ + AgentID, ray.ObjectRef + ] = self.create_inference_clients() self.log_dir = log_dir self.rollout_callback = rollout_callback or default_rollout_callback @@ -276,48 +211,10 @@ def __init__( self.experiment_tag = experiment_tag self.verbose = verbose - def init_agent_interfaces( - self, env_desc: Dict[str, Any], runtime_ids: Sequence[AgentID] - ) -> Dict[AgentID, Any]: - """Initialize agent interfaces which is a dict of `InterfaceWorkerSet`. The keys in the \ - dict is generated from the given agent mapping function. - - Args: - env_desc (Dict[str, Any]): Environment description. - runtime_ids (Sequence[AgentID]): Available runtime ids, generated with agent mapping function. - - Returns: - Dict[AgentID, Any]: A dict of `InferenceWorkerSet`, mapping from `runtime_ids` to `ray.ObjectRef(s)` - """ - - # interact with environment - if self.inference_server_cls is None: - return None + def create_inference_clients(self) -> Dict[AgentID, ray.ObjectRef]: + raise NotImplementedError - obs_spaces = env_desc["observation_spaces"] - act_spaces = env_desc["action_spaces"] - - runtime_obs_spaces = {} - runtime_act_spaces = {} - - for rid, agents in self.agent_group.items(): - runtime_obs_spaces[rid] = obs_spaces[agents[0]] - runtime_act_spaces[rid] = act_spaces[agents[0]] - - agent_interfaces = { - runtime_id: self.inference_server_cls.remote( - agent_id=runtime_id, - observation_space=runtime_obs_spaces[runtime_id], - action_space=runtime_act_spaces[runtime_id], - parameter_server=self.parameter_server, - governed_agents=self.agent_group[runtime_id], - ) - for runtime_id in runtime_ids - } - - return agent_interfaces - - def init_actor_pool( + def init_env_runner_pool( self, env_desc: Dict[str, Any], rollout_config: Dict[str, Any], @@ -344,12 +241,12 @@ def init_actor_pool( num_env_per_thread = rollout_config["num_env_per_thread"] num_eval_threads = rollout_config["num_eval_threads"] - actor_pool = ActorPool( + env_runner_pool = ActorPool( [ - self.inference_client_cls.remote( + self.env_runner_cls.remote( env_desc, - ray.get_actor(settings.OFFLINE_DATASET_ACTOR), max_env_num=num_env_per_thread, + agent_groups=self.agent_groups, use_subproc_env=rollout_config["use_subproc_env"], batch_mode=rollout_config["batch_mode"], postprocessor_types=rollout_config["postprocessor_types"], @@ -358,77 +255,28 @@ def init_actor_pool( for _ in range(num_threads + num_eval_threads) ] ) - return actor_pool - - def init_servers(self): - """Connect to data servers. + return env_runner_pool - Raises: - RuntimeError: Runtime errors. - """ - - retries = 100 - while True: - try: - if self.parameter_server is None: - self.parameter_server = ray.get_actor( - settings.PARAMETER_SERVER_ACTOR - ) - - if self.dataset_server is None: - self.dataset_server = ray.get_actor(settings.OFFLINE_DATASET_ACTOR) - break - except Exception as e: - retries -= 1 - if retries == 0: - raise RuntimeError(traceback.format_exc()) - else: - logger.log( - logging.WARNING, - f"waiting for coordinator server initialization ... {self.worker_indentifier}", - ) - time.sleep(1) - - def rollout( - self, - runtime_strategy_specs: Dict[str, StrategySpec], - stopping_conditions: Dict[str, Any], - data_entrypoints: Dict[str, str] = None, - active_agents: List[AgentID] = None, - ): + def rollout(self, task: RolloutTask): """Rollout, collecting training data when `data_entrypoints` is given, until meets the stopping conditions. The `active_agents` should be None or a none-empty list to specify active agents if rollout is not serve for evaluation. NOTE: the data collection will be triggered only for active agents. Args: - runtime_strategy_specs (Dict[str, StrategySpec]): A dict of strategy spec, mapping from runtime id to `StrategySpec`. - stopping_conditions (Dict[str, Any]): A dict of stopping conditions. - data_entrypoints (Dict[str, str], optional): Mapping from runtimeids to dataentrypoint names. None for evaluation. - active_agents (List[AgentID], optional): A list of environment agent id. Defaults to None, which means all environment agents will be trainable. Empty list for evaluation mode. + task: None """ - stopper = get_stopper(stopping_conditions) + stopper = get_stopper(task.stopping_conditions) active_agents = active_agents or self.env_agents - - if data_entrypoints is not None: - queue_info_dict: Dict[str, Tuple[str, Queue]] = { - rid: None for rid in self.runtime_agent_ids - } - for rid, identifier in data_entrypoints.items(): - queue_id, queue = ray.get( - self.dataset_server.start_producer_pipe.remote(name=identifier) - ) - queue_info_dict[rid] = (queue_id, queue) - else: - queue_info_dict = None + runtime_strategy_specs = task.strategy_specs + data_entrypoint_mapping = task.data_entrypoint_mapping rollout_config = self.rollout_config.copy() rollout_config.update( { "flag": "rollout", "strategy_specs": runtime_strategy_specs, - "active_agents": active_agents, - "agent_group": self.agent_group, + "behavior_mode": BehaviorMode.EXPLORATION, } ) total_timesteps = 0 @@ -443,10 +291,11 @@ def rollout( self.set_running(True) start_time = time.time() - # TODO(ming): share the stopping conditions here while self.is_running(): eval_step = (epoch + 1) % self.rollout_config["eval_interval"] == 0 - results = self.step_rollout(eval_step, rollout_config, queue_info_dict) + results = self.step_rollout( + eval_step, rollout_config, data_entrypoint_mapping + ) total_timesteps += results["total_timesteps"] eval_results = results.get("evaluation", None) @@ -480,29 +329,12 @@ def rollout( self.rollout_callback(self.coordinator, results) return results - def simulate(self, runtime_strategy_specs: Dict[str, StrategySpec]): - """Handling simulation task.""" - - runtime_config_template = self.rollout_config.copy() - runtime_config_template.update( - { - "flag": "simulation", - } - ) - - results: Dict[str, Any] = self.step_simulation( - runtime_strategy_specs, runtime_config_template - ) - - self.simulate_callback(self.coordinator, results) - return results - @abstractmethod def step_rollout( self, eval_step: bool, rollout_config: Dict[str, Any], - dataset_writer_info_dict: Dict[str, Any], + data_entrypoint_mapping: Dict[AgentID, str], ) -> List[Dict[str, Any]]: """The logic function to run rollout. Users must implment this method. @@ -521,6 +353,7 @@ def step_rollout( - `agent_group`: a dict that maps runtime agents to a list of environment agents, which describes the envrionment agents \ governed by what runtime agent interface. - `fragment_length`: the maximum of collected data frames. + data_entrypoint_mapping: ... Raises: NotImplementedError: _description_ @@ -529,25 +362,6 @@ def step_rollout( List[Dict[str, Any]]: Evaluation results, could be empty. """ - @abstractmethod - def step_simulation( - self, - runtime_strategy_specs: Dict[str, StrategySpec], - rollout_config: Dict[str, Any], - ) -> Dict[str, Any]: - """Logic function for running simulation of a list of strategy spec dict. - - Args: - runtime_strategy_specs (Dict[str, StrategySpec]): A strategy spec dict. - rollout_config (Dict[str, Any]): Runtime configuration template. - - Raises: - NotImplementedError: Not implemented error. - - Returns: - Dict[str, Any]: A evaluation results. - """ - def assign_episode_id(self): return f"eps-{self.worker_indentifier}-{time.time()}" diff --git a/malib/scenarios/sarl_scenario.py b/malib/scenarios/sarl_scenario.py index d5f0ebd7..64a6da1f 100644 --- a/malib/scenarios/sarl_scenario.py +++ b/malib/scenarios/sarl_scenario.py @@ -71,6 +71,7 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = algorithms=scenario.algorithms, env_desc=scenario.env_desc, agent_mapping_func=scenario.agent_mapping_func, + group_info=scenario.group_info, training_config=scenario.training_config, log_dir=scenario.log_dir, remote_mode=True, @@ -83,6 +84,7 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = stopping_conditions=scenario.stopping_conditions, num_worker=scenario.num_worker, agent_mapping_func=scenario.agent_mapping_func, + group_info=scenario.group_info, rollout_config=scenario.rollout_config, env_desc=scenario.env_desc, log_dir=scenario.log_dir, @@ -97,13 +99,8 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = f"Training manager was inistialized with a strategy spec:\n{strategy_specs}" ) - data_entrypoints = {rid: rid for rid in training_manager.runtime_ids} - - assert len(data_entrypoints) == 1, "Support single agent only!" - optimization_task = OptimizationTask( active_agents=None, - data_entrypoints=data_entrypoints, stop_conditions=scenario.stopping_conditions["training"], ) training_manager.submit(optimization_task) @@ -111,7 +108,7 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = rollout_task = { "num_workers": 1, "runtime_strategy_specs": strategy_specs, - "data_entrypoints": None, + "data_entrypoints": training_manager.get_data_entrypoint_mapping(), "rollout_config": scenario.rollout_config, "active_agents": None, } @@ -123,8 +120,8 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = ), } - rollout_manager.submit(rollout_task, task_type=TaskType.ROLLOUT) - rollout_manager.submit(evaluation_task, task_type=TaskType.EVALUATION) + rollout_manager.submit(rollout_task) + rollout_manager.submit(evaluation_task) results = league.get_results() diff --git a/malib/scenarios/scenario.py b/malib/scenarios/scenario.py index 6a25c256..f009a085 100644 --- a/malib/scenarios/scenario.py +++ b/malib/scenarios/scenario.py @@ -22,13 +22,50 @@ from abc import ABC, abstractmethod from types import LambdaType -from typing import Callable, Union, Dict, Any +from typing import Callable, Union, Dict, Any, Set, List from copy import deepcopy +from collections import defaultdict + +import gym + +from malib.utils.typing import AgentID DEFAULT_STOPPING_CONDITIONS = {} +def validate_spaces(agent_groups: Dict[str, Set[AgentID]], env_desc: Dict[str, Any]): + # TODO(ming): check whether the agents in the group share the same observation space and action space + raise NotImplementedError + + +def validate_agent_group( + agent_group: Dict[str, List[AgentID]], + observation_spaces: Dict[AgentID, gym.Space], + action_spaces: Dict[AgentID, gym.Space], +) -> None: + """Validate agent group, check spaces. + + Args: + agent_group (Dict[str, List[AgentID]]): A dict, mapping from runtime ids to lists of agent ids. + full_keys (List[AgentID]): A list of original environment agent ids. + observation_spaces (Dict[AgentID, gym.Space]): Agent observation space dict. + action_spaces (Dict[AgentID, gym.Space]): Agent action space dict. + + Raises: + RuntimeError: Agents in a same group should share the same observation space and action space. + NotImplementedError: _description_ + """ + for agents in agent_group.values(): + select_obs_space = observation_spaces[agents[0]] + select_act_space = action_spaces[agents[0]] + for agent in agents[1:]: + assert type(select_obs_space) == type(observation_spaces[agent]) + assert select_obs_space.shape == observation_spaces[agent].shape + assert type(select_act_space) == type(action_spaces[agent]) + assert select_act_space.shape == action_spaces[agent].shape + + class Scenario(ABC): @abstractmethod def __init__( @@ -49,6 +86,23 @@ def __init__( self.env_desc = env_desc self.algorithms = algorithms self.agent_mapping_func = agent_mapping_func + # then generate grouping information here + agent_groups = defaultdict(lambda: set()) + grouped_obs_space = {} + grouped_act_space = {} + for agent in env_desc["possible_agents"]: + rid = agent_mapping_func(agent) + agent_groups[rid].add(agent) + grouped_obs_space[rid] = env_desc["observation_spaces"][agent] + grouped_act_space[rid] = env_desc["action_spaces"][agent] + self.group_info = { + "observation_space": grouped_obs_space, + "action_space": grouped_act_space, + "agent_groups": agent_groups, + } + validate_agent_group( + agent_groups, env_desc["observation_spaces"], env_desc["action_spaces"] + ) self.training_config = training_config self.rollout_config = rollout_config self.stopping_conditions = stopping_conditions or DEFAULT_STOPPING_CONDITIONS diff --git a/tests/rollout/test_env_runner.py b/tests/rollout/test_env_runner.py new file mode 100644 index 00000000..7085bd40 --- /dev/null +++ b/tests/rollout/test_env_runner.py @@ -0,0 +1,42 @@ +from typing import List, Dict, Any + +import pytest + +from malib.utils.typing import BehaviorMode +from malib.common.strategy_spec import StrategySpec +from malib.rollout.inference import env_runner +from malib.rollout.inference.client import InferenceClient +from malib.rollout.envs import mdp + + +@pytest.mark.parametrize( + "env_desc,max_env_num", + [ + [mdp.env_desc_gen(env_id="multi_round_nmdp"), 1], + ], +) +def test_env_runner(env_desc: Dict[str, Any], max_env_num: int): + agent_groups = dict(zip(env_desc["possible_agents"], env_desc["possible_agents"])) + runner = env_runner.EnvRunner(env_desc, max_env_num, agent_groups) + + agents = env_desc["possible_agents"] + observation_spaces = env_desc["observation_spaces"] + action_spaces = env_desc["action_spaces"] + + inference_remote_cls = InferenceClient.as_remote(num_cpus=1) + rollout_config = { + "flag": "evaluation", + "strategy_specs": { + agent: StrategySpec(agent, ["policy-0"], meta_data={}) for agent in agents + }, + "behavior_mode": BehaviorMode.EXPLOITATION, + } + + infer_clients = { + agent: inference_remote_cls.remote( + agent, observation_spaces[agent], action_spaces[agent] + ) + for agent in agents + } + + runner.run(infer_clients, rollout_config) diff --git a/tests/rollout/test_ray_inference.py b/tests/rollout/test_ray_inference.py deleted file mode 100644 index 478b6144..00000000 --- a/tests/rollout/test_ray_inference.py +++ /dev/null @@ -1,397 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Author: Ming Zhou - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from typing import Callable, Dict, Any, List, Tuple -from argparse import Namespace -from collections import defaultdict - -import pytest -import ray - -from malib.agent.agent_interface import AgentInterface -from malib.agent.manager import TrainingManager -from malib.backend.parameter_server import ParameterServer - -# from malib.rollout.envs.dummy_env import env_desc_gen -from malib.runner import start_servers -from malib.rollout.envs.gym import env_desc_gen as gym_env_desc_gen -from malib.rollout.envs.open_spiel import env_desc_gen as open_spiel_env_desc_gen -from malib.rollout.envs.vector_env import VectorEnv -from malib.rollout.inference.utils import process_policy_outputs -from malib.rollout.rolloutworker import parse_rollout_info -from malib.utils.episode import Episode, NewEpisodeDict -from malib.utils.typing import AgentID, PolicyID -from malib.agent.indepdent_agent import IndependentAgent -from malib.common.strategy_spec import StrategySpec -from malib.scenarios.marl_scenario import MARLScenario -from malib.rollout.inference.ray.server import RayInferenceWorkerSet -from malib.rollout.inference.ray.client import env_runner, RayInferenceClient -from malib.utils.typing import BehaviorMode - - -def dqn(): - from malib.rl.dqn import DQNPolicy, DQNTrainer, DEFAULT_CONFIG - - algorithms = { - "default": ( - DQNPolicy, - DQNTrainer, - # model configuration, None for default - { - "net_type": "general_net", - "config": {"hidden_sizes": [64, 64]}, - }, - {}, - ) - } - trainer_config = DEFAULT_CONFIG["training_config"].copy() - return [algorithms, trainer_config] - - -def build_marl_scenario( - algorithms: Dict[str, Dict], - env_description: Dict[str, Any], - learner_cls, - trainer_config: Dict[str, Any], - agent_mapping_func: Callable, - runtime_logdir: str, -) -> MARLScenario: - training_config = { - "type": learner_cls, - "trainer_config": trainer_config, - "custom_config": {}, - } - rollout_config = { - "fragment_length": 200, # every thread - "max_step": 20, - "num_eval_episodes": 10, - "num_threads": 2, - "num_env_per_thread": 2, - "num_eval_threads": 1, - "use_subproc_env": False, - "batch_mode": "time_step", - "postprocessor_types": ["defaults"], - # every # rollout epoch run evaluation. - "eval_interval": 1, - "inference_server": "ray", # three kinds of inference server: `local`, `pipe` and `ray` - } - scenario = MARLScenario( - name="test_ray_inference", - log_dir=runtime_logdir, - algorithms=algorithms, - env_description=env_description, - training_config=training_config, - rollout_config=rollout_config, - agent_mapping_func=agent_mapping_func, - stopping_conditions={ - "training": {"max_iteration": int(1e10)}, - "rollout": {"max_iteration": 1000, "minimum_reward_improvement": 1.0}, - }, - ) - return scenario - - -def push_policy_to_parameter_server( - scenario: MARLScenario, parameter_server: ParameterServer -) -> Dict[AgentID, StrategySpec]: - """Generate a dict of strategy spec, generate policies and push them to the remote parameter server. - - Args: - scenario (MARLScenario): Scenario instance. - agents (List[AgentID]): A list of enviornment agents. - parameter_server (ParameterServer): Remote parameter server. - - Returns: - Dict[AgentID, StrategySpec]: A dict of strategy specs. - """ - - res = dict() - for agent in scenario.env_desc["possible_agents"]: - sid = scenario.agent_mapping_func(agent) - if sid in res: - continue - spec_pid = f"policy-0" - strategy_spec = StrategySpec( - identifier=sid, - policy_ids=[spec_pid], - meta_data={ - "policy_cls": scenario.algorithms["default"][0], - "experiment_tag": "test_ray_inference", - "kwargs": { - "observation_space": scenario.env_desc["observation_spaces"][agent], - "action_space": scenario.env_desc["action_spaces"][agent], - "model_config": scenario.algorithms["default"][2], - "custom_config": scenario.algorithms["default"][3], - "kwargs": {}, - }, - }, - ) - policy = strategy_spec.gen_policy() - ray.get(parameter_server.create_table.remote(strategy_spec)) - ray.get( - parameter_server.set_weights.remote( - spec_id=strategy_spec.id, - spec_policy_id=spec_pid, - state_dict=policy.state_dict(), - ) - ) - res[sid] = strategy_spec - return res - - -def generate_cs( - scenario: MARLScenario, dataset_server, parameter_server -) -> Tuple[RayInferenceClient, Dict[str, RayInferenceWorkerSet]]: - env_desc = scenario.env_desc - observation_spaces = env_desc["observation_spaces"] - action_spaces = env_desc["action_spaces"] - servers = dict.fromkeys(env_desc["possible_agents"], None) - agent_group = defaultdict(list) - for agent in env_desc["possible_agents"]: - sid = scenario.agent_mapping_func(agent) - agent_group[sid].append(agent) - - client = RayInferenceClient( - env_desc=scenario.env_desc, - dataset_server=dataset_server, - max_env_num=scenario.rollout_config["num_env_per_thread"], - use_subproc_env=scenario.rollout_config["use_subproc_env"], - batch_mode=scenario.rollout_config["batch_mode"], - postprocessor_types=scenario.rollout_config["postprocessor_types"], - training_agent_mapping=scenario.agent_mapping_func, - ) - - for sid, agents in agent_group.items(): - servers[sid] = RayInferenceWorkerSet( - agent_id=sid, - observation_space=observation_spaces[agent], - action_space=action_spaces[agent], - parameter_server=parameter_server, - governed_agents=agents.copy(), - ) - - return client, servers - - -from malib.rollout.inference.ray.client import process_env_rets - - -def rollout_func( - episode_dict: NewEpisodeDict, - client: RayInferenceClient, - servers: Dict[str, RayInferenceWorkerSet], - rollout_config, - server_runtime_config, - evaluate, -): - env_rets = client.env.reset( - fragment_length=rollout_config["fragment_length"], - max_step=rollout_config["max_step"], - ) - processed_env_ret, dataframes = process_env_rets( - env_rets, - preprocessor=server_runtime_config["preprocessor"], - preset_meta_data={"evaluate": evaluate}, - ) - if episode_dict is not None: - episode_dict.record(processed_env_ret, agent_first=False) - - cnt = 0 - while not client.env.is_terminated(): - grouped_dataframes = defaultdict(list) - for agent, dataframe in dataframes.items(): - runtime_id = client.training_agent_mapping(agent) - grouped_dataframes[runtime_id].append(dataframe) - - policy_outputs = { - rid: server.compute_action( - grouped_dataframes[rid], runtime_config=server_runtime_config - ) - for rid, server in servers.items() - } - - env_actions, processed_policy_outputs = process_policy_outputs( - policy_outputs, client.env - ) - - assert len(env_actions) > 0, "inference server may be stucked." - - if episode_dict is not None: - episode_dict.record(processed_policy_outputs, agent_first=True) - - env_rets = client.env.step(env_actions) - if len(env_rets) < 1: - dataframes = {} - continue - - processed_env_ret, dataframes = process_env_rets( - env_rets, - preprocessor=server_runtime_config["preprocessor"], - preset_meta_data={"evaluate": evaluate}, - ) - - if episode_dict is not None: - episode_dict.record(processed_env_ret, agent_first=False) - - cnt += 1 - - -def data_servers(): - if not ray.is_initialized(): - ray.init() - - parameter_server, offline_dataset_server = start_servers() - return parameter_server, offline_dataset_server - - -@pytest.mark.parametrize( - "env_desc", - [ - gym_env_desc_gen(env_id="CartPole-v1"), - # open_spiel_env_desc_gen(env_id="kuhn_poker"), - # mdp_env_desc_gen(env_id="two_round_dmdp"), - ], -) -@pytest.mark.parametrize("learner_cls", [IndependentAgent]) -@pytest.mark.parametrize("algorithms,trainer_config", [dqn()]) -def test_inference_mechanism(env_desc, learner_cls, algorithms, trainer_config): - parameter_server, dataset_server = data_servers() - scenario: MARLScenario = build_marl_scenario( - algorithms, - env_desc, - learner_cls, - trainer_config, - agent_mapping_func=lambda agent: agent, - runtime_logdir="./logs", - ) - client, servers = generate_cs(scenario, dataset_server, parameter_server) - training_manager = TrainingManager( - experiment_tag=scenario.name, - stopping_conditions=scenario.stopping_conditions, - algorithms=scenario.algorithms, - env_desc=scenario.env_desc, - agent_mapping_func=scenario.agent_mapping_func, - training_config=scenario.training_config, - log_dir=scenario.log_dir, - remote_mode=True, - resource_config=scenario.resource_config["training"], - verbose=True, - ) - data_entrypoints = {k: k for k in training_manager.agent_groups.keys()} - - # add policies and start training - strategy_specs = training_manager.add_policies(n=scenario.num_policy_each_interface) - strategy_specs = strategy_specs - data_entrypoints = data_entrypoints - - rollout_config = scenario.rollout_config.copy() - rollout_config["flag"] = "rollout" - - server_runtime_config = { - "strategy_specs": strategy_specs, - "behavior_mode": BehaviorMode.EXPLOITATION, - "preprocessor": client.preprocessor, - } - - dwriter_info_dict = dict.fromkeys(data_entrypoints.keys(), None) - - for rid, identifier in data_entrypoints.items(): - queue_id, queue = ray.get( - dataset_server.start_producer_pipe.remote(name=identifier) - ) - dwriter_info_dict[rid] = (queue_id, queue) - - eval_results, performance_results = env_runner( - client, - servers, - rollout_config, - server_runtime_config, - dwriter_info_dict, - ) - eval_results = parse_rollout_info([{"evaluation": eval_results}]) - print(eval_results["evaluation"]) - print(performance_results) - - for rid, identifier in data_entrypoints.items(): - ray.get(dataset_server.end_producer_pipe.remote(identifier)) - - ray.kill(parameter_server) - ray.kill(dataset_server) - ray.shutdown() - - -# def test_inference_pipeline(self): -# """This function tests the inference pipeline without using default env runner""" - -# training_manager.run(data_entrypoints) - -# rollout_config = scenario.rollout_config.copy() -# rollout_config["flag"] = "rollout" -# server_runtime_config = { -# "strategy_specs": strategy_specs, -# "behavior_mode": BehaviorMode.EXPLOITATION, -# "preprocessor": client.preprocessor, -# } - -# dwriter_info_dict = dict.fromkeys(data_entrypoints.keys(), None) - -# for rid, identifier in data_entrypoints.items(): -# queue_id, queue = ray.get( -# dataset_server.start_producer_pipe.remote(name=identifier) -# ) -# dwriter_info_dict[rid] = (queue_id, queue) - -# # collect episodes and run training -# rollout_env = client.env -# for n_epoch in range(2): -# episode_dict = NewEpisodeDict( -# lambda: Episode(agents=scenario.env_desc["possible_agents"]) -# ) -# rollout_func( -# episode_dict, -# client, -# servers, -# rollout_config, -# server_runtime_config, -# False, -# ) - -# episodes = episode_dict.to_numpy() -# for rid, writer_info in dwriter_info_dict.items(): -# agents = client.agent_group[rid] -# batches = [] -# for episode in episodes.values(): -# agent_buffer = [episode[aid] for aid in agents] -# batches.append(agent_buffer) -# writer_info[-1].put_nowait_batch(batches) -# rollout_info = client.env.collect_info() -# eval_results = list(rollout_info.values()) -# rollout_res = parse_rollout_info([{"evaluation": eval_results}]) - -# print("epoch: {}\nrollout_res: {}\n".format(n_epoch, rollout_res)) - -# client.env = rollout_env - -# training_manager.cancel_pending_tasks() -# # training_manager.terminate() From ab55fa10dd4268d0aa9702be7264a571b44732ca Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Wed, 8 Nov 2023 19:20:15 +0800 Subject: [PATCH 13/24] tmp save --- malib/common/rollout_config.py | 10 ++ malib/common/strategy_spec.py | 64 ++++---- malib/rl/common/policy.py | 25 ++- malib/rl/pg/policy.py | 15 +- malib/rl/random/policy.py | 5 +- malib/rollout/envs/env.py | 9 ++ malib/rollout/envs/mdp/env.py | 4 +- malib/rollout/inference/env_runner.py | 148 +++++++++++++++++- malib/utils/episode.py | 56 +++++++ tests/rollout/test_env_runner.py | 15 +- .../test_episode.py | 0 .../test_payoff_manager.py | 0 tests/structures/test_strategy_spec.py | 51 ++++++ 13 files changed, 335 insertions(+), 67 deletions(-) rename tests/{malib_utils => structures}/test_episode.py (100%) rename tests/{malib_utils => structures}/test_payoff_manager.py (100%) create mode 100644 tests/structures/test_strategy_spec.py diff --git a/malib/common/rollout_config.py b/malib/common/rollout_config.py index 4b4191d7..6a3ec2e1 100644 --- a/malib/common/rollout_config.py +++ b/malib/common/rollout_config.py @@ -6,6 +6,16 @@ @dataclass class RolloutConfig: inference_server_type: str + """Inference server type""" + + num_workers: int = 1 + """Defines how many workers will be used for executing one rollout task, default is 1""" + + n_envs_per_worker: int = 1 + """Indicates how many environments will be activated for a rollout task per rollout worker, default is 1""" + + timelimit: int = 256 + """Specifying how many time steps will be collected for each rollout, default is 256""" @classmethod def from_raw( diff --git a/malib/common/strategy_spec.py b/malib/common/strategy_spec.py index b23896af..ec581947 100644 --- a/malib/common/strategy_spec.py +++ b/malib/common/strategy_spec.py @@ -24,7 +24,7 @@ from typing import Dict, Any, Tuple, Type from argparse import Namespace - +from collections import namedtuple import numpy as np from malib.rl.common.policy import Policy @@ -48,9 +48,21 @@ def validate_meta_data(policy_ids: Tuple[PolicyID], meta_data: Dict[str, Any]): assert np.isclose(sum(meta_data["prob_list"]), 1.0) +import copy + +from gym import spaces + + class StrategySpec: def __init__( - self, identifier: str, policy_ids: Tuple[PolicyID], meta_data: Dict[str, Any] + self, + policy_cls: Type, + observation_space: spaces.Space, + action_space: spaces.Space, + model_config: Dict[str, Any] = None, + identifier: str = None, + policy_ids: Tuple[PolicyID] = None, + **kwargs, ) -> None: """Construct a strategy spec. @@ -60,10 +72,17 @@ def __init__( meta_data (Dict[str, Any]): Meta data, for policy construction. """ - validate_meta_data(policy_ids, meta_data) - self.id = identifier - self.policy_ids = tuple(policy_ids) - self.meta_data = meta_data + self.id = identifier or "StrategySpec" + self.policy_ids = tuple(policy_ids) if policy_ids else () + self.meta_data = { + "policy_cls": policy_cls, + "init_kwargs": { + "observation_space": observation_space, + "action_space": action_space, + "model_config": model_config, + **kwargs, + }, + } def __str__(self): return f"" @@ -85,7 +104,8 @@ def register_policy_id(self, policy_id: PolicyID): policy_id (PolicyID): Policy id to register. """ - assert policy_id not in self.policy_ids, (policy_id, self.policy_ids) + if policy_id in self.policy_ids: + raise KeyError("repected policy id detected: {}".format(policy_id)) self.policy_ids = self.policy_ids + (policy_id,) if "prob_list" in self.meta_data: @@ -104,10 +124,11 @@ def update_prob_list(self, policy_probs: Dict[PolicyID, float]): for pid, prob in policy_probs.items(): idx = self.policy_ids.index(pid) self.meta_data["prob_list"][idx] = prob - assert np.isclose(sum(self.meta_data["prob_list"]), 1.0), ( - self.meta_data["prob_list"], - sum(self.meta_data["prob_list"]), - ) + + if not np.isclose(sum(self.meta_data["prob_list"]), 1.0): + raise ValueError( + f"Prob list is not normalized: {self.meta_data['prob_list']}" + ) def get_meta_data(self) -> Dict[str, Any]: """Return meta data. Keys in meta-data contains @@ -121,7 +142,7 @@ def get_meta_data(self) -> Dict[str, Any]: Dict[str, Any]: A dict of meta data. """ - return self.meta_data + return copy.deepcopy(self.meta_data) def gen_policy(self, device=None) -> Policy: """Generate a policy instance with the given meta data. @@ -131,23 +152,8 @@ def gen_policy(self, device=None) -> Policy: """ policy_cls: Type[Policy] = self.meta_data["policy_cls"] - plist = self.meta_data["kwargs"] - plist = Namespace(**plist) - - custom_config = plist.custom_config.copy() - - if device is not None and "cuda" in device: - custom_config["use_cuda"] = True - else: - custom_config["use_cuda"] = False - - return policy_cls( - observation_space=plist.observation_space, - action_space=plist.action_space, - model_config=plist.model_config, - custom_config=custom_config, - **plist.kwargs, - ) + policy = policy_cls(**self.meta_data["init_kwargs"]) + return policy.to(device) def sample(self) -> PolicyID: """Sample a policy instance. Use uniform sample if there is no presetted prob list in meta data. diff --git a/malib/rl/common/policy.py b/malib/rl/common/policy.py index f3525308..39784cba 100644 --- a/malib/rl/common/policy.py +++ b/malib/rl/common/policy.py @@ -62,25 +62,21 @@ def state_dict(self): class Policy(metaclass=ABCMeta): - def __init__( - self, observation_space, action_space, model_config, custom_config, **kwargs - ): + def __init__(self, observation_space, action_space, model_config, **kwargs): _locals = locals() _locals.pop("self") self._init_args = _locals self._observation_space = observation_space self._action_space = action_space - self._model_config = model_config or {} - self._custom_config = custom_config or {} + self._model_config = model_config + self._custom_config = kwargs self._state_handler_dict = {} self._preprocessor = get_preprocessor( observation_space, - mode=self._custom_config.get("preprocess_mode", "flatten"), + mode=kwargs.get("preprocess_mode", "flatten"), )(observation_space) - self._device = torch.device( - "cuda" if self._custom_config.get("use_cuda") else "cpu" - ) + self._device = torch.device("cuda" if kwargs.get("use_cuda") else "cpu") self._registered_networks: Dict[str, nn.Module] = {} @@ -95,16 +91,13 @@ def __init__( ) ) - self.use_cuda = self._custom_config.get("use_cuda", False) + self.use_cuda = kwargs.get("use_cuda", False) self.dist_fn: Distribution = make_proba_distribution( action_space=action_space, - use_sde=custom_config.get("use_sde", False), - dist_kwargs=custom_config.get("dist_kwargs", None), + use_sde=kwargs.get("use_sde", False), + dist_kwargs=kwargs.get("dist_kwargs", None), ) - if kwargs.get("model_client"): - self.model = kwargs["model_client"] - else: - self.model = self.create_model() + self.model = kwargs.get("model_client", self.create_model()) def create_model(self): raise NotImplementedError diff --git a/malib/rl/pg/policy.py b/malib/rl/pg/policy.py index 8903ff73..c4773182 100644 --- a/malib/rl/pg/policy.py +++ b/malib/rl/pg/policy.py @@ -42,8 +42,7 @@ def __init__( self, observation_space: spaces.Space, action_space: spaces.Space, - model_config: Dict[str, Any], - custom_config: Dict[str, Any], + model_config: Dict[str, Any] = None, **kwargs ): """Build a REINFORCE policy whose input and output dims are determined by observation_space and action_space, respectively. @@ -52,8 +51,6 @@ def __init__( observation_space (spaces.Space): The observation space. action_space (spaces.Space): The action space. model_config (Dict[str, Any]): The model configuration dict. - custom_config (Dict[str, Any]): The custom configuration dict. - is_fixed (bool, optional): Indicates fixed policy or trainable policy. Defaults to False. Raises: NotImplementedError: Does not support other action space type settings except Box and Discrete. @@ -61,14 +58,12 @@ def __init__( """ # update model_config with default ones - model_config = merge_dicts(DEFAULT_CONFIG["model_config"].copy(), model_config) - custom_config = merge_dicts( - DEFAULT_CONFIG["custom_config"].copy(), custom_config + model_config = merge_dicts( + DEFAULT_CONFIG["model_config"].copy(), model_config or {} ) + kwargs = merge_dicts(DEFAULT_CONFIG["custom_config"].copy(), kwargs) - super().__init__( - observation_space, action_space, model_config, custom_config, **kwargs - ) + super().__init__(observation_space, action_space, model_config, **kwargs) def create_model(self): # update model preprocess_net config here diff --git a/malib/rl/random/policy.py b/malib/rl/random/policy.py index cadea858..8409be8f 100644 --- a/malib/rl/random/policy.py +++ b/malib/rl/random/policy.py @@ -11,9 +11,6 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Space, model_config: Dict[str, Any], - custom_config: Dict[str, Any], **kwargs ): - super().__init__( - observation_space, action_space, model_config, custom_config, **kwargs - ) + super().__init__(observation_space, action_space, model_config, **kwargs) diff --git a/malib/rollout/envs/env.py b/malib/rollout/envs/env.py index 4c27f954..2325322c 100644 --- a/malib/rollout/envs/env.py +++ b/malib/rollout/envs/env.py @@ -54,6 +54,7 @@ def __init__(self, **configs): self._configs = configs self._current_players = [] self._state: Dict[str, np.ndarray] = None + self._deactivated = True def record_episode_info_step( self, @@ -103,11 +104,19 @@ def action_spaces(self) -> Dict[AgentID, gym.Space]: raise NotImplementedError + @property + def is_deactivated(self) -> bool: + return self._deactivated + + def deactivate(self): + self._deactivated = True + def reset(self, max_step: int = None) -> Union[None, Sequence[Dict[AgentID, Any]]]: """Reset environment and the episode info handler here.""" self.max_step = max_step or self.max_step self.cnt = 0 + self._deactivated = False self.episode_metrics = { "env_step": 0, diff --git a/malib/rollout/envs/mdp/env.py b/malib/rollout/envs/mdp/env.py index 3ad977a5..6b8f7fb3 100644 --- a/malib/rollout/envs/mdp/env.py +++ b/malib/rollout/envs/mdp/env.py @@ -35,7 +35,7 @@ def __init__(self, **configs): ) self.env = scenarios[env_id]().to_env() - self._possible_agents = ["agent"] + self._possible_agents = ["default"] @property def possible_agents(self) -> List[AgentID]: @@ -57,7 +57,7 @@ def time_step( Dict[AgentID, bool], Dict[AgentID, Any], ]: - obs, rew, done, info = self.env._step(actions["agent"]) + obs, rew, done, info = self.env._step(actions["default"]) obs = dict.fromkeys(self.possible_agents, obs) rew = dict.fromkeys(self.possible_agents, rew) diff --git a/malib/rollout/inference/env_runner.py b/malib/rollout/inference/env_runner.py index 1b23fa02..e2558b92 100644 --- a/malib/rollout/inference/env_runner.py +++ b/malib/rollout/inference/env_runner.py @@ -23,7 +23,7 @@ # SOFTWARE. from argparse import Namespace -from typing import Any, List, Dict, Tuple, Set +from typing import Any, List, Dict, Tuple, Set, Type from types import LambdaType from collections import defaultdict @@ -37,16 +37,156 @@ from ray.actor import ActorHandle from malib.utils.typing import AgentID, DataFrame, BehaviorMode -from malib.utils.episode import NewEpisodeList +from malib.utils.episode import ConventionalEpisodeList from malib.utils.preprocessor import Preprocessor, get_preprocessor from malib.utils.timing import Timing from malib.remote.interface import RemoteInterface from malib.rollout.envs.vector_env import VectorEnv, SubprocVecEnv +from malib.common.rollout_config import RolloutConfig from malib.rollout.inference.client import InferenceClient from malib.rollout.inference.utils import process_env_rets, process_policy_outputs +from malib.rollout.envs.env import Environment from malib.backend.dataset_server.utils import send_data +class AgentManager: + def __init__(self, episode_num, inference_clients) -> None: + self.inference_clients = inference_clients + self.episodes = ConventionalEpisodeList( + num=episode_num, agents=list(inference_clients.keys()) + ) + + def collect_and_act(self, episode_idx, raw_obs, last_dones, last_rews, states): + if not last_dones["__all__"]: + action_and_obs = { + k: ray.get(v.compute_action.remote(raw_obs[k], states[k])) + for k, v in self.inference_clients.items() + } + actions = {} + obs = {} + for k, v in action_and_obs.items(): + actions[k] = v[0] + obs[k] = v[1] + else: + actions = None + obs = { + k: ray.get(v.preprocess_obs.remote(raw_obs[k])) + for k, v in self.inference_clients.items() + } + + self.episodes.record(obs, last_dones, last_rews, states, episode_idx) + + return actions + + def merge_episodes(self): + return self.episodes.to_numpy() + + +class BasicEnvRunner(RemoteInterface): + def __repr__(self) -> str: + return super().__repr__() + + def __init__( + self, env_func: Type, max_env_num: int, use_subproc_env: bool = False + ) -> None: + super().__init__() + + self._use_subproc_env = use_subproc_env + self._max_env_num = max_env_num + self._env_func = env_func + self._envs = [] + + @property + def envs(self) -> Tuple[Environment]: + return tuple(self._envs) + + @property + def env_func(self) -> Type: + return self._env_func + + @property + def num_active_envs(self) -> int: + return len(self._envs) + + @property + def use_subproc_env(self) -> bool: + return self._use_subproc_env + + @property + def max_env_num(self) -> int: + return self._max_env_num + + def run( + self, + inference_clients: Dict[AgentID, InferenceClient], + rollout_config: RolloutConfig, + data_entrypoint_mapping: Dict[AgentID, str] = None, + ): + """Single thread env simulation stepping. + + Args: + inference_clients (Dict[AgentID, InferenceClient]): A dict of remote inference client. + rollout_config (RolloutConfig): Rollout configuration, which specifies how many data pieces will rollout. + data_entrypoint_mapping (Dict[AgentID, str], optional): A mapping which defines the data collection trigger, if not None, then return episodes. Defaults to None. + + Raises: + e: _description_ + + Returns: + _type_: _description_ + """ + + new_env_num = max(0, rollout_config.n_envs_per_worker - self.num_active_envs) + + for _ in range(new_env_num): + self._envs.append(self.env_func()) + + # reset envs + envs = self.envs[: rollout_config.n_envs_per_worker] + vec_states, vec_obs, vec_dones, vec_rews = [], [], [], [] + + for env in envs: + states, obs = env.reset(max_step=rollout_config.timelimit) + vec_states.append(states) + vec_obs.append(obs) + vec_dones.append(False) + vec_rews.append(0.0) + + active_env_num = len(envs) + agent_manager = AgentManager(active_env_num, inference_clients) + + while active_env_num: + for env_idx, (env, states, obs, dones, rews) in enumerate( + zip(envs, vec_states, vec_obs, vec_dones, vec_rews) + ): + if env.is_deactivated(): + continue + + actions = agent_manager.collect_and_act( + env_idx, + raw_obs=obs, + last_dones=dones, + last_rews=rews, + states=states, + ) + + if actions is None: + # which means done already + active_env_num -= 1 + env.set_done() + else: + states, obs, rews, dones = env.step(actions) + # update frames + vec_states[env_idx] = states + vec_obs[env_idx] = obs + vec_dones[env_idx] = dones + vec_rews[env_idx] = rews + + # merge agent episodes + data = agent_manager.merge_episodes() + return data + + class EnvRunner(RemoteInterface): def __repr__(self) -> str: return f"" @@ -142,7 +282,7 @@ def run( "strategy_specs": rollout_config["strategy_specs"], } - eval_results, performance = env_runner( + eval_results, performance = _env_runner( self, inference_clients, self.preprocessor, @@ -157,7 +297,7 @@ def run( return res -def env_runner( +def _env_runner( client: InferenceClient, agents: Dict[str, InferenceClient], preprocessors: Dict[str, Preprocessor], diff --git a/malib/utils/episode.py b/malib/utils/episode.py index a1d39961..a501a0db 100644 --- a/malib/utils/episode.py +++ b/malib/utils/episode.py @@ -143,6 +143,62 @@ def to_numpy(self) -> Dict[AgentID, Dict[str, np.ndarray]]: return dict(res) +class ConventionalEpisode(Episode): + def __init__(self, agents: List[AgentID], processors=None): + super().__init__(agents, processors) + self.agent_buffer = {} + + def record(self, obs, last_dones, last_rews, states): + for agent, _obs in obs.items(): + self.agent_entry[agent][Episode.CUR_OBS].append(_obs) + self.agent_entry[agent][Episode.PRE_DONE].append(last_dones[agent]) + self.agent_entry[agent][Episode.PRE_REWARD].append(last_rews[agent]) + self.agent_entry[agent][Episode.CUR_STATE].append(states[agent]) + + def clear_buffer(self): + self.agent_buffer = {} + + def to_numpy(self) -> Dict[AgentID, Dict[str, np.ndarray]]: + if len(self.agent_entry) == 0: + for agent, agent_trajectory in self.agent_entry.items(): + agent_traj_np = {} + agent_traj_np[Episode.CUR_OBS] = np.stack( + agent_trajectory[Episode.CUR_OBS][:-1] + ) + agent_traj_np[Episode.NEXT_OBS] = np.stack( + agent_trajectory[Episode.CUR_OBS][1:] + ) + agent_traj_np[Episode.DONE] = np.stack( + agent_trajectory[Episode.PRE_DONE][1:] + ) + agent_traj_np[Episode.REWARD] = np.stack( + agent_trajectory[Episode.PRE_REWARD][1:] + ) + agent_traj_np[Episode.ACTION] = np.stack( + agent_trajectory[Episode.ACTION] + ) + self.agent_buffer[agent] = agent_traj_np + return self.agent_buffer + + +class ConventionalEpisodeList: + def __init__(self, num: int, agents: List[AgentID]) -> None: + self.episodes = [ConventionalEpisode(agents) for _ in range(num)] + + def record(self, obs, last_dones, last_rews, states, idx: int = None): + if idx is not None: + self.episodes[i].record(obs, last_dones, last_rews, states) + else: + for i, episode in enumerate(self.episodes): + episode.record(obs[i], last_dones[i], last_rews[i], states[i]) + + def to_numpy(self) -> List[Dict[AgentID, Dict[str, np.ndarray]]]: + res = [] + for episode in self.episodes: + res.append(episode.to_numpy()) + return res + + class NewEpisodeDict(defaultdict): """Episode dict, for trajectory tracking for a bunch of environments.""" diff --git a/tests/rollout/test_env_runner.py b/tests/rollout/test_env_runner.py index 7085bd40..b56c83a7 100644 --- a/tests/rollout/test_env_runner.py +++ b/tests/rollout/test_env_runner.py @@ -7,6 +7,7 @@ from malib.rollout.inference import env_runner from malib.rollout.inference.client import InferenceClient from malib.rollout.envs import mdp +from malib.rl.random import RandomPolicy @pytest.mark.parametrize( @@ -16,8 +17,11 @@ ], ) def test_env_runner(env_desc: Dict[str, Any], max_env_num: int): + # mapping from agents to agents agent_groups = dict(zip(env_desc["possible_agents"], env_desc["possible_agents"])) - runner = env_runner.EnvRunner(env_desc, max_env_num, agent_groups) + runner = env_runner.EnvRunner( + env_desc, max_env_num, agent_groups, use_subproc_env=False + ) agents = env_desc["possible_agents"] observation_spaces = env_desc["observation_spaces"] @@ -27,7 +31,14 @@ def test_env_runner(env_desc: Dict[str, Any], max_env_num: int): rollout_config = { "flag": "evaluation", "strategy_specs": { - agent: StrategySpec(agent, ["policy-0"], meta_data={}) for agent in agents + agent: StrategySpec( + policy_cls=RandomPolicy, + observation_space=observation_spaces["default"], + action_space=action_spaces["default"], + identifier=agent, + policy_ids=["policy-0"], + ) + for agent in agents }, "behavior_mode": BehaviorMode.EXPLOITATION, } diff --git a/tests/malib_utils/test_episode.py b/tests/structures/test_episode.py similarity index 100% rename from tests/malib_utils/test_episode.py rename to tests/structures/test_episode.py diff --git a/tests/malib_utils/test_payoff_manager.py b/tests/structures/test_payoff_manager.py similarity index 100% rename from tests/malib_utils/test_payoff_manager.py rename to tests/structures/test_payoff_manager.py diff --git a/tests/structures/test_strategy_spec.py b/tests/structures/test_strategy_spec.py new file mode 100644 index 00000000..31439796 --- /dev/null +++ b/tests/structures/test_strategy_spec.py @@ -0,0 +1,51 @@ +import pytest +import argparse + +from malib.rl.random import RandomPolicy +from malib.common.strategy_spec import StrategySpec +from malib.rollout.envs.mdp import env_desc_gen + + +class TestStrategySpec: + @pytest.fixture + def spec(self) -> StrategySpec: + env_desc = argparse.Namespace(**env_desc_gen(env_id="one_round_dmdp")) + return StrategySpec( + policy_cls=RandomPolicy, + observation_space=env_desc.observation_spaces["default"], + action_space=env_desc.action_spaces["default"], + identifier="random", + ) + + def test_existing_policy_register(self, spec: StrategySpec): + spec.register_policy_id("random0") + assert len(spec) == 1 + with pytest.raises(KeyError): + spec.register_policy_id("random0") + spec.register_policy_id("random2") + assert len(spec) == 2 + assert spec.policy_ids == ("random0", "random2") + + def test_policy_generation(self, spec: StrategySpec): + model = spec.gen_policy(device="cpu") + model = spec.gen_policy(device="cuda:0") + + def test_policy_dist_reset(self, spec: StrategySpec): + spec.register_policy_id("random0") + spec.register_policy_id("random2") + spec.update_prob_list({"random0": 0.5, "random2": 0.5}) + + with pytest.raises(ValueError) as excinfo: + spec.update_prob_list({"random0": 0.5, "random2": 0.5, "random3": 0.5}) + + with pytest.raises(ValueError) as excinfo: + spec.update_prob_list({"random0": 0.5, "random2": 0.1}) + + def test_policy_sampling(self, spec: StrategySpec): + spec.register_policy_id("random0") + spec.register_policy_id("random2") + spec.sample() + spec.update_prob_list({"random0": 1, "random2": 0}) + for _ in range(10): + pid = spec.sample() + assert "random0" == pid, pid From f30e6e8d7dda3f89c81bf6c5f950da785048e25a Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Thu, 9 Nov 2023 20:19:05 +0800 Subject: [PATCH 14/24] env runner test passed --- malib/common/strategy_spec.py | 7 +- malib/models/config.py | 15 + .../inference => models}/model_client.py | 24 +- malib/rl/common/policy.py | 243 +++++++--------- malib/rl/pg/policy.py | 21 +- malib/rollout/__init__.py | 4 +- malib/rollout/inference/client.py | 103 +++++-- malib/rollout/inference/env_runner.py | 266 ++++++++---------- malib/rollout/pb_rolloutworker.py | 48 ++-- malib/{common => rollout}/rollout_config.py | 6 + malib/rollout/rolloutworker.py | 80 ++---- malib/utils/episode.py | 18 +- tests/models/test_model_client.py | 7 + tests/rl/test_policy_gradient.py | 7 + tests/rollout/test_env_runner.py | 59 ++-- 15 files changed, 457 insertions(+), 451 deletions(-) create mode 100644 malib/models/config.py rename malib/{rollout/inference => models}/model_client.py (74%) rename malib/{common => rollout}/rollout_config.py (86%) create mode 100644 tests/models/test_model_client.py create mode 100644 tests/rl/test_policy_gradient.py diff --git a/malib/common/strategy_spec.py b/malib/common/strategy_spec.py index ec581947..bdf24ccf 100644 --- a/malib/common/strategy_spec.py +++ b/malib/common/strategy_spec.py @@ -68,7 +68,7 @@ def __init__( Args: identifier (str): Runtime id as identifier. - policy_ids (Tuple[PolicyID]): A tuple of policy id, could be empty. + policy_ids (Tuple[PolicyID], optional): A tuple of policy id, could be empty. Defaults to None, then sampling will return a None. meta_data (Dict[str, Any]): Meta data, for policy construction. """ @@ -162,6 +162,11 @@ def sample(self) -> PolicyID: PolicyID: A sampled policy id. """ + if len(self) == 0: + raise RuntimeError( + "No policy id registered, it would be feasible for an active policy." + ) + prob_list = self.meta_data.get( "prob_list", [1 / self.num_policy] * self.num_policy ) diff --git a/malib/models/config.py b/malib/models/config.py new file mode 100644 index 00000000..bb1a2108 --- /dev/null +++ b/malib/models/config.py @@ -0,0 +1,15 @@ +from typing import Type, Dict, Any + +from dataclasses import dataclass + + +@dataclass +class ModelConfig: + + model_cls: Type + + model_args: Dict[str, Any] + + def to_dict(self): + _dict = self.__dict__.copy() + return _dict diff --git a/malib/rollout/inference/model_client.py b/malib/models/model_client.py similarity index 74% rename from malib/rollout/inference/model_client.py rename to malib/models/model_client.py index 5391e0a1..083b7092 100644 --- a/malib/rollout/inference/model_client.py +++ b/malib/models/model_client.py @@ -2,14 +2,13 @@ from concurrent import futures import threading +import torch +import ray from readerwriterlock import rwlock from torch import nn -import torch -import ray - -from malib.utils.typing import AgentID, DataFrame +from malib.models.config import ModelConfig def load_state_dict(client, timeout=10): @@ -20,9 +19,18 @@ def load_state_dict(client, timeout=10): class ModelClient: - def __init__( - self, entry_point: str, model_cls: nn.Module, model_args: Dict[str, Any] - ): + def __init__(self, entry_point: str, model_config: ModelConfig): + """Construct a model client for mantaining a model instance and its update. + + Args: + entry_point (str): Entry point for model update. + model_cls (nn.Module): Model class for constructing model instance. + model_args (Dict[str, Any]): Arguments for constructing model instance. + + Raises: + NotImplementedError: Unsupported cluster type. + """ + cluster_type, name_or_address = entry_point.split(":") if "ray" in cluster_type: @@ -36,7 +44,7 @@ def __init__( self.event = threading.Event() self.thread_pool.submit(self._model_update, self.event) - self.model: nn.Module = model_cls(**model_args).cpu() + self.model: nn.Module = model_config.model_cls(**model_config.model_args).cpu() self.model.share_memory() def __call__(self, *args: Any, **kwds: Any) -> Any: diff --git a/malib/rl/common/policy.py b/malib/rl/common/policy.py index 39784cba..ba346aa7 100644 --- a/malib/rl/common/policy.py +++ b/malib/rl/common/policy.py @@ -24,16 +24,20 @@ from abc import ABCMeta, abstractmethod from typing import Dict, Any, Tuple, Union -from enum import IntEnum +from collections import namedtuple +import copy import torch import torch.nn as nn import gym +import numpy as np from gym import spaces -from malib.utils.preprocessor import get_preprocessor +from malib.utils.preprocessor import get_preprocessor, Preprocessor from malib.common.distributions import make_proba_distribution, Distribution +from malib.models.config import ModelConfig +from malib.models.model_client import ModelClient class SimpleObject: @@ -56,13 +60,21 @@ def state_dict(self): return value -Action = Any -ActionDist = Any -Logits = Any +Action = np.ndarray +ActionDist = np.ndarray +Logits = np.ndarray + +PolicyReturn = namedtuple("PolicyReturn", "action,action_dist,logits,others") class Policy(metaclass=ABCMeta): - def __init__(self, observation_space, action_space, model_config, **kwargs): + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Space, + model_config: Union[ModelConfig, Dict[str, Any]], + **kwargs, + ): _locals = locals() _locals.pop("self") self._init_args = _locals @@ -70,20 +82,17 @@ def __init__(self, observation_space, action_space, model_config, **kwargs): self._action_space = action_space self._model_config = model_config self._custom_config = kwargs - self._state_handler_dict = {} self._preprocessor = get_preprocessor( observation_space, mode=kwargs.get("preprocess_mode", "flatten"), )(observation_space) - self._device = torch.device("cuda" if kwargs.get("use_cuda") else "cpu") - - self._registered_networks: Dict[str, nn.Module] = {} + self._device = torch.device(kwargs.get("device", "cpu")) if isinstance(action_space, spaces.Discrete): - self.action_type = "discrete" + self._action_type = "discrete" elif isinstance(action_space, spaces.Box): - self.action_type = "continuous" + self._action_type = "continuous" else: raise NotImplementedError( "Does not support other action space type settings except Box and Discrete. {}".format( @@ -91,145 +100,100 @@ def __init__(self, observation_space, action_space, model_config, **kwargs): ) ) - self.use_cuda = kwargs.get("use_cuda", False) - self.dist_fn: Distribution = make_proba_distribution( + self._dist_fn: Distribution = make_proba_distribution( action_space=action_space, use_sde=kwargs.get("use_sde", False), dist_kwargs=kwargs.get("dist_kwargs", None), ) - self.model = kwargs.get("model_client", self.create_model()) - - def create_model(self): + self._model = kwargs.get("model_client") + if self._model is None: + if kwargs.get("model_entry_point"): + self._model = ModelClient(kwargs["model_entry_point"], model_config) + else: + self._model = self.create_model().to(self._device) + + def create_model(self) -> nn.Module: raise NotImplementedError @property - def action_space(self) -> gym.Space: - return self._action_space + def dist_fn(self) -> Distribution: + return self._dist_fn @property - def observation_space(self) -> gym.Space: - return self._observation_space + def action_type(self) -> str: + return self._action_type @property - def model_config(self): - return self._model_config + def action_space(self) -> spaces.Space: + return self._action_space @property - def device(self) -> str: - return self._device + def model(self) -> nn.Module: + return self._model @property - def custom_config(self) -> Dict[str, Any]: - return self._custom_config + def observation_space(self) -> spaces.Space: + return self._observation_space @property - def target_actor(self): - return self._target_actor - - @target_actor.setter - def target_actor(self, value: Any): - self._target_actor = value + def model_config(self) -> Dict[str, Any]: + if isinstance(self._model_config, ModelConfig): + return self._model_config.to_dict() + else: + return copy.deepcopy(self._model_config) @property - def actor(self): - return self._actor - - @actor.setter - def actor(self, value: Any): - self._actor = value + def preprocessor(self) -> Preprocessor: + return self._preprocessor @property - def critic(self): - return self._critic - - @critic.setter - def critic(self, value: Any): - self._critic = value + def device(self) -> str: + return self._device @property - def target_critic(self): - return self._target_critic - - @target_critic.setter - def target_critic(self, value: Any): - self._target_critic = value + def custom_config(self) -> Dict[str, Any]: + return copy.deepcopy(self._custom_config) - def load_state_dict(self, state_dict: Dict[str, Any]): + def load_state_dict( + self, state_dict: Dict[str, Any] = None, checkpoint: str = None + ) -> "Policy": """Load state dict outside. Args: state_dict (Dict[str, Any]): A dict of states. """ - for k, v in state_dict.items(): - self._state_handler_dict[k].load_state_dict(v) - - def state_dict(self, device=None): - """Return state dict in real time""" + if state_dict is not None: + self.model.load_state_dict(state_dict) + elif checkpoint is not None: + self.model.load_state_dict(torch.load(checkpoint)) - if device is None: - res = {k: v.state_dict() for k, v in self._state_handler_dict.items()} - else: - res = {} - for k, v in self._state_handler_dict.items(): - if isinstance(v, torch.nn.Module): - tmp = {} - for _k, _v in v.state_dict().items(): - tmp[_k] = _v.cpu() - else: - tmp = v.state_dict() - res[k] = tmp - return res + return self - def register_state(self, obj: Any, name: str) -> None: - """Register state of obj. Called in init function to register model states. - - Example: - >>> class CustomPolicy(Policy): - ... def __init__( - ... self, - ... registered_name, - ... observation_space, - ... action_space, - ... model_config, - ... custom_config - ... ): - ... # ... - ... actor = MLP(...) - ... self.register_state(actor, "actor") + def state_dict( + self, device: Union[torch.DeviceObjType, str] = None + ) -> Dict[str, Any]: + """Return state dict of model. Args: - obj (Any): Any object, for non `torch.nn.Module`, it will be wrapped as a `Simpleobject`. - name (str): Humanreadable name, to identify states. + device (Union[torch.DeviceObjType, str], optional): Device name. Defaults to None. - Raises: - errors.RepeatedAssignError: [description] + Returns: + Dict[str, Any]: A state dict """ - # if not isinstance(obj, nn.Module): - if obj.__class__.__module__ == "builtins": - n = SimpleObject(self, name) - n.load_state_dict(obj) - obj = n - - self._state_handler_dict[name] = obj - if isinstance(obj, nn.Module): - self._registered_networks[name] = obj - - def deregister_state(self, name: str): - if self._state_handler_dict.get(name) is None: - print(f"No such state tagged with: {name}") + if device is None: + res = self.model.state_dict() else: - self._state_handler_dict.pop(name) - print(f"Deregister state tagged with: {name}") + res = {} + for k, v in self.model.state_dict(): + res[k] = v.to(device) + + return res def get_initial_state(self, batch_size: int = None): return None - @property - def preprocessor(self): - return self._preprocessor - @abstractmethod def compute_action( self, @@ -238,7 +202,7 @@ def compute_action( evaluate: bool, hidden_state: Any = None, **kwargs, - ) -> Tuple[Action, ActionDist, Logits, Any]: + ) -> PolicyReturn: pass def save(self, path, global_step=0, hard: bool = False): @@ -254,16 +218,22 @@ def load(self, path: str): def reset(self, **kwargs): """Reset parameters or behavior policies.""" - pass @classmethod - def copy(cls, instance, replacement: Dict): - return cls(replacement=replacement, **instance._init_args) + def copy(cls, instance: "Policy", replacement: Dict) -> "Policy": + """Self copy, from a given instance. The replacement is a dict of new arguments to override. - @property - def registered_networks(self) -> Dict[str, nn.Module]: - return self._registered_networks + Args: + instance (Policy): A policy instance to copy from. Must be an instance of cls. + replacement (Dict): A dict of new arguments to override. + + Returns: + New policy instance. + """ + + kwargs = {**replacement, **instance._init_args} + return cls(**kwargs) def to(self, device: str = None, use_copy: bool = False) -> "Policy": """Convert policy to a given device. If `use_copy`, then return a copy. If device is None, do not change device. @@ -292,17 +262,13 @@ def to(self, device: str = None, use_copy: bool = False) -> "Policy": replacement = {} if cond1 or cond2: - # retrieve networks here - for k, v in self.registered_networks.items(): - _v = v.to(device) - if not use_copy: - setattr(self, k, _v) - else: - replacement[k] = _v + _model = self.model.to(device) + if not use_copy: + setattr(self, "model", _model) + else: + replacement["model_client"] = _model else: - # fixed bug: replacement cannot be None. - for k, v in self.registered_networks.items(): - replacement[k] = v + replacement["model_client"] = self.model if use_copy: ret = self.copy(self, replacement=replacement) @@ -312,27 +278,20 @@ def to(self, device: str = None, use_copy: bool = False) -> "Policy": return ret - def parameters(self) -> Dict[str, Dict]: - """Return trainable parameters.""" - - res = {} - for name, net in self.registered_networks.items(): - res[name] = net.parameters() - return res - - def update_parameters(self, parameter_dict: Dict[str, Any]): - """Update local parameters with given parameter dict. + def parameters(self, recurse: bool = True): + """Returns an iterator over module parameters. + This is typically passed to an optimizer. Args: - parameter_dict (Dict[str, Parameter]): A dict of paramters + recurse (bool, optional): If True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Defaults to True. + + Yields: + Parameter: module parameter """ - for k, parameters in parameter_dict.items(): - target = self.registered_networks[k] - for target_param, param in zip(target.parameters(), parameters): - target_param.data.copy_(param.data) + return self.model.parameters(recurse=recurse) def coordinate(self, state: Dict[str, torch.Tensor], message: Any) -> Any: """Coordinate with other agents here""" - raise NotImplementedError + pass diff --git a/malib/rl/pg/policy.py b/malib/rl/pg/policy.py index c4773182..20b23a2c 100644 --- a/malib/rl/pg/policy.py +++ b/malib/rl/pg/policy.py @@ -30,10 +30,11 @@ from gym import spaces from torch import nn +from malib.utils.general import merge_dicts from malib.models.torch import net, discrete, continuous +from malib.models.config import ModelConfig from malib.rl.common import misc -from malib.rl.common.policy import Policy, Action, ActionDist, Logits -from malib.utils.general import merge_dicts +from malib.rl.common.policy import Policy, PolicyReturn from .config import DEFAULT_CONFIG @@ -42,7 +43,7 @@ def __init__( self, observation_space: spaces.Space, action_space: spaces.Space, - model_config: Dict[str, Any] = None, + model_config: Union[ModelConfig, Dict[str, Any]] = None, **kwargs ): """Build a REINFORCE policy whose input and output dims are determined by observation_space and action_space, respectively. @@ -80,7 +81,7 @@ def create_model(self): **self.model_config["preprocess_net"]["config"] ) if isinstance(self.action_space, spaces.Discrete): - self.actor = discrete.Actor( + return discrete.Actor( preprocess_net=preprocess_net, action_shape=action_shape, hidden_sizes=self.model_config["hidden_sizes"], @@ -88,7 +89,7 @@ def create_model(self): device=self.device, ) elif isinstance(self.action_space, spaces.Box): - self.actor = continuous.Actor( + return continuous.Actor( preprocess_net=preprocess_net, action_shape=action_shape, hidden_sizes=self.model_config["hidden_sizes"], @@ -100,8 +101,6 @@ def create_model(self): "Unexpected action space type: {}".format(type(self.action_space)) ) - self.register_state(self.actor, "actor") - def value_function(self, observation: torch.Tensor, evaluate: bool, **kwargs): """Compute values of critic.""" @@ -114,9 +113,9 @@ def compute_action( evaluate: bool, hidden_state: Any = None, **kwargs - ) -> Tuple[Action, ActionDist, Logits, Any]: - with torch.no_grad(): - logits, hidden = self.actor(observation, state=hidden_state) + ) -> PolicyReturn: + with torch.inference_mode(): + logits, hidden = self.model(observation, state=hidden_state) if isinstance(logits, tuple): dist = self.dist_fn.proba_distribution(*logits) else: @@ -135,4 +134,4 @@ def compute_action( action_dist = probs state = hidden - return action, action_dist, logits, state + return PolicyReturn(action, action_dist, logits, state) diff --git a/malib/rollout/__init__.py b/malib/rollout/__init__.py index 0e02ff91..bc26123e 100644 --- a/malib/rollout/__init__.py +++ b/malib/rollout/__init__.py @@ -23,8 +23,8 @@ # SOFTWARE. from .pb_rolloutworker import RolloutWorker -from .inference.env_runner import EnvRunner +from .inference.env_runner import BasicEnvRunner from .inference.client import InferenceClient -__all__ = ["RolloutWorker", "EnvRunner", "InferenceClient"] +__all__ = ["RolloutWorker", "BasicEnvRunner", "InferenceClient"] diff --git a/malib/rollout/inference/client.py b/malib/rollout/inference/client.py index fa28c7f5..3cf50bf1 100644 --- a/malib/rollout/inference/client.py +++ b/malib/rollout/inference/client.py @@ -22,69 +22,124 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Any, List, Dict +from typing import Any, List, Dict, Type, Tuple from functools import reduce from operator import mul from collections import namedtuple -from concurrent.futures import ThreadPoolExecutor import os -import pickle as pkl import gym +import torch +import numpy as np -from malib import settings from malib.remote.interface import RemoteInterface from malib.utils.typing import AgentID, DataFrame from malib.utils.timing import Timing from malib.utils.episode import Episode from malib.common.strategy_spec import StrategySpec -from malib.rl.common.policy import Policy + +from malib.models.config import ModelConfig +from malib.rl.common.policy import Policy, PolicyReturn Connection = namedtuple("Connection", "sender,recver,runtime_config,rnn_states") +PolicyReturnWithObs = namedtuple("PolicyReturnWithObs", PolicyReturn._fields + ("obs",)) class InferenceClient(RemoteInterface): def __init__( self, - agent_id: AgentID, + entry_point: str, + policy_cls: Type, observation_space: gym.Space, action_space: gym.Space, + model_config: ModelConfig, ) -> None: """Create ray-based inference server. Args: - agent_id (AgentID): Runtime agent id, not environment agent id. + entry_point (str): Entrypoint for model update. observation_space (gym.Space): Observation space related to the governed environment agents. action_space (gym.Space): Action space related to the governed environment agents. """ - self.runtime_agent_id = agent_id self.observation_space = observation_space self.action_space = action_space - self.thread_pool = ThreadPoolExecutor() - self.policies: Dict[str, Policy] = {} - self.strategy_spec_dict: Dict[str, StrategySpec] = {} + self.fixed_policy: Policy = policy_cls( + observation_space, action_space, model_config + ) + self.active_policy: Policy = policy_cls( + observation_space, + action_space, + model_config, + model_entry_point=entry_point, + ) def shutdown(self): - self.thread_pool.shutdown(wait=True) - for _handler in self.connections.values(): - _handler.sender.shutdown(True) - _handler.recver.shutdown(True) - self.connections: Dict[int, Connection] = {} - - def save(self, model_dir: str) -> None: - if not os.path.exists(model_dir): - os.makedirs(model_dir) + pass - for pid, policy in self.policies.items(): - fp = os.path.join(model_dir, pid + ".pkl") - with open(fp, "wb") as f: - pkl.dump(policy, f, protocol=settings.PICKLE_PROTOCOL_VER) + def process_obs(self, raw_observation: Any) -> np.ndarray: + return self.fixed_policy.preprocessor.transform(raw_observation) def compute_action( + self, + raw_obs: Any, + state: Any, + last_reward: float, + last_done: float, + active_policy: bool = False, + checkpoint: str = None, + require_obs_return: bool = True, + ) -> PolicyReturnWithObs: + """Compute actions for given observations. + + Args: + raw_obs (Any): Raw observations. + state (Any): State. + last_reward (float): Last reward. + last_done (float): Last done. + active_policy (bool, optional): Whether to use active model. Defaults to False. + checkpoint (str, optional): Checkpoint path. Defaults to None. + + Returns: + PolicyReturnWithObs: An instance of PolicyReturnWithObs. + """ + + if active_policy: + policy = self.active_policy + evaluate = False + else: + policy = self.fixed_policy + evaluate = True + + if checkpoint is not None: + if not os.path.exists(checkpoint): + raise RuntimeError(f"Checkpoint {checkpoint} not found.") + policy.model.load_state_dict(torch.load(checkpoint)) + + with torch.inference_mode(): + obs = self.fixed_policy.preprocessor.transform(raw_obs) + obs = torch.from_numpy(obs).float() + # FIXME(ming): act mask and hidden state is set to None, + # not feasible for cases which require them + policy_return = policy.compute_action( + observation=obs, + act_mask=None, + evaluate=evaluate, + hidden_state=None, + state=state, + last_reward=last_reward, + last_done=last_done, + ) + _returns: dict = policy_return._asdict() + if require_obs_return: + _returns.update({"obs": obs}) + policy_return = PolicyReturnWithObs(**_returns) + return policy_return + + def compute_action_with_frames( self, dataframes: List[DataFrame], runtime_config: Dict[str, Any] ) -> List[DataFrame]: timer = Timing() diff --git a/malib/rollout/inference/env_runner.py b/malib/rollout/inference/env_runner.py index e2558b92..91ba976d 100644 --- a/malib/rollout/inference/env_runner.py +++ b/malib/rollout/inference/env_runner.py @@ -22,7 +22,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from argparse import Namespace from typing import Any, List, Dict, Tuple, Set, Type from types import LambdaType from collections import defaultdict @@ -33,48 +32,96 @@ import pickle import ray +import numpy as np from ray.actor import ActorHandle from malib.utils.typing import AgentID, DataFrame, BehaviorMode from malib.utils.episode import ConventionalEpisodeList -from malib.utils.preprocessor import Preprocessor, get_preprocessor + from malib.utils.timing import Timing from malib.remote.interface import RemoteInterface from malib.rollout.envs.vector_env import VectorEnv, SubprocVecEnv -from malib.common.rollout_config import RolloutConfig -from malib.rollout.inference.client import InferenceClient +from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.inference.client import InferenceClient, PolicyReturnWithObs from malib.rollout.inference.utils import process_env_rets, process_policy_outputs from malib.rollout.envs.env import Environment +from malib.common.strategy_spec import StrategySpec from malib.backend.dataset_server.utils import send_data class AgentManager: - def __init__(self, episode_num, inference_clients) -> None: + def __init__( + self, + episode_num: int, + inference_clients: Dict[AgentID, ray.ObjectRef], + strategy_specs: Dict[AgentID, StrategySpec], + ) -> None: + """Construct a unified API for multi-agent action caller. + + Args: + episode_num (int): Defines how many episodes will be collected, used for initialization. + inference_clients (Dict[AgentID, ray.ObjectRef]): A dict of remote inference clients. + """ + self.inference_clients = inference_clients + self.strategy_specs = strategy_specs self.episodes = ConventionalEpisodeList( num=episode_num, agents=list(inference_clients.keys()) ) + self.use_active_policy = dict.fromkeys(self.inference_clients.keys(), False) + self.checkpoints = {} - def collect_and_act(self, episode_idx, raw_obs, last_dones, last_rews, states): - if not last_dones["__all__"]: - action_and_obs = { - k: ray.get(v.compute_action.remote(raw_obs[k], states[k])) - for k, v in self.inference_clients.items() - } - actions = {} - obs = {} - for k, v in action_and_obs.items(): - actions[k] = v[0] - obs[k] = v[1] - else: - actions = None - obs = { - k: ray.get(v.preprocess_obs.remote(raw_obs[k])) - for k, v in self.inference_clients.items() - } + def set_behavior_policy(self): + """Specify behavior policy for each agent.""" - self.episodes.record(obs, last_dones, last_rews, states, episode_idx) + for agent_id, strategy_spec in self.strategy_specs.items(): + if len(strategy_spec) == 0: + self.use_active_policy[agent_id] = True + else: + self.checkpoints[agent_id] = strategy_spec.sample() + + def collect_and_act( + self, + episode_idx: int, + raw_obs: Dict[AgentID, Any], + last_dones: Dict[AgentID, bool], + last_rews: Dict[AgentID, float], + states: Dict[AgentID, np.ndarray] = None, + ) -> Dict[str, Any]: + """Collect give timestep, if last_dones['__all__'] is True, then return a None action. + + Args: + episode_idx (int): Episode index, for identifying episode buffer. + raw_obs (Dict[AgentID, Any]): A dict of raw agent observations. + last_dones (Dict[AgentID, bool]): A dict of agent dones, accompanying with __all__ to identify environment done. + last_rews (Dict[AgentID, float]): A dict of rewards of last timestep. + states (Dict[AgentID, np.ndarray], optional): A dict of states. Defaults to None. + + Returns: + Dict[str, Any]: A dict of actions. + """ + + policy_return_with_obs: Dict[AgentID, PolicyReturnWithObs] = { + k: ray.get( + v.compute_action.remote( + raw_obs=raw_obs[k], + state=states[k] if states is not None else None, + last_reward=last_rews[k], + last_done=last_dones[k], + active_policy=self.use_active_policy[k], + checkpoint=self.checkpoints.get(k), + ) + ) + for k, v in self.inference_clients.items() + } + actions = {} + obs = {} + for k, v in policy_return_with_obs.items(): + actions[k] = v.action + obs[k] = v.obs + + self.episodes.record(obs, actions, last_dones, last_rews, states, episode_idx) return actions @@ -82,6 +129,9 @@ def merge_episodes(self): return self.episodes.to_numpy() +from malib.utils.timing import Timing + + class BasicEnvRunner(RemoteInterface): def __repr__(self) -> str: return super().__repr__() @@ -120,6 +170,7 @@ def run( self, inference_clients: Dict[AgentID, InferenceClient], rollout_config: RolloutConfig, + strategy_specs: Dict[AgentID, StrategySpec], data_entrypoint_mapping: Dict[AgentID, str] = None, ): """Single thread env simulation stepping. @@ -127,6 +178,7 @@ def run( Args: inference_clients (Dict[AgentID, InferenceClient]): A dict of remote inference client. rollout_config (RolloutConfig): Rollout configuration, which specifies how many data pieces will rollout. + strategy_specs (Dict[AgentID, StrategySpec]): A dict of strategy specs, which rules the behavior policy of each agent. data_entrypoint_mapping (Dict[AgentID, str], optional): A mapping which defines the data collection trigger, if not None, then return episodes. Defaults to None. Raises: @@ -149,152 +201,56 @@ def run( states, obs = env.reset(max_step=rollout_config.timelimit) vec_states.append(states) vec_obs.append(obs) - vec_dones.append(False) - vec_rews.append(0.0) + vec_dones.append( + {"__all__": False, **dict.fromkeys(env.possible_agents, False)} + ) + vec_rews.append(dict.fromkeys(env.possible_agents, 0.0)) active_env_num = len(envs) - agent_manager = AgentManager(active_env_num, inference_clients) + agent_manager = AgentManager(active_env_num, inference_clients, strategy_specs) + + timer = Timing() + total_timestep = 0 while active_env_num: for env_idx, (env, states, obs, dones, rews) in enumerate( zip(envs, vec_states, vec_obs, vec_dones, vec_rews) ): - if env.is_deactivated(): + if env.is_deactivated: continue + total_timestep += 1 + + with timer.time_avg("avg_env_step"): + with timer.time_avg("avg_inference_client_step"): + actions = agent_manager.collect_and_act( + env_idx, + raw_obs=obs, + last_dones=dones, + last_rews=rews, + states=states, + ) - actions = agent_manager.collect_and_act( - env_idx, - raw_obs=obs, - last_dones=dones, - last_rews=rews, - states=states, - ) - - if actions is None: - # which means done already - active_env_num -= 1 - env.set_done() - else: - states, obs, rews, dones = env.step(actions) - # update frames - vec_states[env_idx] = states - vec_obs[env_idx] = obs - vec_dones[env_idx] = dones - vec_rews[env_idx] = rews + if dones["__all__"]: + # which means done already + active_env_num -= 1 + env.deactivate() + else: + states, obs, rews, dones, info = env.step(actions) + # update frames + vec_states[env_idx] = states + vec_obs[env_idx] = obs + vec_dones[env_idx] = dones + vec_rews[env_idx] = rews # merge agent episodes + # FIXME(ming): send data to remote dataset data = agent_manager.merge_episodes() - return data + stats = {"total_timesteps": total_timestep, **timer.todict()} + return stats -class EnvRunner(RemoteInterface): - def __repr__(self) -> str: - return f"" - - def __init__( - self, - env_desc: Dict[str, Any], - max_env_num: int, - agent_groups: Dict[str, Set], - use_subproc_env: bool = False, - batch_mode: str = "time_step", - postprocessor_types: Dict = None, - training_agent_mapping: LambdaType = None, - custom_config: Dict[str, Any] = {}, - ): - """Construct an inference client, one for each agent. - - Args: - env_desc (Dict[str, Any]): Environment description - dataset_server (_type_): A ray object reference. - max_env_num (int): The maximum of created environment instance. - use_subproc_env (bool, optional): Indicate subproc envrionment enabled or not. Defaults to False. - batch_mode (str, optional): Batch mode, could be `time_step` or `episode` mode. Defaults to "time_step". - postprocessor_types (Dict, optional): Post processor type list. Defaults to None. - training_agent_mapping (LambdaType, optional): Agent mapping function. Defaults to None. - custom_config (Dict[str, Any], optional): Custom configuration. Defaults to an empty dict. - """ - - self.use_subproc_env = use_subproc_env - self.batch_mode = batch_mode - self.postprocessor_types = postprocessor_types or ["defaults"] - self.process_id = os.getpid() - self.timer = Timing() - self.training_agent_mapping = training_agent_mapping or (lambda agent: agent) - self.max_env_num = max_env_num - self.custom_configs = custom_config - self.runtime_agent_ids = list(agent_groups.keys()) - self.agent_groups = agent_groups - - obs_spaces = env_desc["observation_spaces"] - act_spaces = env_desc["action_spaces"] - env_cls = env_desc["creator"] - env_config = env_desc["config"] - - self.preprocessor: Dict[str, Preprocessor] = { - agent: get_preprocessor(obs_spaces[agent])(obs_spaces[agent]) - for agent in env_desc["possible_agents"] - } - - if use_subproc_env: - self.env = SubprocVecEnv( - obs_spaces, act_spaces, env_cls, env_config, preset_num_envs=max_env_num - ) - else: - self.env = VectorEnv( - obs_spaces, act_spaces, env_cls, env_config, preset_num_envs=max_env_num - ) - - def close(self): - """Disconnects with inference servers and turns off environment.""" - - if self.recv_queue is not None: - _ = [e.shutdown(force=True) for e in self.recv_queue.values()] - _ = [e.shutdown(force=True) for e in self.send_queue.values()] - self.env.close() - - def run( - self, - inference_clients: Dict[AgentID, InferenceClient], - rollout_config: Dict[str, Any], - data_entrypoint_mapping: Dict[AgentID, str] = None, - ) -> Dict[str, Any]: - """Executes environment runner to collect training data or run purely simulation/evaluation. - - Note: - Only simulation/evaluation tasks return evaluation information. - - Args: - inference_clients (Dict[AgentID, InferenceClient]): A dict of agent interface servers. - rollout_config (Dict[str, Any]): Rollout configuration. - dataset_writer_info_dict (Dict[str, Tuple[str, Queue]], optional): Dataset writer info dict. Defaults to None. - - Returns: - Dict[str, Any]: A dict of simulation results. - """ - - # reset timer, ready for monitor - self.timer.clear() - task_type = rollout_config["flag"] - - server_runtime_config = { - "preprocessor": self.preprocessor, - "strategy_specs": rollout_config["strategy_specs"], - } - - eval_results, performance = _env_runner( - self, - inference_clients, - self.preprocessor, - rollout_config, - server_runtime_config, - data_entrypoint_mapping, - ) - - res = performance.copy() - if task_type != "rollout": - res["evaluation"] = eval_results - return res +from malib.utils.episode import NewEpisodeList +from malib.utils.preprocessor import Preprocessor def _env_runner( diff --git a/malib/rollout/pb_rolloutworker.py b/malib/rollout/pb_rolloutworker.py index 852dc68d..42ba620d 100644 --- a/malib/rollout/pb_rolloutworker.py +++ b/malib/rollout/pb_rolloutworker.py @@ -22,11 +22,16 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Any +from typing import Dict, Any, List -from malib.rollout.rolloutworker import RolloutWorker, parse_rollout_info +import ray + +from malib.utils.typing import AgentID from malib.utils.logging import Logger +from malib.rollout.rolloutworker import RolloutWorker, parse_rollout_info +from malib.common.strategy_spec import StrategySpec + class PBRolloutWorker(RolloutWorker): """For experience collection and simulation, the operating unit is env.AgentInterface""" @@ -34,35 +39,20 @@ class PBRolloutWorker(RolloutWorker): def step_rollout( self, eval_step: bool, - rollout_config: Dict[str, Any], - data_entrypoint_mapping: Dict[str, Any], - ): - tasks = [rollout_config for _ in range(self.rollout_config["num_threads"])] - - # add tasks for evaluation - if eval_step: - eval_runtime_config = rollout_config.copy() - eval_runtime_config["flag"] = "evaluation" - tasks.extend( - [ - eval_runtime_config - for _ in range(self.rollout_config["num_eval_threads"]) - ] - ) - - rets = [ - x - for x in self.env_runner_pool.map( - lambda a, task: a.run.remote( - inference_clients=self.inference_clients, - rollout_config=task, - data_entrypoint_mapping=data_entrypoint_mapping, - ), - tasks, + strategy_specs: Dict[AgentID, StrategySpec], + data_entrypoint_mapping: Dict[AgentID, str], + ) -> List[Dict[str, Any]]: + results = ray.get( + self.env_runner.run.remote( + inference_clients=self.inference_clients, + rollout_config=self.rollout_config, + strategy_specs=strategy_specs, + data_entrypoint_mapping=data_entrypoint_mapping, ) - ] + ) # check evaluation info - parsed_results = parse_rollout_info(rets) + parsed_results = parse_rollout_info(results) Logger.debug(f"parsed results: {parsed_results}") + return parsed_results diff --git a/malib/common/rollout_config.py b/malib/rollout/rollout_config.py similarity index 86% rename from malib/common/rollout_config.py rename to malib/rollout/rollout_config.py index 6a3ec2e1..c569376c 100644 --- a/malib/common/rollout_config.py +++ b/malib/rollout/rollout_config.py @@ -11,9 +11,15 @@ class RolloutConfig: num_workers: int = 1 """Defines how many workers will be used for executing one rollout task, default is 1""" + eval_interval: int = 1 + """Defines evaluation frequency""" + n_envs_per_worker: int = 1 """Indicates how many environments will be activated for a rollout task per rollout worker, default is 1""" + use_subproc_env: bool = False + """Indicate whether use subproce env, better to use True for heavey environments""" + timelimit: int = 256 """Specifying how many time steps will be collected for each rollout, default is 256""" diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index 6daff6e9..c20df7fa 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -47,8 +47,9 @@ from malib.common.strategy_spec import StrategySpec from malib.common.task import RolloutTask from malib.remote.interface import RemoteInterface +from malib.rollout.rollout_config import RolloutConfig from malib.rollout.inference.client import InferenceClient -from malib.rollout.inference.env_runner import EnvRunner +from malib.rollout.inference.env_runner import BasicEnvRunner PARAMETER_GET_TIMEOUT = 3 @@ -186,23 +187,21 @@ def __init__( self.env_agents = env_desc["possible_agents"] self.runtime_agent_ids = list(agent_groups.keys()) self.agent_groups = agent_groups - self.rollout_config: Dict[str, Any] = rollout_config + self.rollout_config = RolloutConfig.from_raw(rollout_config) validate_runtime_configs(self.rollout_config) - self.inference_client_cls = InferenceClient.as_remote( - **resource_config["inference_client"] + # create environment runner, handling evaluation or rollout task + env_runner_resource_config = resource_config["inference_server"] + self.env_runner = self.create_env_runner( + env_desc, env_runner_resource_config, self.rollout_config ) - self.env_runner_cls = EnvRunner.as_remote( - **resource_config["inference_server"] - ).options(max_concurrency=100) - self.env_runner_pool: ActorPool = self.init_env_runner_pool( - env_desc, rollout_config, agent_mapping_func - ) + # create inference clients, for action execution + inferenc_client_configuration = resource_config["inference_client"] self.inference_clients: Dict[ AgentID, ray.ObjectRef - ] = self.create_inference_clients() + ] = self.create_inference_clients(inferenc_client_configuration) self.log_dir = log_dir self.rollout_callback = rollout_callback or default_rollout_callback @@ -214,22 +213,18 @@ def __init__( def create_inference_clients(self) -> Dict[AgentID, ray.ObjectRef]: raise NotImplementedError - def init_env_runner_pool( + def create_env_runner( self, env_desc: Dict[str, Any], - rollout_config: Dict[str, Any], - agent_mapping_func: Callable, + resource_config: Dict[str, Any], + rollout_config: RolloutConfig, ) -> ActorPool: """Initialize an actor pool for the management of simulation tasks. Note the size of the \ generated actor pool is determined by `num_threads + num_eval_threads`. Args: env_desc (Dict[str, Any]): Environment description. - rollout_config (Dict[str, Any]): Runtime configuration, the given keys in this configuration \ - include: - - `num_threads`: int, determines the size of this actor pool. - - `num_env_per_thread`: int, indicates how many environments will be created for each thread. - - `num_eval_threads`: int, determines how many threads will be created for the evaluation along the rollouts. + rollout_config (RolloutConfig): Rollout configuration. agent_mapping_func (Callable): Agent mapping function which maps environment agents \ to runtime ids, shared among all workers. @@ -237,25 +232,16 @@ def init_env_runner_pool( ActorPool: An instance of `ActorPool`. """ - num_threads = rollout_config["num_threads"] - num_env_per_thread = rollout_config["num_env_per_thread"] - num_eval_threads = rollout_config["num_eval_threads"] - - env_runner_pool = ActorPool( - [ - self.env_runner_cls.remote( - env_desc, - max_env_num=num_env_per_thread, - agent_groups=self.agent_groups, - use_subproc_env=rollout_config["use_subproc_env"], - batch_mode=rollout_config["batch_mode"], - postprocessor_types=rollout_config["postprocessor_types"], - training_agent_mapping=agent_mapping_func, - ) - for _ in range(num_threads + num_eval_threads) - ] + env_runner_cls = BasicEnvRunner.as_remote(**resource_config).options( + max_concurrency=100 ) - return env_runner_pool + env_runner = env_runner_cls.remote( + env_func=lambda: env_desc["creator"](**env_desc["config"]), + max_env_num=rollout_config.n_envs_per_worker, + use_subproc_env=rollout_config.use_subproc_env, + ) + + return env_runner def rollout(self, task: RolloutTask): """Rollout, collecting training data when `data_entrypoints` is given, until meets the stopping conditions. The `active_agents` should be None or a none-empty list to specify active agents if rollout is not serve for evaluation. @@ -268,17 +254,7 @@ def rollout(self, task: RolloutTask): stopper = get_stopper(task.stopping_conditions) active_agents = active_agents or self.env_agents - runtime_strategy_specs = task.strategy_specs - data_entrypoint_mapping = task.data_entrypoint_mapping - - rollout_config = self.rollout_config.copy() - rollout_config.update( - { - "flag": "rollout", - "strategy_specs": runtime_strategy_specs, - "behavior_mode": BehaviorMode.EXPLORATION, - } - ) + total_timesteps = 0 eval_results = {} epoch = 0 @@ -292,9 +268,12 @@ def rollout(self, task: RolloutTask): start_time = time.time() while self.is_running(): - eval_step = (epoch + 1) % self.rollout_config["eval_interval"] == 0 + eval_step = (epoch + 1) % self.rollout_config.eval_interval == 0 results = self.step_rollout( - eval_step, rollout_config, data_entrypoint_mapping + eval_step, + task.strategy_specs, + self.rollout_config, + task.data_entrypoint_mapping, ) total_timesteps += results["total_timesteps"] eval_results = results.get("evaluation", None) @@ -333,6 +312,7 @@ def rollout(self, task: RolloutTask): def step_rollout( self, eval_step: bool, + strategy_specs: Dict[AgentID, StrategySpec], rollout_config: Dict[str, Any], data_entrypoint_mapping: Dict[AgentID, str], ) -> List[Dict[str, Any]]: diff --git a/malib/utils/episode.py b/malib/utils/episode.py index a501a0db..6bd22023 100644 --- a/malib/utils/episode.py +++ b/malib/utils/episode.py @@ -148,18 +148,20 @@ def __init__(self, agents: List[AgentID], processors=None): super().__init__(agents, processors) self.agent_buffer = {} - def record(self, obs, last_dones, last_rews, states): + def record(self, obs, actions, last_dones, last_rews, states): for agent, _obs in obs.items(): self.agent_entry[agent][Episode.CUR_OBS].append(_obs) + self.agent_entry[agent][Episode.ACTION].append(actions[agent]) self.agent_entry[agent][Episode.PRE_DONE].append(last_dones[agent]) self.agent_entry[agent][Episode.PRE_REWARD].append(last_rews[agent]) - self.agent_entry[agent][Episode.CUR_STATE].append(states[agent]) + if states is not None: + self.agent_entry[agent][Episode.CUR_STATE].append(states[agent]) def clear_buffer(self): self.agent_buffer = {} def to_numpy(self) -> Dict[AgentID, Dict[str, np.ndarray]]: - if len(self.agent_entry) == 0: + if len(self.agent_buffer) == 0: for agent, agent_trajectory in self.agent_entry.items(): agent_traj_np = {} agent_traj_np[Episode.CUR_OBS] = np.stack( @@ -175,7 +177,7 @@ def to_numpy(self) -> Dict[AgentID, Dict[str, np.ndarray]]: agent_trajectory[Episode.PRE_REWARD][1:] ) agent_traj_np[Episode.ACTION] = np.stack( - agent_trajectory[Episode.ACTION] + agent_trajectory[Episode.ACTION][:-1] ) self.agent_buffer[agent] = agent_traj_np return self.agent_buffer @@ -185,12 +187,14 @@ class ConventionalEpisodeList: def __init__(self, num: int, agents: List[AgentID]) -> None: self.episodes = [ConventionalEpisode(agents) for _ in range(num)] - def record(self, obs, last_dones, last_rews, states, idx: int = None): + def record(self, obs, actions, last_dones, last_rews, states, idx: int = None): if idx is not None: - self.episodes[i].record(obs, last_dones, last_rews, states) + self.episodes[idx].record(obs, actions, last_dones, last_rews, states) else: for i, episode in enumerate(self.episodes): - episode.record(obs[i], last_dones[i], last_rews[i], states[i]) + episode.record( + obs[i], actions[i], last_dones[i], last_rews[i], states[i] + ) def to_numpy(self) -> List[Dict[AgentID, Dict[str, np.ndarray]]]: res = [] diff --git a/tests/models/test_model_client.py b/tests/models/test_model_client.py new file mode 100644 index 00000000..24db57d6 --- /dev/null +++ b/tests/models/test_model_client.py @@ -0,0 +1,7 @@ +import pytest + +from malib.models.model_client import ModelClient + + +def test_model_client(): + pass diff --git a/tests/rl/test_policy_gradient.py b/tests/rl/test_policy_gradient.py new file mode 100644 index 00000000..7305f59d --- /dev/null +++ b/tests/rl/test_policy_gradient.py @@ -0,0 +1,7 @@ +import pytest + +from malib.rl.pg import PGPolicy, PGTrainer + + +def test_policy_update(): + pass diff --git a/tests/rollout/test_env_runner.py b/tests/rollout/test_env_runner.py index b56c83a7..25448b59 100644 --- a/tests/rollout/test_env_runner.py +++ b/tests/rollout/test_env_runner.py @@ -1,12 +1,14 @@ from typing import List, Dict, Any import pytest +import ray from malib.utils.typing import BehaviorMode from malib.common.strategy_spec import StrategySpec from malib.rollout.inference import env_runner from malib.rollout.inference.client import InferenceClient from malib.rollout.envs import mdp +from malib.rollout.rollout_config import RolloutConfig from malib.rl.random import RandomPolicy @@ -18,19 +20,18 @@ ) def test_env_runner(env_desc: Dict[str, Any], max_env_num: int): # mapping from agents to agents - agent_groups = dict(zip(env_desc["possible_agents"], env_desc["possible_agents"])) - runner = env_runner.EnvRunner( - env_desc, max_env_num, agent_groups, use_subproc_env=False - ) - - agents = env_desc["possible_agents"] - observation_spaces = env_desc["observation_spaces"] - action_spaces = env_desc["action_spaces"] - - inference_remote_cls = InferenceClient.as_remote(num_cpus=1) - rollout_config = { - "flag": "evaluation", - "strategy_specs": { + with ray.init(local_mode=True): + runner = env_runner.BasicEnvRunner( + lambda: env_desc["creator"](**env_desc["config"]), + max_env_num, + use_subproc_env=False, + ) + agents = env_desc["possible_agents"] + observation_spaces = env_desc["observation_spaces"] + action_spaces = env_desc["action_spaces"] + + inference_remote_cls = InferenceClient.as_remote(num_cpus=1) + strategy_specs = { agent: StrategySpec( policy_cls=RandomPolicy, observation_space=observation_spaces["default"], @@ -39,15 +40,29 @@ def test_env_runner(env_desc: Dict[str, Any], max_env_num: int): policy_ids=["policy-0"], ) for agent in agents - }, - "behavior_mode": BehaviorMode.EXPLOITATION, - } + } + rollout_config = RolloutConfig( + inference_server_type="ray", + num_workers=1, + eval_interval=1, + n_envs_per_worker=10, + use_subproc_env=False, + timelimit=256, + ) + + infer_clients = { + agent: inference_remote_cls.remote( + entry_point=None, + policy_cls=RandomPolicy, + observation_space=observation_spaces[agent], + action_space=action_spaces[agent], + model_config=None, + ) + for agent in agents + } - infer_clients = { - agent: inference_remote_cls.remote( - agent, observation_spaces[agent], action_spaces[agent] + stats = runner.run( + infer_clients, rollout_config, strategy_specs, data_entrypoint_mapping=None ) - for agent in agents - } - runner.run(infer_clients, rollout_config) + print(stats) From 47cd9174050c057bc846b6f77379a754a159d708 Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Thu, 9 Nov 2023 22:02:49 +0800 Subject: [PATCH 15/24] tmp save: refactor learner --- examples/run_gym.py | 2 +- examples/run_psro.py | 2 +- examples/sarl/ppo_gym.py | 2 +- malib/backend/league.py | 11 +- malib/common/manager.py | 8 +- malib/{agent => learner}/__init__.py | 2 +- .../async_learner.py} | 2 +- .../indepdent_learner.py} | 39 +--- .../agent_interface.py => learner/learner.py} | 143 ++++---------- malib/{agent => learner}/manager.py | 17 +- .../team_agent.py => learner/team_learner.py} | 2 +- malib/mocker/mocker_utils.py | 2 +- malib/rollout/inference/client.py | 67 ------- malib/rollout/inference/env_runner.py | 186 ++++-------------- malib/rollout/inference/manager.py | 43 ++++ malib/rollout/manager.py | 19 +- malib/rollout/pb_rolloutworker.py | 1 - malib/rollout/rolloutworker.py | 20 +- malib/scenarios/marl_scenario.py | 2 +- malib/scenarios/psro_scenario.py | 2 +- malib/scenarios/sarl_scenario.py | 46 +++-- tests/agents/test_independent_agent.py | 2 +- tests/agents/test_manager.py | 4 +- tests/rollout/test_env_runner.py | 6 +- 24 files changed, 200 insertions(+), 430 deletions(-) rename malib/{agent => learner}/__init__.py (95%) rename malib/{agent/async_agent.py => learner/async_learner.py} (95%) rename malib/{agent/indepdent_agent.py => learner/indepdent_learner.py} (63%) rename malib/{agent/agent_interface.py => learner/learner.py} (69%) rename malib/{agent => learner}/manager.py (95%) rename malib/{agent/team_agent.py => learner/team_learner.py} (98%) create mode 100644 malib/rollout/inference/manager.py diff --git a/examples/run_gym.py b/examples/run_gym.py index df904ba7..91f9b2e0 100644 --- a/examples/run_gym.py +++ b/examples/run_gym.py @@ -27,7 +27,7 @@ import time from malib.runner import run -from malib.agent import IndependentAgent +from malib.learner import IndependentAgent from malib.scenarios.marl_scenario import MARLScenario from malib.rl.dqn import DQNPolicy, DQNTrainer, DEFAULT_CONFIG from malib.rollout.envs.gym import env_desc_gen diff --git a/examples/run_psro.py b/examples/run_psro.py index 05eee90c..b9382c5f 100644 --- a/examples/run_psro.py +++ b/examples/run_psro.py @@ -27,7 +27,7 @@ import time from malib.runner import run -from malib.agent import IndependentAgent +from malib.learner import IndependentAgent from malib.scenarios.psro_scenario import PSROScenario from malib.rl.dqn import DQNPolicy, DQNTrainer, DEFAULT_CONFIG from malib.rollout.envs.open_spiel import env_desc_gen diff --git a/examples/sarl/ppo_gym.py b/examples/sarl/ppo_gym.py index 34fd2806..0a734e9f 100644 --- a/examples/sarl/ppo_gym.py +++ b/examples/sarl/ppo_gym.py @@ -3,7 +3,7 @@ from argparse import ArgumentParser -from malib.agent import IndependentAgent +from malib.learner import IndependentAgent from malib.scenarios.marl_scenario import MARLScenario from malib.runner import run diff --git a/malib/backend/league.py b/malib/backend/league.py index b52d08d8..d33af6f8 100644 --- a/malib/backend/league.py +++ b/malib/backend/league.py @@ -11,9 +11,15 @@ class League: - def __init__(self, training_manager: Manager, rollout_manager: Manager) -> None: + def __init__( + self, + training_manager: Manager, + rollout_manager: Manager, + inference_manager: Manager, + ) -> None: self.training_manager = training_manager self.rollout_manager = rollout_manager + self.inferenc_managfer = inference_manager self.flight_servers = [] self.rw_lock = rwlock.RWLockFair() self.event = threading.Event() @@ -39,6 +45,9 @@ def list_learners(self): def list_rollout_workers(self): return self.rollout_manager.workers() + def list_inference_clients(self): + return self.inferenc_managfer.workers() + def get_results(self) -> Dict[str, Dict[str, Any]]: """Retrieve results from rollout and training manager. diff --git a/malib/common/manager.py b/malib/common/manager.py index 8de93772..80d394f7 100644 --- a/malib/common/manager.py +++ b/malib/common/manager.py @@ -32,10 +32,11 @@ class Manager(ABC): @abstractmethod - def __init__(self, verbose: bool): + def __init__(self, verbose: bool, namespace: str): self._force_stop = False self.pending_tasks = [] self.verbose = verbose + self._namespace = namespace def is_running(self): return len(self.pending_tasks) > 0 @@ -43,6 +44,10 @@ def is_running(self): def force_stop(self): self._force_stop = True + @property + def namespace(self) -> str: + return self._namespace + @property def workers(self) -> List[RemoteInterface]: raise NotImplementedError @@ -79,7 +84,6 @@ def cancel_pending_tasks(self): return rets - @abstractmethod def terminate(self): """Resource recall""" diff --git a/malib/agent/__init__.py b/malib/learner/__init__.py similarity index 95% rename from malib/agent/__init__.py rename to malib/learner/__init__.py index 1a34c050..13e1cde8 100644 --- a/malib/agent/__init__.py +++ b/malib/learner/__init__.py @@ -22,4 +22,4 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from .indepdent_agent import IndependentAgent +from .indepdent_learner import IndependentAgent diff --git a/malib/agent/async_agent.py b/malib/learner/async_learner.py similarity index 95% rename from malib/agent/async_agent.py rename to malib/learner/async_learner.py index 2b4fede0..6d5503bc 100644 --- a/malib/agent/async_agent.py +++ b/malib/learner/async_learner.py @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from malib.agent.agent_interface import AgentInterface +from malib.learner.learner import AgentInterface class AsyncAgent(AgentInterface): diff --git a/malib/agent/indepdent_agent.py b/malib/learner/indepdent_learner.py similarity index 63% rename from malib/agent/indepdent_agent.py rename to malib/learner/indepdent_learner.py index 82352b80..18540916 100644 --- a/malib/agent/indepdent_agent.py +++ b/malib/learner/indepdent_learner.py @@ -22,42 +22,19 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Tuple, Any, Callable, List, Union +from typing import Dict, Tuple, Any, Callable, List, Type, Union + +import gym + +from gym import spaces +from malib.backend.dataset_server.data_loader import DynamicDataset from malib.utils.typing import AgentID from malib.utils.tianshou_batch import Batch -from malib.agent.agent_interface import AgentInterface +from malib.learner.learner import Learner -class IndependentAgent(AgentInterface): - def __init__( - self, - experiment_tag: str, - runtime_id: str, - log_dir: str, - env_desc: Dict[str, Any], - algorithms: Dict[str, Tuple[Dict, Dict, Dict]], - agent_mapping_func: Callable[[AgentID], str], - governed_agents: Tuple[AgentID], - trainer_config: Dict[str, Any], - custom_config: Dict[str, Any] = None, - local_buffer_config: Dict = None, - verbose: bool = True, - ): - super().__init__( - experiment_tag, - runtime_id, - log_dir, - env_desc, - algorithms, - agent_mapping_func, - governed_agents, - trainer_config, - custom_config, - local_buffer_config, - verbose, - ) - +class IndependentAgent(Learner): def multiagent_post_process( self, batch_info: Union[ diff --git a/malib/agent/agent_interface.py b/malib/learner/learner.py similarity index 69% rename from malib/agent/agent_interface.py rename to malib/learner/learner.py index 3cac506f..da87e2f4 100644 --- a/malib/agent/agent_interface.py +++ b/malib/learner/learner.py @@ -27,9 +27,6 @@ from abc import ABC, abstractmethod from collections import deque -import os -import copy -import time import traceback import gym @@ -38,21 +35,20 @@ from ray.util.queue import Queue from torch.utils import tensorboard +from torch.utils.data import DataLoader -from malib import settings -from malib.backend.offline_dataset_server import OfflineDataset -from malib.backend.data_loader import RLDataLoader -from malib.backend.parameter_server import ParameterServer from malib.utils.typing import AgentID from malib.utils.logging import Logger from malib.utils.tianshou_batch import Batch from malib.utils.monitor import write_to_tensorboard from malib.remote.interface import RemoteInterface from malib.rl.common.trainer import Trainer +from malib.common.task import OptimizationTask from malib.common.strategy_spec import StrategySpec +from malib.backend.dataset_server.data_loader import DynamicDataset -class AgentInterface(RemoteInterface, ABC): +class Learner(RemoteInterface, ABC): """Base class of agent interface, for training""" @abstractmethod @@ -70,7 +66,7 @@ def __init__( custom_config: Dict[str, Any] = None, local_buffer_config: Dict = None, verbose: bool = True, - dataloader: RLDataLoader = None, + dataset: DynamicDataset = None, ): """Construct agent interface for training. @@ -99,20 +95,11 @@ def __init__( # initialize a strategy spec for policy maintainance. strategy_spec = StrategySpec( - identifier=runtime_id, - policy_ids=[], - meta_data={ - "policy_cls": algorithms["default"][0], - "experiment_tag": experiment_tag, - # for policy initialize - "kwargs": { - "observation_space": observation_space, - "action_space": action_space, - "model_config": algorithms["default"][2], - "custom_config": algorithms["default"][3], - "kwargs": {}, - }, - }, + policy_cls=algorithms["default"][0], + observation_space=observation_space, + action_space=action_space, + model_config=algorithms["default"][2], + **algorithms["default"][3], ) self._runtime_id = runtime_id @@ -130,12 +117,19 @@ def __init__( self._trainer: Trainer = algorithms["default"][1](trainer_config) self._policies = {} - self._offline_dataset: OfflineDataset = None - self._parameter_server: ParameterServer = None - self._dataloader = dataloader or self.create_dataloader() + dataset = dataset or self.create_dataset() + self._data_loader = DataLoader(dataset, batch_size=trainer_config["batch_size"]) self._active_tups = deque() - self.verbose = verbose + self._verbose = verbose + + @property + def verbose(self) -> bool: + return self._verbose + + @property + def data_loader(self) -> DataLoader: + return self._data_loader @property def governed_agents(self) -> Tuple[str]: @@ -157,48 +151,9 @@ def device(self) -> Union[str, torch.DeviceObjType]: return self._device - def create_dataloader(self) -> RLDataLoader: - """Create a data loader instance. - - Raises: - NotImplementedError: Raise if this method is not implemented. - - Returns: - RLDataLoader: A data loader instance. - """ - + def create_dataset(self) -> DynamicDataset: raise NotImplementedError - def connect( - self, - max_tries: int = 10, - dataset_server_ref: str = None, - parameter_server_ref: str = None, - ): - """Try to connect with backend, i.e., parameter server and offline dataset server. If the reference of dataset server or parameter server is not been given, then the agent will use default settings. - - Args: - max_tries (int, optional): Maximum of trails. Defaults to 10. - dataset_server_ref (str, optional): Name of ray-based dataset server. Defaults to None. - parameter_server_ref (str, optional): Name of ray-based parameter server. Defaults to None. - """ - - parameter_server_ref = parameter_server_ref or settings.PARAMETER_SERVER_ACTOR - dataset_server_ref = dataset_server_ref or settings.OFFLINE_DATASET_ACTOR - - while max_tries > 0: - try: - if self._parameter_server is None: - self._parameter_server = ray.get_actor(parameter_server_ref) - if self._offline_dataset is None: - self._offline_dataset = ray.get_actor(dataset_server_ref) - break - except Exception as e: - Logger.debug(f"{e}") - max_tries -= 1 - time.sleep(1) - continue - def add_policies(self, n: int) -> StrategySpec: """Construct `n` new policies and return the latest strategy spec. @@ -306,11 +261,7 @@ def sync_remote_parameters(self): ) ) - def train( - self, - data_request_identifier: str, - reset_state: bool = True, - ) -> Dict[str, Any]: + def train(self, task: OptimizationTask) -> Dict[str, Any]: """Executes a optimization task and returns the final interface state. Args: @@ -323,52 +274,28 @@ def train( # XXX(ming): why we need to reset the state here? I think it is not necessary as # an optimization task should be independent with other tasks. - if reset_state: - self.reset() - - reader_info_dict: Dict[str, Tuple[str, Queue]] = {} - assert len(self._active_tups) == 1, "the length of active tups can be only 1." self.set_running(True) try: while self.is_running(): - if data_request_identifier not in reader_info_dict: - reader_info_dict[data_request_identifier] = ray.get( - self._offline_dataset.start_consumer_pipe.remote( - name=data_request_identifier, - batch_size=self._trainer_config["batch_size"], + for data in self.data_loader: + batch_info = self.multiagent_post_process(data) + step_info_list = self._trainer(batch_info) + for step_info in step_info_list: + self._total_step += 1 + write_to_tensorboard( + self._summary_writer, + info=step_info, + global_step=self._total_step, + prefix=f"Training/{self._runtime_id}", ) - ) - reader_info: Tuple[str, Queue] = reader_info_dict[ - data_request_identifier - ] - - # XXX(ming): what if queue has been killed by remote server? - batch_info = reader_info[-1].get() - if len(batch_info[-1]) == 0: - continue - batch = self.multiagent_post_process(batch_info) - step_info_list = self._trainer(batch) - for step_info in step_info_list: - self._total_step += 1 - write_to_tensorboard( - self._summary_writer, - info=step_info, - global_step=self._total_step, - prefix=f"Training/{self._runtime_id}", - ) - self.sync_remote_parameters() - self._total_epoch += 1 - self._active_tups.popleft() + self.sync_remote_parameters() + self._total_epoch += 1 except Exception as e: Logger.warning( f"training pipe is terminated. caused by: {traceback.format_exc()}" ) - # close the data pipeline - ray.get( - self._offline_dataset.end_consumer_pipe.remote(data_request_identifier) - ) if self.verbose: Logger.info( diff --git a/malib/agent/manager.py b/malib/learner/manager.py similarity index 95% rename from malib/agent/manager.py rename to malib/learner/manager.py index 7eca9f3a..c51e18b8 100644 --- a/malib/agent/manager.py +++ b/malib/learner/manager.py @@ -46,7 +46,7 @@ from malib.utils.logging import Logger from malib.utils.exploitability import measure_exploitability from malib.remote.interface import RemoteInterface -from malib.agent.agent_interface import AgentInterface +from malib.learner.learner import Learner from malib.common.strategy_spec import StrategySpec from malib.common.manager import Manager from malib.common.training_config import TrainingConfig @@ -70,6 +70,7 @@ def __init__( log_dir: str, remote_mode: bool = True, resource_config: Dict[str, Any] = None, + ray_actor_namespace: str = "learner", verbose: bool = True, ): """Create an TrainingManager instance which is responsible for the multi agent training @@ -88,7 +89,7 @@ def __init__( remote_mode (bool, Optional): Init learners as remote actor or not. Default is True. """ - super().__init__(verbose=verbose) + super().__init__(verbose=verbose, namespace=ray_actor_namespace) resource_config = resource_config or DEFAULT_RESOURCE_CONFIG training_config = TrainingConfig.from_raw(training_config) @@ -109,7 +110,7 @@ def __init__( learner_cls = learner_cls.as_remote(**resource_config).options( max_concurrency=10 ) - learners: Dict[str, Union[AgentInterface, ray.ObjectRef]] = {} + learners: Dict[str, Union[Learner, ray.ObjectRef]] = {} assert ( "training" in stopping_conditions @@ -121,7 +122,8 @@ def __init__( experiment_tag=experiment_tag, runtime_id=rid, log_dir=f"{log_dir}/learner_{rid}", - env_desc=env_desc, + observation_space=group_info["observation_space"][rid], + action_space=group_info["action_space"][rid], algorithms=algorithms, agent_mapping_func=agent_mapping_func, governed_agents=tuple(agents), @@ -131,12 +133,13 @@ def __init__( ) # ensure all interfaces have been started up - if remote_mode: - _ = ray.get([x.connect.remote() for x in learners.values()]) + tasks = list(learners.values()) + while len(tasks): + _, tasks = ray.wait(tasks, num_returns=1, timeout=1) # TODO(ming): collect data entrypoints from learners self._group_info = group_info - self._runtime_ids = tuple(self._agent_groups.keys()) + self._runtime_ids = tuple(group_info["agent_groups"].keys()) self._experiment_tag = experiment_tag self._env_description = env_desc self._training_config = training_config diff --git a/malib/agent/team_agent.py b/malib/learner/team_learner.py similarity index 98% rename from malib/agent/team_agent.py rename to malib/learner/team_learner.py index e8c63660..68e45f8c 100644 --- a/malib/agent/team_agent.py +++ b/malib/learner/team_learner.py @@ -27,7 +27,7 @@ from malib.utils.typing import AgentID from malib.utils.tianshou_batch import Batch from malib.models.torch import make_net -from malib.agent.agent_interface import AgentInterface +from malib.learner.learner import AgentInterface class TeamAgent(AgentInterface): diff --git a/malib/mocker/mocker_utils.py b/malib/mocker/mocker_utils.py index 3db4deef..ff8407f8 100644 --- a/malib/mocker/mocker_utils.py +++ b/malib/mocker/mocker_utils.py @@ -214,7 +214,7 @@ def terminate(self): from typing import Union, Type from collections import defaultdict -from malib.agent.manager import TrainingManager +from malib.learner.manager import TrainingManager class FakeTrainingManager(TrainingManager): diff --git a/malib/rollout/inference/client.py b/malib/rollout/inference/client.py index 3cf50bf1..4892ecf3 100644 --- a/malib/rollout/inference/client.py +++ b/malib/rollout/inference/client.py @@ -139,73 +139,6 @@ def compute_action( policy_return = PolicyReturnWithObs(**_returns) return policy_return - def compute_action_with_frames( - self, dataframes: List[DataFrame], runtime_config: Dict[str, Any] - ) -> List[DataFrame]: - timer = Timing() - strategy_specs: Dict[AgentID, StrategySpec] = runtime_config["strategy_specs"] - return_dataframes: List[DataFrame] = [] - - assert len(dataframes) > 0 - - for dataframe in dataframes: - with timer.time_avg("others"): - agent_id = dataframe.identifier - spec = strategy_specs[agent_id] - batch_size = dataframe.meta_data["env_num"] - spec_policy_id = spec.sample() - policy_id = f"{spec.id}/{spec_policy_id}" - policy: Policy = self.policies[policy_id] - kwargs = { - Episode.DONE: dataframe.data[Episode.DONE], - Episode.ACTION_MASK: dataframe.data[Episode.ACTION_MASK], - "evaluate": dataframe.meta_data["evaluate"], - } - observation = dataframe.data[Episode.CUR_OBS] - kwargs[Episode.RNN_STATE] = _get_initial_states( - self, - None, - observation, - policy, - identifier=dataframe.identifier, - ) - - rets = {} - - with timer.time_avg("compute_action"): - ( - rets[Episode.ACTION], - rets[Episode.ACTION_LOGITS], - rets[Episode.ACTION_DIST], - rets[Episode.RNN_STATE], - ) = policy.compute_action( - observation=observation.reshape(batch_size, -1), **kwargs - ) - - # compute state value - with timer.time_avg("compute_value"): - rets[Episode.STATE_VALUE] = policy.value_function( - observation=observation, - action_dist=rets[Episode.ACTION_DIST].copy(), - **kwargs, - ) - - with timer.time_avg("tail_handler"): - for k, v in rets.items(): - if k == Episode.RNN_STATE: - continue - if len(v.shape) < 1: - rets[k] = v.reshape(-1) - elif v.shape[0] == 1: - continue - else: - rets[k] = v.reshape(batch_size, -1) - - return_dataframes.append( - DataFrame(identifier=agent_id, data=rets, meta_data=dataframe.meta_data) - ) - return return_dataframes - def _get_initial_states(self, client_id, observation, policy: Policy, identifier): if ( diff --git a/malib/rollout/inference/env_runner.py b/malib/rollout/inference/env_runner.py index 91ba976d..29c7bdbf 100644 --- a/malib/rollout/inference/env_runner.py +++ b/malib/rollout/inference/env_runner.py @@ -23,20 +23,11 @@ # SOFTWARE. from typing import Any, List, Dict, Tuple, Set, Type -from types import LambdaType -from collections import defaultdict -import os -import time -import traceback - -import pickle import ray import numpy as np -from ray.actor import ActorHandle - -from malib.utils.typing import AgentID, DataFrame, BehaviorMode +from malib.utils.typing import AgentID from malib.utils.episode import ConventionalEpisodeList from malib.utils.timing import Timing @@ -44,7 +35,6 @@ from malib.rollout.envs.vector_env import VectorEnv, SubprocVecEnv from malib.rollout.rollout_config import RolloutConfig from malib.rollout.inference.client import InferenceClient, PolicyReturnWithObs -from malib.rollout.inference.utils import process_env_rets, process_policy_outputs from malib.rollout.envs.env import Environment from malib.common.strategy_spec import StrategySpec from malib.backend.dataset_server.utils import send_data @@ -137,7 +127,12 @@ def __repr__(self) -> str: return super().__repr__() def __init__( - self, env_func: Type, max_env_num: int, use_subproc_env: bool = False + self, + env_func: Type, + max_env_num: int, + use_subproc_env: bool = False, + agent_groups: Dict[str, Set] = None, + inferenc_client_namespace: str = None, ) -> None: super().__init__() @@ -145,6 +140,13 @@ def __init__( self._max_env_num = max_env_num self._env_func = env_func self._envs = [] + self._agent_groups = agent_groups + self._infer_client_namespace = inferenc_client_namespace + self._inference_clients = None + + @property + def inference_clients(self) -> Dict[str, ray.ObjectRef]: + return self._inference_clients @property def envs(self) -> Tuple[Environment]: @@ -168,17 +170,17 @@ def max_env_num(self) -> int: def run( self, - inference_clients: Dict[AgentID, InferenceClient], rollout_config: RolloutConfig, strategy_specs: Dict[AgentID, StrategySpec], + inference_clients: Dict[AgentID, InferenceClient] = None, data_entrypoint_mapping: Dict[AgentID, str] = None, ): """Single thread env simulation stepping. Args: - inference_clients (Dict[AgentID, InferenceClient]): A dict of remote inference client. rollout_config (RolloutConfig): Rollout configuration, which specifies how many data pieces will rollout. strategy_specs (Dict[AgentID, StrategySpec]): A dict of strategy specs, which rules the behavior policy of each agent. + inference_clients (Dict[AgentID, InferenceClient]): A dict of remote inference client. data_entrypoint_mapping (Dict[AgentID, str], optional): A mapping which defines the data collection trigger, if not None, then return episodes. Defaults to None. Raises: @@ -188,6 +190,23 @@ def run( _type_: _description_ """ + if inference_clients is None: + assert ( + self._infer_client_namespace is not None + ), "Inference client namespace should be specified if infer_clients is not given." + assert ( + self._agent_groups is not None + ), "Agent groups should be specified if infer_clients is not given." + if self.inference_clients is None: + res = {} + for rid, _agents in self._agent_groups.items(): + client = ray.get_actor( + f"inference_{rid}", namespace=self._infer_client_namespace + ) + res.update(dict.fromkeys(_agents, client)) + self._inference_clients = res + inference_clients = self.inference_clients + new_env_num = max(0, rollout_config.n_envs_per_worker - self.num_active_envs) for _ in range(new_env_num): @@ -247,142 +266,3 @@ def run( data = agent_manager.merge_episodes() stats = {"total_timesteps": total_timestep, **timer.todict()} return stats - - -from malib.utils.episode import NewEpisodeList -from malib.utils.preprocessor import Preprocessor - - -def _env_runner( - client: InferenceClient, - agents: Dict[str, InferenceClient], - preprocessors: Dict[str, Preprocessor], - rollout_config: Dict[str, Any], - server_runtime_config: Dict[str, Any], - data_entrypoint_mapping: Dict[AgentID, str], -) -> Tuple[List[Dict[str, Any]], Dict[str, float]]: - """The main logic of environment stepping, also for data collections. - - Args: - client (InferenceClient): The inference client. - rollout_config (Dict[str, Any]): Rollout configuration. - server_runtime_config (Dict[str, Any]): A dict which gives the runtime configuration of inference server. Keys including - - - `preprocessor`: observation preprocessor. - - `behavior_mode`: a value of `BehaviorMode`. - - `strategy_spec`: a dict of strategy specs, mapping from runtime agent id to strategy spces. - - dwriter_info_dict (Dict[str, Tuple[str, Queue]], optional): A dict maps from runtime ids to a tuple of dataset writer info. Defaults to None. - - Raises: - e: General exceptions. - - Returns: - Tuple[List[Dict[str, Any]], Dict[str, float]]: A tuple of evaluation results and performance results. - """ - - # check whether remote server or not - evaluate_on = rollout_config["behavior_mode"] == BehaviorMode.EXPLOITATION - remote_actor = isinstance(list(agents.values())[0], ActorHandle) - - try: - if data_entrypoint_mapping is not None: - episodes = NewEpisodeList( - num=client.env.num_envs, agents=client.env.possible_agents - ) - else: - episodes = None - - with client.timer.timeit("environment_reset"): - env_rets = client.env.reset( - fragment_length=rollout_config["fragment_length"], - max_step=rollout_config["max_step"], - ) - - env_dones, processed_env_ret, dataframes = process_env_rets( - env_rets=env_rets, - preprocessors=preprocessors, - preset_meta_data={"evaluate": evaluate_on}, - ) - - # env ret is key first, not agent first: state, obs - if episodes is not None: - episodes.record( - processed_env_ret, agent_first=False, is_episode_done=env_dones - ) - - start = time.time() - cnt = 0 - - while not client.env.is_terminated(): - # group dataframes by runtime ids. - grouped_data_frames: Dict[str, List[DataFrame]] = defaultdict(lambda: []) - for agent, dataframe in dataframes.items(): - runtime_id = client.training_agent_mapping(agent) - grouped_data_frames[runtime_id].append(dataframe) - - with client.timer.time_avg("policy_step"): - if remote_actor: - policy_outputs: Dict[str, List[DataFrame]] = { - rid: ray.get( - agent.compute_action.remote( - grouped_data_frames[rid], - runtime_config=server_runtime_config, - ) - ) - for rid, agent in agents.items() - } - else: - policy_outputs: Dict[str, List[DataFrame]] = { - rid: agent.compute_action( - grouped_data_frames[rid], - runtime_config=server_runtime_config, - ) - for rid, agent in agents.items() - } - - with client.timer.time_avg("process_policy_output"): - # TODO(ming): do not use async stepping - env_actions, processed_policy_outputs = process_policy_outputs( - policy_outputs, client.env - ) - - if episodes is not None: - episodes.record( - processed_policy_outputs, - agent_first=True, - is_episode_done=env_dones, - ) - - with client.timer.time_avg("environment_step"): - env_rets = client.env.step(env_actions) - env_dones, processed_env_ret, dataframes = process_env_rets( - env_rets=env_rets, - preprocessor=server_runtime_config["preprocessor"], - preset_meta_data={"evaluate": evaluate_on}, - ) - # state, obs, rew, done - if episodes is not None: - episodes.record( - processed_env_ret, agent_first=False, is_episode_done=env_dones - ) - - cnt += 1 - - if data_entrypoint_mapping is not None: - # episode_id: agent_id: dict_data - episodes = episodes.to_numpy() - for entrypoint in data_entrypoint_mapping.values(): - send_data(pickle.dumps(episodes), entrypoint) - end = time.time() - rollout_info = client.env.collect_info() - except Exception as e: - traceback.print_exc() - raise e - - performance = client.timer.todict() - performance["FPS"] = client.env.batched_step_cnt / (end - start) - eval_results = rollout_info - performance["total_timesteps"] = client.env.batched_step_cnt - - return eval_results, performance diff --git a/malib/rollout/inference/manager.py b/malib/rollout/inference/manager.py new file mode 100644 index 00000000..bd73e8c1 --- /dev/null +++ b/malib/rollout/inference/manager.py @@ -0,0 +1,43 @@ +from typing import Dict, Set + +import ray + +from malib.common.manager import Manager +from malib.scenarios import Scenario +from malib.rollout.inference.client import InferenceClient + + +class InferenceManager(Manager): + def __init__( + self, + group_info: Dict[str, Set], + ray_actor_namespace: str, + entrypoints: Dict[str, str], + scenario: Scenario, + verbose: bool = False, + ): + super().__init__(verbose, namespace=ray_actor_namespace) + + inference_remote_cls = InferenceClient.as_remote(num_cpus=1).options( + namespace=self.namespace + ) + obs_spaces = group_info["observation_space"] + act_spaces = group_info["action_space"] + agent_groups = group_info["agent_groups"] + + self.infer_clients = {} + for rid, _ in agent_groups.items(): + self.infer_clients[rid] = inference_remote_cls.options( + name=f"inference_{rid}" + ).remote( + entry_point=entrypoints[rid], + policy_cls=scenario.algorithms[rid].policy_cls, + observation_space=obs_spaces[rid], + action_space=act_spaces[rid], + model_config=scenario.training_config["model_config"], + ) + + # check ready + tasks = list(self.infer_clients.values()) + while len(tasks): + _, tasks = ray.wait(tasks, num_returns=1, timeout=1) diff --git a/malib/rollout/manager.py b/malib/rollout/manager.py index 1a4ddbd7..e3a99274 100644 --- a/malib/rollout/manager.py +++ b/malib/rollout/manager.py @@ -39,6 +39,7 @@ from malib.common.manager import Manager from malib.remote.interface import RemoteInterface from malib.common.strategy_spec import StrategySpec +from malib.rollout.rollout_config import RolloutConfig from malib.rollout.pb_rolloutworker import PBRolloutWorker @@ -77,12 +78,12 @@ def __init__( experiment_tag: str, stopping_conditions: Dict[str, Any], num_worker: int, - agent_mapping_func: Callable, group_info: Dict[str, Any], - rollout_config: Dict[str, Any], + rollout_config: Union[RolloutConfig, Dict[str, Any]], env_desc: Dict[str, Any], log_dir: str, resource_config: Dict[str, Any] = None, + ray_actor_namespace: str = "rollout_worker", verbose: bool = True, ): """Construct a manager for multiple rollout workers. @@ -90,7 +91,6 @@ def __init__( Args: experiment_tag (str): Experiment tag. num_worker (int): Indicates how many rollout workers will be initialized. - agent_mapping_func (Callable): Agent mapping function, maps agents to runtime id. rollout_config (Dict[str, Any]): Runtime rollout configuration. env_desc (Dict[str, Any]): Environment description. log_dir (str): Log directory. @@ -98,20 +98,21 @@ def __init__( verbose (bool, optional): Enable logging or not. Defaults to True. """ - super().__init__(verbose=verbose) + super().__init__(verbose=verbose, namespace=ray_actor_namespace) rollout_worker_cls = PBRolloutWorker - worker_cls = rollout_worker_cls.as_remote(num_cpus=0, num_gpus=0) + worker_cls = rollout_worker_cls.as_remote(num_cpus=0, num_gpus=0).options( + namespace=self.namespace + ) workers = [] - for _ in range(num_worker): + for i in range(num_worker): workers.append( - worker_cls.options(max_concurrency=100).remote( + worker_cls.options(max_concurrency=100, name=f"actor_{i}").remote( experiment_tag=experiment_tag, env_desc=env_desc, - agent_mapping_func=agent_mapping_func, agent_groups=group_info["agent_groups"], - rollout_config=rollout_config, + rollout_config=RolloutConfig.from_raw(rollout_config), log_dir=log_dir, rollout_callback=None, simulate_callback=None, diff --git a/malib/rollout/pb_rolloutworker.py b/malib/rollout/pb_rolloutworker.py index 42ba620d..7433edc2 100644 --- a/malib/rollout/pb_rolloutworker.py +++ b/malib/rollout/pb_rolloutworker.py @@ -44,7 +44,6 @@ def step_rollout( ) -> List[Dict[str, Any]]: results = ray.get( self.env_runner.run.remote( - inference_clients=self.inference_clients, rollout_config=self.rollout_config, strategy_specs=strategy_specs, data_entrypoint_mapping=data_entrypoint_mapping, diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index c20df7fa..4284e29f 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Any, List, Callable, Sequence, Tuple, Set +from typing import Dict, Any, List, Callable, Sequence, Tuple, Set, Union from abc import abstractmethod from collections import defaultdict @@ -149,9 +149,8 @@ def __init__( self, experiment_tag: str, env_desc: Dict[str, Any], - agent_mapping_func: Callable, agent_groups: Dict[str, Set], - rollout_config: Dict[str, Any], + rollout_config: Union[RolloutConfig, Dict[str, Any]], log_dir: str, rollout_callback: Callable[[ray.ObjectRef, Dict[str, Any]], Any] = None, simulate_callback: Callable[[ray.ObjectRef, Dict[str, Any]], Any] = None, @@ -162,8 +161,6 @@ def __init__( Args: env_desc (Dict[str, Any]): The environment description. - agent_mapping_func (Callable): The agent mapping function, maps environment agents to runtime ids. \ - It is shared among all workers. rollout_config (Dict[str, Any]): Basic runtime configuration to control the rollout. Keys including * `fragment_length`: int, how many steps for each data collection and broadcasting. * `max_step`: int, the maximum step of each episode. @@ -197,12 +194,6 @@ def __init__( env_desc, env_runner_resource_config, self.rollout_config ) - # create inference clients, for action execution - inferenc_client_configuration = resource_config["inference_client"] - self.inference_clients: Dict[ - AgentID, ray.ObjectRef - ] = self.create_inference_clients(inferenc_client_configuration) - self.log_dir = log_dir self.rollout_callback = rollout_callback or default_rollout_callback self.simulate_callback = simulate_callback or default_simulate_callback @@ -210,9 +201,6 @@ def __init__( self.experiment_tag = experiment_tag self.verbose = verbose - def create_inference_clients(self) -> Dict[AgentID, ray.ObjectRef]: - raise NotImplementedError - def create_env_runner( self, env_desc: Dict[str, Any], @@ -232,9 +220,7 @@ def create_env_runner( ActorPool: An instance of `ActorPool`. """ - env_runner_cls = BasicEnvRunner.as_remote(**resource_config).options( - max_concurrency=100 - ) + env_runner_cls = BasicEnvRunner.as_remote(**resource_config) env_runner = env_runner_cls.remote( env_func=lambda: env_desc["creator"](**env_desc["config"]), max_env_num=rollout_config.n_envs_per_worker, diff --git a/malib/scenarios/marl_scenario.py b/malib/scenarios/marl_scenario.py index 3dc41b54..a49ceeb8 100644 --- a/malib/scenarios/marl_scenario.py +++ b/malib/scenarios/marl_scenario.py @@ -29,7 +29,7 @@ from malib.scenarios import Scenario from malib.utils.logging import Logger -from malib.agent.manager import TrainingManager +from malib.learner.manager import TrainingManager from malib.rollout.manager import RolloutWorkerManager diff --git a/malib/scenarios/psro_scenario.py b/malib/scenarios/psro_scenario.py index 5c5c7425..7febd5ad 100644 --- a/malib/scenarios/psro_scenario.py +++ b/malib/scenarios/psro_scenario.py @@ -32,7 +32,7 @@ from malib.utils.logging import Logger from malib.utils.stopping_conditions import get_stopper from malib.utils.exploitability import measure_exploitability -from malib.agent.manager import TrainingManager +from malib.learner.manager import TrainingManager from malib.rollout.manager import RolloutWorkerManager from malib.common.payoff_manager import PayoffManager from malib.common.strategy_spec import StrategySpec diff --git a/malib/scenarios/sarl_scenario.py b/malib/scenarios/sarl_scenario.py index 64a6da1f..808df037 100644 --- a/malib/scenarios/sarl_scenario.py +++ b/malib/scenarios/sarl_scenario.py @@ -23,14 +23,15 @@ # SOFTWARE. from typing import Dict, Any -from malib.common.task import OptimizationTask +from malib.common.task import OptimizationTask, RolloutTask from malib.scenarios import Scenario from malib.utils.logging import Logger from malib.backend.league import League -from malib.agent.manager import TrainingManager +from malib.learner.manager import TrainingManager from malib.rollout.manager import RolloutWorkerManager, TaskType +from malib.rollout.inference.manager import InferenceManager class SARLScenario(Scenario): @@ -76,6 +77,7 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = log_dir=scenario.log_dir, remote_mode=True, resource_config=scenario.resource_config["training"], + ray_actor_namespace="learner_{}".format(experiment_tag), verbose=verbose, ) @@ -83,48 +85,52 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = experiment_tag=experiment_tag, stopping_conditions=scenario.stopping_conditions, num_worker=scenario.num_worker, - agent_mapping_func=scenario.agent_mapping_func, group_info=scenario.group_info, rollout_config=scenario.rollout_config, env_desc=scenario.env_desc, log_dir=scenario.log_dir, resource_config=scenario.resource_config["rollout"], + ray_actor_namespace="rollout_{}".format(experiment_tag), verbose=verbose, ) - league = League(rollout_manager, training_manager) + inference_manager = InferenceManager( + group_info=scenario.group_info, + ray_actor_namespace="inference_{}".format(experiment_tag), + entrypoints=training_manager.get_data_entrypoints(), + scenario=scenario, + ) + + league = League(rollout_manager, training_manager, inference_manager) + # NOTE(ming): if all agents are active, the strategy specs should not contain any pids strategy_specs = training_manager.add_policies(n=1) Logger.info( f"Training manager was inistialized with a strategy spec:\n{strategy_specs}" ) optimization_task = OptimizationTask( - active_agents=None, + active_agents=scenario.env_desc["possible_agents"], stop_conditions=scenario.stopping_conditions["training"], ) training_manager.submit(optimization_task) - rollout_task = { - "num_workers": 1, - "runtime_strategy_specs": strategy_specs, - "data_entrypoints": training_manager.get_data_entrypoint_mapping(), - "rollout_config": scenario.rollout_config, - "active_agents": None, - } - evaluation_task = { - "num_workers": 1, - "runtime_strategy_specs": strategy_specs, - "rollout_config": getattr( - scenario, "evaluation_config", scenario.rollout_config - ), - } + rollout_task = RolloutTask( + task_type=TaskType.ROLLOUT, + strategy_specs=strategy_specs, + stopping_conditions=scenario.stopping_conditions["rollout"], + data_entrypoint_mapping=training_manager.get_data_entrypoint_mapping(), + ) + + evaluation_task = RolloutTask( + task_type=TaskType.EVALUATION, + strategy_specs=strategy_specs, + ) rollout_manager.submit(rollout_task) rollout_manager.submit(evaluation_task) results = league.get_results() - league.terminate() return results diff --git a/tests/agents/test_independent_agent.py b/tests/agents/test_independent_agent.py index 68ba2f3b..df0f8d79 100644 --- a/tests/agents/test_independent_agent.py +++ b/tests/agents/test_independent_agent.py @@ -32,7 +32,7 @@ from malib.utils.episode import Episode from malib.mocker.mocker_utils import use_ray_env from malib.rollout.envs.mdp import env_desc_gen -from malib.agent.indepdent_agent import IndependentAgent +from malib.learner.indepdent_learner import IndependentAgent def start_learner(env_id: str, algorithm: Any): diff --git a/tests/agents/test_manager.py b/tests/agents/test_manager.py index 1ada504d..6bb9369f 100644 --- a/tests/agents/test_manager.py +++ b/tests/agents/test_manager.py @@ -34,8 +34,8 @@ from malib.runner import start_servers from malib.utils.typing import AgentID -from malib.agent import IndependentAgent -from malib.agent.manager import TrainingManager +from malib.learner import IndependentAgent +from malib.learner.manager import TrainingManager from malib.rl.random import RandomPolicy, RandomTrainer, DEFAULT_CONFIG diff --git a/tests/rollout/test_env_runner.py b/tests/rollout/test_env_runner.py index 25448b59..b3c368db 100644 --- a/tests/rollout/test_env_runner.py +++ b/tests/rollout/test_env_runner.py @@ -3,7 +3,6 @@ import pytest import ray -from malib.utils.typing import BehaviorMode from malib.common.strategy_spec import StrategySpec from malib.rollout.inference import env_runner from malib.rollout.inference.client import InferenceClient @@ -62,7 +61,10 @@ def test_env_runner(env_desc: Dict[str, Any], max_env_num: int): } stats = runner.run( - infer_clients, rollout_config, strategy_specs, data_entrypoint_mapping=None + rollout_config, + strategy_specs, + inference_clients=infer_clients, + data_entrypoint_mapping=None, ) print(stats) From 8ddc7631c93660d3353c7e7b69f16d5afddb9655 Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Fri, 10 Nov 2023 22:05:21 +0800 Subject: [PATCH 16/24] pb rollout worker test passed --- examples/run_gym.py | 1 - malib/backend/dataset_server/data_loader.py | 6 + malib/backend/league.py | 35 ++--- malib/common/manager.py | 15 +- malib/common/task.py | 9 +- malib/learner/learner.py | 122 +++++---------- malib/learner/manager.py | 111 ++++++------- malib/models/model_client.py | 17 +- malib/remote/interface.py | 6 +- malib/rl/config.py | 16 ++ malib/rollout/envs/env.py | 25 ++- malib/rollout/envs/mdp/env.py | 15 +- malib/rollout/envs/random/__init__.py | 36 +++++ malib/rollout/envs/random/env.py | 62 ++++++++ malib/rollout/inference/client.py | 8 +- malib/rollout/inference/env_runner.py | 11 +- malib/rollout/inference/manager.py | 59 +++++-- malib/rollout/manager.py | 30 ++-- malib/rollout/pb_rolloutworker.py | 1 - malib/rollout/rollout_config.py | 4 +- malib/rollout/rolloutworker.py | 104 ++++-------- malib/runner.py | 101 ------------ malib/scenarios/sarl_scenario.py | 61 ++++---- malib/scenarios/scenario.py | 50 +++--- malib/utils/stopping_conditions.py | 24 +-- tests/rollout/test_env_runner.py | 3 +- tests/rollout/test_pb_rollout_worker.py | 165 +++++++++----------- 27 files changed, 525 insertions(+), 572 deletions(-) create mode 100644 malib/rl/config.py create mode 100644 malib/rollout/envs/random/__init__.py create mode 100644 malib/rollout/envs/random/env.py delete mode 100644 malib/runner.py diff --git a/examples/run_gym.py b/examples/run_gym.py index 91f9b2e0..1ad8c40c 100644 --- a/examples/run_gym.py +++ b/examples/run_gym.py @@ -26,7 +26,6 @@ import os import time -from malib.runner import run from malib.learner import IndependentAgent from malib.scenarios.marl_scenario import MARLScenario from malib.rl.dqn import DQNPolicy, DQNTrainer, DEFAULT_CONFIG diff --git a/malib/backend/dataset_server/data_loader.py b/malib/backend/dataset_server/data_loader.py index f151190d..9d774ffc 100644 --- a/malib/backend/dataset_server/data_loader.py +++ b/malib/backend/dataset_server/data_loader.py @@ -2,6 +2,7 @@ import threading import grpc +import socket from concurrent import futures from torch.utils.data import DataLoader, Dataset @@ -33,6 +34,7 @@ def __init__( max_message_length, find_free_port(), ) + self.host = socket.gethostbyname(socket.gethostbyname()) def _start_servicer( self, max_workers: int, max_message_length: int, grpc_port: int @@ -52,6 +54,10 @@ def _start_servicer( return server + @property + def entrypoint(self) -> str: + return f"{self.host}:{self.server._state.port}" + def __len__(self): return self.feature_handler_caller.block_size diff --git a/malib/backend/league.py b/malib/backend/league.py index d33af6f8..39d6e5ce 100644 --- a/malib/backend/league.py +++ b/malib/backend/league.py @@ -8,39 +8,34 @@ from malib.utils.logging import Logger from malib.common.manager import Manager +from malib.common.task import Task, RolloutTask, OptimizationTask class League: def __init__( self, - training_manager: Manager, + learner_manager: Manager, rollout_manager: Manager, inference_manager: Manager, ) -> None: - self.training_manager = training_manager + self.learner_manager = learner_manager self.rollout_manager = rollout_manager self.inferenc_managfer = inference_manager - self.flight_servers = [] self.rw_lock = rwlock.RWLockFair() self.event = threading.Event() self.thread_pool = futures.ThreadPoolExecutor() - def register_flight_server(self, flight_server_address: str): - raise NotImplementedError - - def list_flight_servers(self) -> List[str]: - raise NotImplementedError - - def _flight_server_check(self): - while not self.event.is_set(): - with self.rw_lock.gen_rlock(): - for flight_server in self.flight_servers: - if not ray.util.check_connection(flight_server): - self.flight_servers.remove(flight_server) - self.event.wait(10) - def list_learners(self): - return self.training_manager.workers() + return self.learner_manager.workers() + + def submit(self, task_desc: Task, wait: bool = False): + if isinstance(task_desc, RolloutTask): + res = self.rollout_manager.submit(task_desc, wait) + elif isinstance(task_desc, OptimizationTask): + res = self.learner_manager.submit(task_desc, wait) + else: + raise ValueError(f"Unexpected task type: {isinstance(task_desc)}") + return res def list_rollout_workers(self): return self.rollout_manager.workers() @@ -62,7 +57,7 @@ def get_results(self) -> Dict[str, Dict[str, Any]]: while True: for result in self.rollout_manager.get_results(): rollout_results.append(result) - for result in self.training_manager.get_results(): + for result in self.learner_manager.get_results(): training_results.append(result) except KeyboardInterrupt: Logger.info("Keyboard interruption was detected, recalling resources ...") @@ -76,5 +71,5 @@ def get_results(self) -> Dict[str, Dict[str, Any]]: def terminate(self): self.event.set() self.thread_pool.shutdown() - self.training_manager.terminate() + self.learner_manager.terminate() self.rollout_manager.terminate() diff --git a/malib/common/manager.py b/malib/common/manager.py index 80d394f7..87e4c0ca 100644 --- a/malib/common/manager.py +++ b/malib/common/manager.py @@ -20,8 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import traceback -from typing import List, Generator, Any +from typing import List, Any from abc import abstractmethod, ABC import ray @@ -52,8 +51,18 @@ def namespace(self) -> str: def workers(self) -> List[RemoteInterface]: raise NotImplementedError + @abstractmethod def retrive_results(self): - raise NotImplementedError + """Retrieve execution results.""" + + @abstractmethod + def submit(self, task: Any, wait: bool = False) -> Any: + """Submit task to workers. + + Args: + task (Any): Task description. + wait (bool, optional): Block or not. Defaults to False. + """ def wait(self) -> List[Any]: """Wait workers to be terminated, and retrieve the executed results. diff --git a/malib/common/task.py b/malib/common/task.py index e59df234..1155e967 100644 --- a/malib/common/task.py +++ b/malib/common/task.py @@ -12,9 +12,12 @@ class TaskType(IntEnum): OPTIMIZATION = 2 +class Task: + pass + + @dataclass -class RolloutTask: - task_type: int +class RolloutTask(Task): strategy_specs: Dict[str, Any] = field(default_factory=dict()) stopping_conditions: Dict[str, Any] = field(default_factory=dict()) data_entrypoint_mapping: Dict[str, Any] = field(default_factory=dict()) @@ -32,7 +35,7 @@ def from_raw( @dataclass -class OptimizationTask: +class OptimizationTask(Task): stop_conditions: Dict[str, Any] """stopping conditions for optimization task, e.g., max iteration, max time, etc.""" diff --git a/malib/learner/learner.py b/malib/learner/learner.py index da87e2f4..56fc4114 100644 --- a/malib/learner/learner.py +++ b/malib/learner/learner.py @@ -33,7 +33,7 @@ import torch import ray -from ray.util.queue import Queue +from gym import spaces from torch.utils import tensorboard from torch.utils.data import DataLoader @@ -42,10 +42,11 @@ from malib.utils.tianshou_batch import Batch from malib.utils.monitor import write_to_tensorboard from malib.remote.interface import RemoteInterface -from malib.rl.common.trainer import Trainer from malib.common.task import OptimizationTask from malib.common.strategy_spec import StrategySpec from malib.backend.dataset_server.data_loader import DynamicDataset +from malib.rl.common.trainer import Trainer +from malib.rl.common.policy import Policy class Learner(RemoteInterface, ABC): @@ -57,8 +58,8 @@ def __init__( experiment_tag: str, runtime_id: str, log_dir: str, - observation_space: gym.Space, - action_space: gym.Space, + observation_space: spaces.Space, + action_space: spaces.Space, algorithms: Dict[str, Tuple[Type, Type, Dict, Dict]], agent_mapping_func: Callable[[AgentID], str], governed_agents: Tuple[AgentID], @@ -109,24 +110,32 @@ def __init__( self._strategy_spec = strategy_spec self._agent_mapping_func = agent_mapping_func self._custom_config = custom_config + self._policy = strategy_spec.gen_policy(device=device) self._summary_writer = tensorboard.SummaryWriter(log_dir=log_dir) self._trainer_config = trainer_config + + # load policy for trainer + self._trainer: Trainer = algorithms["default"][1](trainer_config, self._policy) self._total_step = 0 self._total_epoch = 0 - self._trainer: Trainer = algorithms["default"][1](trainer_config) - self._policies = {} dataset = dataset or self.create_dataset() self._data_loader = DataLoader(dataset, batch_size=trainer_config["batch_size"]) - self._active_tups = deque() - self._verbose = verbose @property def verbose(self) -> bool: return self._verbose + @property + def strategy_spec(self) -> StrategySpec: + return self._strategy_spec + + @property + def policy(self) -> Policy: + return self._policy + @property def data_loader(self) -> DataLoader: return self._data_loader @@ -151,9 +160,28 @@ def device(self) -> Union[str, torch.DeviceObjType]: return self._device + @property + def trainer(self) -> Trainer: + return self._trainer + + def get_data_entrypoint(self) -> str: + return self.data_loader.dataset.entrypoint + + def get_strategy_spec(self) -> StrategySpec: + return self._strategy_spec + + def get_state_dict(self) -> Dict[str, torch.Tensor]: + return self.policy.state_dict(device="cpu") + + @abstractmethod def create_dataset(self) -> DynamicDataset: - raise NotImplementedError + """Create dataset + + Returns: + DynamicDataset: Must be an subinstance of DynamicDataset + """ + @abstractmethod def add_policies(self, n: int) -> StrategySpec: """Construct `n` new policies and return the latest strategy spec. @@ -164,61 +192,6 @@ def add_policies(self, n: int) -> StrategySpec: StrategySpec: The latest strategy spec instance. """ - for _ in range(n): - spec_pid = f"policy-{len(self._strategy_spec)}" - self._strategy_spec.register_policy_id(policy_id=spec_pid) - policy = self._strategy_spec.gen_policy() - policy_id = f"{self._strategy_spec.id}/{spec_pid}" - self._policies[policy_id] = policy - # active tups store the policy info tuple for training, the - # the data request relies on it. - self._active_tups.append((self._strategy_spec.id, spec_pid)) - self._trainer.reset(policy_instance=policy) - - ray.get(self._parameter_server.create_table.remote(self._strategy_spec)) - ray.get( - self._parameter_server.set_weights.remote( - spec_id=self._strategy_spec.id, - spec_policy_id=spec_pid, - state_dict=policy.state_dict(), - ) - ) - - return self._strategy_spec - - def push(self): - """Push local weights to remote server""" - - pending_tasks = [] - for spec_pid in self._strategy_spec.policy_ids: - pid = f"{self._strategy_spec.id}/{spec_pid}" - task = self._parameter_server.set_weights.remote( - spec_id=self._strategy_spec.id, - spec_policy_id=spec_pid, - state_dict=self._policies[pid].state_dict(), - ) - pending_tasks.append(task) - while len(pending_tasks) > 0: - dones, pending_tasks = ray.wait(pending_tasks) - - def pull(self): - """Pull remote weights to update local version.""" - - pending_tasks = [] - - for spec_pid in self._strategy_spec.policy_ids: - pid = f"{self._strategy_spec.id}/{spec_pid}" - task = self._parameter_server.get_weights.remote( - spec_id=self._strategy_spec.id, spec_policy_id=spec_pid - ) - pending_tasks.append(task) - - while len(pending_tasks) > 0: - dones, pending_tasks = ray.wait(pending_tasks) - for done in ray.get(dones): - pid = "{}/{}".format(done["spec_id"], done["spec_policy_id"]) - self._policies[pid].load_state_dict(done["weights"]) - @abstractmethod def multiagent_post_process( self, @@ -246,21 +219,8 @@ def get_interface_state(self) -> Dict[str, Any]: "total_step": self._total_step, "total_epoch": self._total_epoch, "policy_num": len(self._strategy_spec), - "active_tups": list(self._active_tups), } - def sync_remote_parameters(self): - """Push latest network parameters of active policies to remote parameter server.""" - - top_active_tup = self._active_tups[0] - ray.get( - self._parameter_server.set_weights.remote( - spec_id=top_active_tup[0], - spec_policy_id=top_active_tup[1], - state_dict=self._trainer.policy.state_dict(device="cpu"), - ) - ) - def train(self, task: OptimizationTask) -> Dict[str, Any]: """Executes a optimization task and returns the final interface state. @@ -272,25 +232,21 @@ def train(self, task: OptimizationTask) -> Dict[str, Any]: Dict[str, Any]: A dict that describes the final state. """ - # XXX(ming): why we need to reset the state here? I think it is not necessary as - # an optimization task should be independent with other tasks. - self.set_running(True) try: while self.is_running(): for data in self.data_loader: batch_info = self.multiagent_post_process(data) - step_info_list = self._trainer(batch_info) + step_info_list = self.trainer(batch_info) for step_info in step_info_list: self._total_step += 1 write_to_tensorboard( self._summary_writer, info=step_info, global_step=self._total_step, - prefix=f"Training/{self._runtime_id}", + prefix=f"Learner/{self._runtime_id}", ) - self.sync_remote_parameters() self._total_epoch += 1 except Exception as e: Logger.warning( diff --git a/malib/learner/manager.py b/malib/learner/manager.py index c51e18b8..ea15b225 100644 --- a/malib/learner/manager.py +++ b/malib/learner/manager.py @@ -53,11 +53,11 @@ DEFAULT_RESOURCE_CONFIG = dict( - num_cpus=None, num_gpus=None, memory=None, object_store_memory=None, resources=None + num_cpus=None, num_gpus=None, memory=None, resources=None ) -class TrainingManager(Manager): +class LearnerManager(Manager): def __init__( self, experiment_tag: str, @@ -68,12 +68,11 @@ def __init__( group_info: Dict[str, Any], training_config: Union[Dict[str, Any], TrainingConfig], log_dir: str, - remote_mode: bool = True, resource_config: Dict[str, Any] = None, ray_actor_namespace: str = "learner", verbose: bool = True, ): - """Create an TrainingManager instance which is responsible for the multi agent training + """Create an LearnerManager instance which is responsible for the multi agent training tasks execution and rollout task requests sending. Args: @@ -86,7 +85,6 @@ def __init__( training_config (Dict[str, Any]): Training configuration, for agent interface, keys include \ `type`, `trainer_config` and `custom_config`. log_dir (str): Directory for logging. - remote_mode (bool, Optional): Init learners as remote actor or not. Default is True. """ super().__init__(verbose=verbose, namespace=ray_actor_namespace) @@ -96,7 +94,7 @@ def __init__( # interface config give the agent type used here and the group mapping if needed - # FIXME(ming): resource configuration is not available now, will open in the next version + # FIXME(ming): resource configuration is not available now, will turn-on in the next version if training_config.trainer_config.get("use_cuda", False): num_gpus = 1 / len(group_info["agent_groups"]) else: @@ -107,18 +105,19 @@ def __init__( learner_cls = training_config.learner_type # update num gpus resource_config["num_gpus"] = num_gpus - learner_cls = learner_cls.as_remote(**resource_config).options( - max_concurrency=10 - ) - learners: Dict[str, Union[Learner, ray.ObjectRef]] = {} + learner_cls = learner_cls.as_remote(**resource_config) + learners: Dict[str, ray.ObjectRef] = {} assert ( "training" in stopping_conditions ), f"Stopping conditions should contains `training` stoppong conditions: {stopping_conditions}" + ready_check = [] + for rid, agents in group_info["agent_groups"].items(): - _cls = learner_cls.remote if remote_mode else learner_cls - learners[rid] = _cls( + learners[rid] = learner_cls.options( + name=f"learner_{rid}", max_concurrency=10, namespace=self.namespace + ).remote( experiment_tag=experiment_tag, runtime_id=rid, log_dir=f"{log_dir}/learner_{rid}", @@ -131,11 +130,22 @@ def __init__( custom_config=training_config.custom_config, verbose=verbose, ) + ready_check.append(learners[rid].ready.remote()) # ensure all interfaces have been started up - tasks = list(learners.values()) - while len(tasks): - _, tasks = ray.wait(tasks, num_returns=1, timeout=1) + while len(ready_check): + _, ready_check = ray.wait(ready_check, num_returns=1, timeout=1) + + data_entrypoints = ray.get( + [x.get_data_entrypoint.remote() for x in learners.values()] + ) + self._data_entrypoints = dict(zip(learners.keys(), data_entrypoints)) + self._learner_entrypoints = dict( + zip( + learners.keys(), + [f"{self.namespace}:learner_{rid}" for rid in learners.keys()], + ) + ) # TODO(ming): collect data entrypoints from learners self._group_info = group_info @@ -146,7 +156,6 @@ def __init__( self._log_dir = log_dir self._agent_mapping_func = agent_mapping_func self._learners = learners - self._remote_mode = remote_mode self._thread_pool = ThreadPoolExecutor(max_workers=len(learners)) self._stopping_conditions = stopping_conditions @@ -165,14 +174,18 @@ def agent_groups(self) -> Dict[str, Set[AgentID]]: return self._group_info["agent_groups"] @property - def get_data_entrypoints(self) -> Dict[str, str]: + def data_entrypoints(self) -> Dict[str, str]: """Return a dict of data entrypoints, maps from runtime ids to data entrypoints. Returns: Dict[str, str]: A dict of data entrypoints. """ - return {rid: rid for rid in self._runtime_ids} + return self._data_entrypoints + + @property + def learner_entrypoints(self) -> Dict[str, str]: + return self._learner_entrypoints @property def workers(self) -> List[RemoteInterface]: @@ -194,9 +207,6 @@ def runtime_ids(self) -> Tuple[str]: return self._runtime_ids - def get_data_entrypoint_mapping(self) -> Dict[AgentID, str]: - raise NotImplementedError - def add_policies( self, interface_ids: Sequence[str] = None, n: Union[int, Dict[str, int]] = 1 ) -> Dict[str, Type[StrategySpec]]: @@ -217,21 +227,15 @@ def add_policies( policy_nums = dict.fromkeys(interface_ids, n) if isinstance(n, int) else n - if self._remote_mode: - strategy_spec_list: List[StrategySpec] = ray.get( - [ - self._learners[k].add_policies.remote(n=policy_nums[k]) - for k in interface_ids - ] - ) - strategy_spec_dict: Dict[str, StrategySpec] = dict( - zip(interface_ids, strategy_spec_list) - ) - else: - strategy_spec_dict = { - k: self._learners[k].add_policies(n=policy_nums[k]) + strategy_spec_list: List[StrategySpec] = ray.get( + [ + self._learners[k].add_policies.remote(n=policy_nums[k]) for k in interface_ids - } + ] + ) + strategy_spec_dict: Dict[str, StrategySpec] = dict( + zip(interface_ids, strategy_spec_list) + ) return strategy_spec_dict @@ -249,11 +253,8 @@ def submit(self, task: OptimizationTask): raise RuntimeError(f"Agent {aid} is not registered in training manager") else: learner = self._learners[rid] - if self._remote_mode: - ray_task = learner.train.remote(task) - self.pending_tasks.append(ray_task) - else: - raise NotImplementedError + ray_task = learner.train.remote(task) + self.pending_tasks.append(ray_task) def retrive_results(self) -> Generator: """Return a generator of results. @@ -262,36 +263,18 @@ def retrive_results(self) -> Generator: Generator: A generator for task results. """ - if self._remote_mode: - while len(self.pending_tasks) > 0: - dones, self.pending_tasks = ray.wait(self.pending_tasks) - for done in ray.get(dones): - yield done - else: - for task in self.pending_tasks: - assert isinstance(task, Future) - try: - if task.done(): - yield task.result(timeout=10) - except TimeoutError: - Logger.error( - f"Retrieving results of training task is timeout: {traceback.format_exc()}" - ) - except CancelledError: - Logger.error( - f"Try to retrieve results of a cancelled task: {traceback.format_exc()}" - ) - except Exception: - Logger.error(traceback.format_exc()) + while len(self.pending_tasks): + dones, self.pending_tasks = ray.wait(self.pending_tasks) + for done in ray.get(dones): + yield done def terminate(self) -> None: """Terminate all training actors.""" super().terminate() - if self._remote_mode: - for x in self._learners.values(): - ray.kill(x) + for x in self._learners.values(): + ray.kill(x) self._thread_pool.shutdown() del self._learners diff --git a/malib/models/model_client.py b/malib/models/model_client.py index 083b7092..e239fd45 100644 --- a/malib/models/model_client.py +++ b/malib/models/model_client.py @@ -31,15 +31,9 @@ def __init__(self, entry_point: str, model_config: ModelConfig): NotImplementedError: Unsupported cluster type. """ - cluster_type, name_or_address = entry_point.split(":") + namespace, name = entry_point.split(":") - if "ray" in cluster_type: - self.client = ray.get_actor(name_or_address) - else: - raise NotImplementedError - - self.cluster_type = cluster_type - self.server_address = name_or_address + self.client = ray.get_actor(name=name, namespace=namespace) self.thread_pool = futures.ThreadPoolExecutor(max_workers=10) self.event = threading.Event() @@ -59,10 +53,11 @@ def critic(self, *args, **kwargs): def _model_update(self, event: threading.Event): while not event.is_set(): - # TODO(ming): update model from remote server try: - state_dict = load_state_dict(self.client) - + state_dict = load_state_dict( + ray.get(self.client.get_state_dict.remote(), timeout=10) + ) + self.model.load_state_dict(state_dict) event.wait(0.5) except TimeoutError: # TODO(ming): count or reconnect diff --git a/malib/remote/interface.py b/malib/remote/interface.py index 43828237..4664ca4c 100644 --- a/malib/remote/interface.py +++ b/malib/remote/interface.py @@ -39,7 +39,6 @@ def as_remote( num_cpus: int = None, num_gpus: int = None, memory: int = None, - object_store_memory: int = None, resources: dict = None, ) -> type: """Return a remote class for Actor initialization""" @@ -48,10 +47,13 @@ def as_remote( num_cpus=num_cpus, num_gpus=num_gpus, memory=memory, - object_store_memory=object_store_memory, resources=resources, )(cls) + def ready(self): + """For initialization checking. Always return True.""" + return True + def stop_pending_tasks(self): """External object can call this method to stop all pending tasks.""" diff --git a/malib/rl/config.py b/malib/rl/config.py new file mode 100644 index 00000000..03e4c6d4 --- /dev/null +++ b/malib/rl/config.py @@ -0,0 +1,16 @@ +from typing import Dict, Any + +from dataclasses import dataclass + +from malib.rl.common.policy import Policy +from malib.rl.common.trainer import Trainer + + +@dataclass +class Algorithm: + + policy: Policy + + trainer: Trainer + + model_config: Dict[str, Any] diff --git a/malib/rollout/envs/env.py b/malib/rollout/envs/env.py index 2325322c..81492a77 100644 --- a/malib/rollout/envs/env.py +++ b/malib/rollout/envs/env.py @@ -22,6 +22,7 @@ from typing import Dict, List, Any, Union, Tuple, Sequence +import copy import uuid import gym import numpy as np @@ -55,6 +56,9 @@ def __init__(self, **configs): self._current_players = [] self._state: Dict[str, np.ndarray] = None self._deactivated = True + self._agents = self.register_agents() + self._observation_spaces = self.register_observation_spaces() + self._action_spaces = self.register_action_spaces() def record_episode_info_step( self, @@ -87,27 +91,40 @@ def record_episode_info_step( self.episode_metrics["episode_reward"] += sum(rewards.values()) @property - def possible_agents(self) -> List[AgentID]: + def configs(self): + return copy.deepcopy(self._configs) + + @property + def possible_agents(self) -> Tuple[AgentID]: """Return a list of environment agent ids""" - raise NotImplementedError + return tuple(self._agents) @property def observation_spaces(self) -> Dict[AgentID, gym.Space]: """A dict of agent observation spaces""" - raise NotImplementedError + return self._observation_spaces @property def action_spaces(self) -> Dict[AgentID, gym.Space]: """A dict of agent action spaces""" - raise NotImplementedError + return self._action_spaces @property def is_deactivated(self) -> bool: return self._deactivated + def register_observation_spaces(self): + raise NotImplementedError + + def register_action_spaces(self): + raise NotImplementedError + + def register_agents(self): + raise NotImplementedError + def deactivate(self): self._deactivated = True diff --git a/malib/rollout/envs/mdp/env.py b/malib/rollout/envs/mdp/env.py index 6b8f7fb3..e97c99ee 100644 --- a/malib/rollout/envs/mdp/env.py +++ b/malib/rollout/envs/mdp/env.py @@ -9,7 +9,6 @@ class MDPEnvironment(Environment): def __init__(self, **configs): - super().__init__(**configs) try: from blackhc import mdp @@ -35,18 +34,16 @@ def __init__(self, **configs): ) self.env = scenarios[env_id]().to_env() - self._possible_agents = ["default"] - @property - def possible_agents(self) -> List[AgentID]: - return self._possible_agents + super().__init__(**configs) + + def register_agents(self): + return ["default"] - @property - def observation_spaces(self) -> Dict[AgentID, gym.Space]: + def register_observation_spaces(self): return dict.fromkeys(self.possible_agents, self.env.observation_space) - @property - def action_spaces(self) -> Dict[AgentID, gym.Space]: + def register_action_spaces(self): return dict.fromkeys(self.possible_agents, self.env.action_space) def time_step( diff --git a/malib/rollout/envs/random/__init__.py b/malib/rollout/envs/random/__init__.py new file mode 100644 index 00000000..92e3686e --- /dev/null +++ b/malib/rollout/envs/random/__init__.py @@ -0,0 +1,36 @@ +# MIT License + +# Copyright (c) 2021 MARL @ SJTU + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from .env import RandomEnv + + +def env_desc_gen(**config): + env = RandomEnv(**config) + env_desc = { + "creator": RandomEnv, + "possible_agents": env.possible_agents, + "action_spaces": env.action_spaces, + "observation_spaces": env.observation_spaces, + "config": config, + } + env.close() + return env_desc diff --git a/malib/rollout/envs/random/env.py b/malib/rollout/envs/random/env.py new file mode 100644 index 00000000..72ec8aa3 --- /dev/null +++ b/malib/rollout/envs/random/env.py @@ -0,0 +1,62 @@ +from typing import Any, Dict, List, Sequence, Tuple, Union + +import gym +import random + +from gym import spaces + +from malib.rollout.envs.env import Environment +from malib.utils.typing import AgentID + + +class RandomEnv(Environment): + def __init__(self, **configs): + assert "num_agents" in configs + super().__init__(**configs) + + def register_agents(self): + return {f"agent_{i}" for i in range(self.configs["num_agents"])} + + def register_observation_spaces(self): + return { + agent: spaces.Box(low=-1, high=1, shape=(2,)) + for agent in self.possible_agents + } + + def register_action_spaces(self): + return {agent: spaces.Discrete(4) for agent in self.possible_agents} + + def get_state(self) -> Any: + return None + + def reset(self, max_step: int = None): + super().reset(max_step) + obs = {k: v.sample() for k, v in self.observation_spaces.items()} + return self.get_state(), obs + + def time_step( + self, actions: Dict[AgentID, Any] + ) -> Tuple[ + Dict[AgentID, Any], + Dict[AgentID, float], + Dict[AgentID, bool], + Dict[AgentID, Any], + ]: + # assert action whether in space + for k, v in actions.items(): + _space = self.action_spaces[k] + assert _space.contains(v), (k, v, _space) + obs = {k: v.sample() for k, v in self.observation_spaces.items()} + rews = {k: random.random() for k in self.possible_agents} + state = self.get_state() + + return ( + state, + obs, + rews, + {k: False for k in self.possible_agents}, + {k: {} for k in self.possible_agents}, + ) + + def close(self): + pass diff --git a/malib/rollout/inference/client.py b/malib/rollout/inference/client.py index 4892ecf3..82715e71 100644 --- a/malib/rollout/inference/client.py +++ b/malib/rollout/inference/client.py @@ -34,10 +34,6 @@ import numpy as np from malib.remote.interface import RemoteInterface -from malib.utils.typing import AgentID, DataFrame -from malib.utils.timing import Timing -from malib.utils.episode import Episode -from malib.common.strategy_spec import StrategySpec from malib.models.config import ModelConfig from malib.rl.common.policy import Policy, PolicyReturn @@ -50,7 +46,7 @@ class InferenceClient(RemoteInterface): def __init__( self, - entry_point: str, + model_entry_point: str, policy_cls: Type, observation_space: gym.Space, action_space: gym.Space, @@ -74,7 +70,7 @@ def __init__( observation_space, action_space, model_config, - model_entry_point=entry_point, + model_entry_point=model_entry_point, ) def shutdown(self): diff --git a/malib/rollout/inference/env_runner.py b/malib/rollout/inference/env_runner.py index 29c7bdbf..bc28882e 100644 --- a/malib/rollout/inference/env_runner.py +++ b/malib/rollout/inference/env_runner.py @@ -132,7 +132,7 @@ def __init__( max_env_num: int, use_subproc_env: bool = False, agent_groups: Dict[str, Set] = None, - inferenc_client_namespace: str = None, + inference_entry_points: Dict[str, str] = None, ) -> None: super().__init__() @@ -141,7 +141,7 @@ def __init__( self._env_func = env_func self._envs = [] self._agent_groups = agent_groups - self._infer_client_namespace = inferenc_client_namespace + self._inference_entry_points = inference_entry_points self._inference_clients = None @property @@ -192,7 +192,7 @@ def run( if inference_clients is None: assert ( - self._infer_client_namespace is not None + self._inference_entry_points is not None ), "Inference client namespace should be specified if infer_clients is not given." assert ( self._agent_groups is not None @@ -200,9 +200,8 @@ def run( if self.inference_clients is None: res = {} for rid, _agents in self._agent_groups.items(): - client = ray.get_actor( - f"inference_{rid}", namespace=self._infer_client_namespace - ) + namespace, name = self._inference_entry_points[rid].split(":") + client = ray.get_actor(name=name, namespace=namespace) res.update(dict.fromkeys(_agents, client)) self._inference_clients = res inference_clients = self.inference_clients diff --git a/malib/rollout/inference/manager.py b/malib/rollout/inference/manager.py index bd73e8c1..edac1189 100644 --- a/malib/rollout/inference/manager.py +++ b/malib/rollout/inference/manager.py @@ -1,8 +1,9 @@ -from typing import Dict, Set +from typing import Any, Dict, Set import ray from malib.common.manager import Manager +from malib.rl.config import Algorithm from malib.scenarios import Scenario from malib.rollout.inference.client import InferenceClient @@ -12,32 +13,60 @@ def __init__( self, group_info: Dict[str, Set], ray_actor_namespace: str, - entrypoints: Dict[str, str], - scenario: Scenario, + model_entry_point: Dict[str, str], + algorithm: Algorithm, verbose: bool = False, ): super().__init__(verbose, namespace=ray_actor_namespace) - inference_remote_cls = InferenceClient.as_remote(num_cpus=1).options( - namespace=self.namespace - ) + inference_remote_cls = InferenceClient.as_remote(num_cpus=1) obs_spaces = group_info["observation_space"] act_spaces = group_info["action_space"] agent_groups = group_info["agent_groups"] - self.infer_clients = {} + self._infer_clients = {} + self._inference_entry_points = {} + # FIXME(Ming): for debug only + model_entry_point = model_entry_point or { + rid: None for rid in agent_groups.keys() + } + + infer_client_ready_check = [] for rid, _ in agent_groups.items(): - self.infer_clients[rid] = inference_remote_cls.options( - name=f"inference_{rid}" + actor_name = f"inference_{rid}" + self._infer_clients[rid] = inference_remote_cls.options( + namespace=self.namespace, name=actor_name ).remote( - entry_point=entrypoints[rid], - policy_cls=scenario.algorithms[rid].policy_cls, + model_entry_point=model_entry_point[rid], + policy_cls=algorithm.policy, observation_space=obs_spaces[rid], action_space=act_spaces[rid], - model_config=scenario.training_config["model_config"], + model_config=algorithm.model_config, + ) + infer_client_ready_check.append(self._infer_clients[rid].ready.remote()) + self._inference_entry_points[rid] = "{}:{}".format( + self.namespace, actor_name ) # check ready - tasks = list(self.infer_clients.values()) - while len(tasks): - _, tasks = ray.wait(tasks, num_returns=1, timeout=1) + while len(infer_client_ready_check): + _, infer_client_ready_check = ray.wait( + infer_client_ready_check, num_returns=1, timeout=1 + ) + + def get_inference_client(self, runtime_id: str) -> InferenceClient: + return self.inference_clients[runtime_id] + + @property + def inference_clients(self) -> Dict[str, ray.ObjectRef]: + return self._infer_clients + + @property + def inference_entry_points(self) -> str: + return self._inference_entry_points + + def submit(self, task: Any, wait: bool = False): + pass + + def retrive_results(self): + pass diff --git a/malib/rollout/manager.py b/malib/rollout/manager.py index e3a99274..e669a426 100644 --- a/malib/rollout/manager.py +++ b/malib/rollout/manager.py @@ -75,7 +75,6 @@ def validate_strategy_specs(specs: Dict[str, StrategySpec]): class RolloutWorkerManager(Manager): def __init__( self, - experiment_tag: str, stopping_conditions: Dict[str, Any], num_worker: int, group_info: Dict[str, Any], @@ -89,7 +88,6 @@ def __init__( """Construct a manager for multiple rollout workers. Args: - experiment_tag (str): Experiment tag. num_worker (int): Indicates how many rollout workers will be initialized. rollout_config (Dict[str, Any]): Runtime rollout configuration. env_desc (Dict[str, Any]): Environment description. @@ -101,15 +99,14 @@ def __init__( super().__init__(verbose=verbose, namespace=ray_actor_namespace) rollout_worker_cls = PBRolloutWorker - worker_cls = rollout_worker_cls.as_remote(num_cpus=0, num_gpus=0).options( - namespace=self.namespace - ) + worker_cls = rollout_worker_cls.as_remote(num_cpus=0, num_gpus=0).options() workers = [] - + ready_check = [] for i in range(num_worker): workers.append( - worker_cls.options(max_concurrency=100, name=f"actor_{i}").remote( - experiment_tag=experiment_tag, + worker_cls.options( + max_concurrency=100, namespace=self.namespace, name=f"actor_{i}" + ).remote( env_desc=env_desc, agent_groups=group_info["agent_groups"], rollout_config=RolloutConfig.from_raw(rollout_config), @@ -120,12 +117,15 @@ def __init__( verbose=verbose, ) ) + ready_check.append(workers[-1].ready.remote()) + + while len(ready_check): + _, ready_check = ray.wait(ready_check, num_returns=1, timeout=1) - self._workers: List[ray.actor] = workers + self._workers: List[ray.ObjectRef] = workers self._actor_pool = ActorPool(self._workers) self._runtime_ids = tuple(group_info["agent_groups"].keys()) self._group_info = group_info - self.experiment_tag = experiment_tag assert ( "rollout" in stopping_conditions @@ -163,7 +163,9 @@ def workers(self) -> List[RemoteInterface]: return self._workers - def submit(self, task: Union[Dict[str, Any], List[Dict[str, Any]]]): + def submit( + self, task: Union[Dict[str, Any], List[Dict[str, Any]]], wait: bool = False + ) -> Any: """Submit a task to workers Args: @@ -180,6 +182,12 @@ def submit(self, task: Union[Dict[str, Any], List[Dict[str, Any]]]): validate_strategy_specs(_task.strategy_specs) self._actor_pool.submit(lambda actor, _task: actor.rollout.remote(_task)) + if wait: + result_list = self.wait() + return result_list + else: + return None + def retrive_results(self): """Retrieve task results diff --git a/malib/rollout/pb_rolloutworker.py b/malib/rollout/pb_rolloutworker.py index 7433edc2..deb15a83 100644 --- a/malib/rollout/pb_rolloutworker.py +++ b/malib/rollout/pb_rolloutworker.py @@ -49,7 +49,6 @@ def step_rollout( data_entrypoint_mapping=data_entrypoint_mapping, ) ) - # check evaluation info parsed_results = parse_rollout_info(results) Logger.debug(f"parsed results: {parsed_results}") diff --git a/malib/rollout/rollout_config.py b/malib/rollout/rollout_config.py index c569376c..4e462d6b 100644 --- a/malib/rollout/rollout_config.py +++ b/malib/rollout/rollout_config.py @@ -5,8 +5,6 @@ @dataclass class RolloutConfig: - inference_server_type: str - """Inference server type""" num_workers: int = 1 """Defines how many workers will be used for executing one rollout task, default is 1""" @@ -23,6 +21,8 @@ class RolloutConfig: timelimit: int = 256 """Specifying how many time steps will be collected for each rollout, default is 256""" + inference_entry_points: Dict[str, str] = field(default_factory=dict) + @classmethod def from_raw( cls, config: Union["RolloutConfig", Dict[str, Any]] diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index 4284e29f..aa35ac44 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -76,49 +76,18 @@ def parse_rollout_info(raw_statistics: List[Dict[str, Any]]) -> Dict[str, Any]: Dict[str, Any]: A merged dict. """ - results = {"total_timesteps": 0, "FPS": 0.0} - evaluation = [] - - for e in raw_statistics: - # when task mode is `simualtion` or `evaluation`, then - # evaluation result is not empty. - if "evaluation" in e: - evaluation.extend(e["evaluation"]) - - for k, v in e.items(): - if k == "total_timesteps": - results[k] += v - elif k == "FPS": - results[k] += v - # else: - # raise ValueError(f"Unknow key: {k} / {v}") - - if len(evaluation) > 0: - raw_eval_results = defaultdict(lambda: []) - for e in evaluation: - for k, v in e.items(): - if isinstance(v, (Tuple, List)): - v = sum(v) - raw_eval_results[k].append(v) - eval_results = {} - for k, v in raw_eval_results.items(): - # convert v to array - eval_results.update( - {f"{k}_max": np.max(v), f"{k}_min": np.min(v), f"{k}_mean": np.mean(v)} - ) - results["evaluation"] = eval_results - return results + return raw_statistics def log(message: str): logger.log(settings.LOG_LEVEL, f"(rollout worker) {message}") -def default_rollout_callback(coordinator: ray.ObjectRef, results: Dict[str, Any]): +def default_rollout_callback(results: Dict[str, Any]): pass -def default_simulate_callback(coordinator: ray.ObjectRef, results: Dict[str, Any]): +def default_simulate_callback(results: Dict[str, Any]): pass @@ -147,9 +116,8 @@ def validate_runtime_configs(configs: Dict[str, Any]): class RolloutWorker(RemoteInterface): def __init__( self, - experiment_tag: str, env_desc: Dict[str, Any], - agent_groups: Dict[str, Set], + agent_groups: Dict[str, Tuple], rollout_config: Union[RolloutConfig, Dict[str, Any]], log_dir: str, rollout_callback: Callable[[ray.ObjectRef, Dict[str, Any]], Any] = None, @@ -166,7 +134,6 @@ def __init__( * `max_step`: int, the maximum step of each episode. * `num_eval_episodes`: int, the number of epsiodes for each evaluation. log_dir (str): Log directory. - experiment_tag (str): Experiment tag, to create a data table. rollout_callback (Callable[[ray.ObjectRef, Dict[str, Any]], Any], optional): Callback function for rollout task, users can determine how \ to cordinate with coordinator here. Defaults by None, indicating no coordination. simulate_callback (Callable[[ray.ObjectRef, Dict[str, Any]], Any]): Callback function for simulation task, users can determine \ @@ -186,8 +153,6 @@ def __init__( self.agent_groups = agent_groups self.rollout_config = RolloutConfig.from_raw(rollout_config) - validate_runtime_configs(self.rollout_config) - # create environment runner, handling evaluation or rollout task env_runner_resource_config = resource_config["inference_server"] self.env_runner = self.create_env_runner( @@ -198,7 +163,6 @@ def __init__( self.rollout_callback = rollout_callback or default_rollout_callback self.simulate_callback = simulate_callback or default_simulate_callback self.tb_writer = tensorboard.SummaryWriter(log_dir=log_dir) - self.experiment_tag = experiment_tag self.verbose = verbose def create_env_runner( @@ -225,6 +189,8 @@ def create_env_runner( env_func=lambda: env_desc["creator"](**env_desc["config"]), max_env_num=rollout_config.n_envs_per_worker, use_subproc_env=rollout_config.use_subproc_env, + agent_groups=self.agent_groups, + inference_entry_points=rollout_config.inference_entry_points, ) return env_runner @@ -239,7 +205,6 @@ def rollout(self, task: RolloutTask): """ stopper = get_stopper(task.stopping_conditions) - active_agents = active_agents or self.env_agents total_timesteps = 0 eval_results = {} @@ -258,29 +223,27 @@ def rollout(self, task: RolloutTask): results = self.step_rollout( eval_step, task.strategy_specs, - self.rollout_config, task.data_entrypoint_mapping, ) - total_timesteps += results["total_timesteps"] - eval_results = results.get("evaluation", None) - - performance["rollout_iter_rate"] = (epoch + 1) / (time.time() - start_time) - performance["rollout_FPS"] = results["FPS"] - performance["ave_rollout_FPS"] = ( - performance["ave_rollout_FPS"] * epoch + results["FPS"] - ) / (epoch + 1) - - if eval_results is not None: - if self.verbose: - eval_results["performance"] = performance - formatted_results = pprint.pformat(eval_results) - Logger.info(f"Evaluation at epoch: {epoch}\n{formatted_results}") - write_to_tensorboard( - self.tb_writer, - eval_results, - global_step=total_timesteps, - prefix="Evaluation", - ) + # total_timesteps += results["total_timesteps"] + + # performance["rollout_iter_rate"] = (epoch + 1) / (time.time() - start_time) + # performance["rollout_FPS"] = results["FPS"] + # performance["ave_rollout_FPS"] = ( + # performance["ave_rollout_FPS"] * epoch + results["FPS"] + # ) / (epoch + 1) + + # if self.verbose: + # eval_results["performance"] = performance + # formatted_results = pprint.pformat(eval_results) + # Logger.info(f"Evaluation at epoch: {epoch}\n{formatted_results}") + + # write_to_tensorboard( + # self.tb_writer, + # results, + # global_step=total_timesteps, + # prefix="Evaluation", + # ) write_to_tensorboard( self.tb_writer, performance, global_step=epoch, prefix="Performance" @@ -291,7 +254,8 @@ def rollout(self, task: RolloutTask): break epoch += 1 - self.rollout_callback(self.coordinator, results) + self.rollout_callback(results) + return results @abstractmethod @@ -299,26 +263,12 @@ def step_rollout( self, eval_step: bool, strategy_specs: Dict[AgentID, StrategySpec], - rollout_config: Dict[str, Any], data_entrypoint_mapping: Dict[AgentID, str], ) -> List[Dict[str, Any]]: """The logic function to run rollout. Users must implment this method. Args: eval_step (bool): Indicate evaluation or not. - rollout_config (Dict[str, Any]): Runtime configurations to control the amount of sampled data. Keys include: - - `flag`: indicate the task type, the value is rollout. - - `max_step`: indicates the maximum length of an episode. - - `num_episodes`: indicates how many episodes will be collected. - - `policy_distribution`: a dict describes the policy distribution. - - `parameter_desc_dict`: a dict describes the parameter description. - - `trainable_pairs`: a dict describes the trainable policy configuration, it is a mapping from `runtime_ids` \ - to a tuple of policy id and policy configuration. - - `behavior_policies`: a dict maps runtime agents to policy ids, it specifies the behavior policy for available agents, \ - could be a subset of the full agent set. - - `agent_group`: a dict that maps runtime agents to a list of environment agents, which describes the envrionment agents \ - governed by what runtime agent interface. - - `fragment_length`: the maximum of collected data frames. data_entrypoint_mapping: ... Raises: diff --git a/malib/runner.py b/malib/runner.py deleted file mode 100644 index 380efb8c..00000000 --- a/malib/runner.py +++ /dev/null @@ -1,101 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import time -import ray - -from malib import settings -from malib.utils.logging import Logger -from malib.scenarios import marl_scenario, psro_scenario -from malib.scenarios.scenario import Scenario -from malib.backend.offline_dataset_server import OfflineDataset -from malib.backend.parameter_server import ParameterServer - - -def start_servers(data_table_capacity: int = 100000): - try: - offline_dataset_server = ( - OfflineDataset.as_remote(num_cpus=0) - .options(name=settings.OFFLINE_DATASET_ACTOR, max_concurrency=100) - .remote(table_capacity=data_table_capacity) - ) - ray.get(offline_dataset_server.start.remote()) - except ValueError: - Logger.warning("detected existing offline dataset server") - offline_dataset_server = ray.get_actor(settings.OFFLINE_DATASET_ACTOR) - - try: - parameter_server = ( - ParameterServer.as_remote(num_cpus=1) - .options(name=settings.PARAMETER_SERVER_ACTOR, max_concurrency=100) - .remote() - ) - ray.get(parameter_server.start.remote()) - except ValueError: - Logger.warning("detected exisitng parameter server") - parameter_server = ray.get_actor(settings.PARAMETER_SERVER_ACTOR) - - return parameter_server, offline_dataset_server - - -def run(scenario: Scenario, cluster_address: str = "auto"): - """Load scenario to the execution plan and lauch a cluster. The instance will search an active \ - cluster by default. Users can also determine the specified cluster with given `cluster_address`. - - Args: - scenario (Scenario): Scenario instance. - cluster_address (str, optional): Ray cluster address. Defaults to "auto", which means the \ - training instance will search an active cluster. - - Raises: - TypeError: Unexpected scenario type. - """ - - try: - start_ray_info = ray.init(address="auto", dashboard_port=8265) - except ConnectionError: - Logger.warning("No active cluster deteced, will create a local ray instance.") - start_ray_info = ray.init() - - try: - Logger.info("Ray lauched: {}".format(start_ray_info)) - Logger.info("Ray cluster resources info: {}".format(ray.cluster_resources())) - - parameter_server, offline_dataset_server = start_servers() - scenario.parameter_server = parameter_server - scenario.offline_dataset_server = offline_dataset_server - - experiment_tag = f"malib-{scenario.name}-{time.strftime('%Y-%m-%d-%H%M%S')}" - - if isinstance(scenario, psro_scenario.PSROScenario): - psro_scenario.execution_plan(experiment_tag, scenario) - elif isinstance(scenario, marl_scenario.MARLScenario): - marl_scenario.execution_plan(experiment_tag, scenario) - else: - raise TypeError("Unexpected scenario type: {}".format(scenario)) - except KeyboardInterrupt: - ray.shutdown() - except TypeError as e: - ray.shutdown() - raise e - except Exception as e: - raise e diff --git a/malib/scenarios/sarl_scenario.py b/malib/scenarios/sarl_scenario.py index 808df037..653f6cdc 100644 --- a/malib/scenarios/sarl_scenario.py +++ b/malib/scenarios/sarl_scenario.py @@ -23,14 +23,14 @@ # SOFTWARE. from typing import Dict, Any -from malib.common.task import OptimizationTask, RolloutTask +from malib.common.task import TaskType, OptimizationTask, RolloutTask from malib.scenarios import Scenario - +from malib.utils.stopping_conditions import StoppingCondition, get_stopper from malib.utils.logging import Logger from malib.backend.league import League -from malib.learner.manager import TrainingManager -from malib.rollout.manager import RolloutWorkerManager, TaskType +from malib.learner.manager import LearnerManager +from malib.rollout.manager import RolloutWorkerManager from malib.rollout.inference.manager import InferenceManager @@ -44,8 +44,6 @@ def __init__( training_config: Dict[str, Any], rollout_config: Dict[str, Any], stopping_conditions: Dict[str, Any], - dataset_config: Dict[str, Any], - parameter_server_config: Dict[str, Any], resource_config: Dict[str, Any] = None, ): super().__init__( @@ -57,16 +55,17 @@ def __init__( training_config, rollout_config, stopping_conditions, - dataset_config, - parameter_server_config, ) self.num_policy_each_interface = 1 self.resource_config = resource_config or {"training": None, "rollout": None} + def create_global_stopper(self) -> StoppingCondition: + return get_stopper(self.stopping_conditions) + def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = True): # TODO(ming): simplify the initialization of training and rollout manager with a scenario instance as input - training_manager = TrainingManager( + learner_manager = LearnerManager( experiment_tag=experiment_tag, stopping_conditions=scenario.stopping_conditions, algorithms=scenario.algorithms, @@ -81,8 +80,14 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = verbose=verbose, ) + inference_manager = InferenceManager( + group_info=scenario.group_info, + ray_actor_namespace="inference_{}".format(experiment_tag), + model_entry_point=learner_manager.learner_entrypoints, + scenario=scenario, + ) + rollout_manager = RolloutWorkerManager( - experiment_tag=experiment_tag, stopping_conditions=scenario.stopping_conditions, num_worker=scenario.num_worker, group_info=scenario.group_info, @@ -94,32 +99,22 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = verbose=verbose, ) - inference_manager = InferenceManager( - group_info=scenario.group_info, - ray_actor_namespace="inference_{}".format(experiment_tag), - entrypoints=training_manager.get_data_entrypoints(), - scenario=scenario, - ) - - league = League(rollout_manager, training_manager, inference_manager) - - # NOTE(ming): if all agents are active, the strategy specs should not contain any pids - strategy_specs = training_manager.add_policies(n=1) - Logger.info( - f"Training manager was inistialized with a strategy spec:\n{strategy_specs}" + league = League( + learner_manager, rollout_manager, inference_manager, namespace=experiment_tag ) optimization_task = OptimizationTask( active_agents=scenario.env_desc["possible_agents"], stop_conditions=scenario.stopping_conditions["training"], ) - training_manager.submit(optimization_task) + + strategy_specs = learner_manager.get_strategy_specs() rollout_task = RolloutTask( task_type=TaskType.ROLLOUT, strategy_specs=strategy_specs, stopping_conditions=scenario.stopping_conditions["rollout"], - data_entrypoint_mapping=training_manager.get_data_entrypoint_mapping(), + data_entrypoint_mapping=learner_manager.data_entrypoints, ) evaluation_task = RolloutTask( @@ -127,8 +122,20 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = strategy_specs=strategy_specs, ) - rollout_manager.submit(rollout_task) - rollout_manager.submit(evaluation_task) + stopper = scenario.create_global_stopper() + epoch_cnt = 0 + + while True: + rollout_results = league.submit(rollout_task, wait=True) + training_results = league.submit(optimization_task, wait=True) + evaluation_results = league.submit(evaluation_task, wait=True) + epoch_cnt += 1 + if stopper.should_stop( + evaluation_results, training_results, rollout_results, epoch_cnt + ): + break + if epoch_cnt % scenario.save_interval == 0: + league.save_checkpoint(global_step=epoch_cnt) results = league.get_results() league.terminate() diff --git a/malib/scenarios/scenario.py b/malib/scenarios/scenario.py index f009a085..03af7069 100644 --- a/malib/scenarios/scenario.py +++ b/malib/scenarios/scenario.py @@ -22,13 +22,14 @@ from abc import ABC, abstractmethod from types import LambdaType -from typing import Callable, Union, Dict, Any, Set, List +from typing import Dict, Any, Set, Tuple from copy import deepcopy from collections import defaultdict import gym from malib.utils.typing import AgentID +from malib.utils.stopping_conditions import StoppingCondition DEFAULT_STOPPING_CONDITIONS = {} @@ -40,7 +41,7 @@ def validate_spaces(agent_groups: Dict[str, Set[AgentID]], env_desc: Dict[str, A def validate_agent_group( - agent_group: Dict[str, List[AgentID]], + agent_group: Dict[str, Tuple[AgentID]], observation_spaces: Dict[AgentID, gym.Space], action_spaces: Dict[AgentID, gym.Space], ) -> None: @@ -66,6 +67,23 @@ def validate_agent_group( assert select_act_space.shape == action_spaces[agent].shape +def form_group_info(env_desc, agent_mapping_func): + agent_groups = defaultdict(lambda: list()) + grouped_obs_space = {} + grouped_act_space = {} + for agent in env_desc["possible_agents"]: + rid = agent_mapping_func(agent) + agent_groups[rid].append(agent) + grouped_obs_space[rid] = env_desc["observation_spaces"][agent] + grouped_act_space[rid] = env_desc["action_spaces"][agent] + agent_groups = {k: tuple(v) for k, v in agent_groups.items()} + return { + "observation_space": grouped_obs_space, + "action_space": grouped_act_space, + "agent_groups": agent_groups, + } + + class Scenario(ABC): @abstractmethod def __init__( @@ -78,8 +96,6 @@ def __init__( training_config: Dict[str, Any], rollout_config: Dict[str, Any], stopping_conditions: Dict[str, Any], - dataset_config: Dict[str, Any], - parameter_server_config: Dict[str, Any], ): self.name = name self.log_dir = log_dir @@ -87,33 +103,23 @@ def __init__( self.algorithms = algorithms self.agent_mapping_func = agent_mapping_func # then generate grouping information here - agent_groups = defaultdict(lambda: set()) - grouped_obs_space = {} - grouped_act_space = {} - for agent in env_desc["possible_agents"]: - rid = agent_mapping_func(agent) - agent_groups[rid].add(agent) - grouped_obs_space[rid] = env_desc["observation_spaces"][agent] - grouped_act_space[rid] = env_desc["action_spaces"][agent] - self.group_info = { - "observation_space": grouped_obs_space, - "action_space": grouped_act_space, - "agent_groups": agent_groups, - } + self.group_info = form_group_info(env_desc, agent_mapping_func) validate_agent_group( - agent_groups, env_desc["observation_spaces"], env_desc["action_spaces"] + self.group_info["agent_groups"], + env_desc["observation_spaces"], + env_desc["action_spaces"], ) self.training_config = training_config self.rollout_config = rollout_config self.stopping_conditions = stopping_conditions or DEFAULT_STOPPING_CONDITIONS - self.dataset_config = dataset_config or {"table_capacity": 1000} - self.parameter_server_config = parameter_server_config or {} - self.parameter_server = None - self.offline_dataset_server = None def copy(self): return deepcopy(self) + @abstractmethod + def create_global_stopper(self) -> StoppingCondition: + """Create a global stopper.""" + def with_updates(self, **kwargs) -> "Scenario": new_copy = self.copy() for k, v in kwargs.items(): diff --git a/malib/utils/stopping_conditions.py b/malib/utils/stopping_conditions.py index 7a67b329..27584ede 100644 --- a/malib/utils/stopping_conditions.py +++ b/malib/utils/stopping_conditions.py @@ -33,17 +33,17 @@ class StoppingCondition(ABC): @abstractmethod - def should_stop(self, latest_trainer_result: dict, *args, **kwargs) -> bool: + def should_stop(self, results, **kwargs) -> bool: pass class NoStoppingCondition(StoppingCondition): - def should_stop(self, latest_trainer_result: dict, *args, **kwargs) -> bool: + def should_stop(self, results, **kwargs) -> bool: return False class StopImmediately(StoppingCondition): - def should_stop(self, latest_trainer_result: dict, *args, **kwargs) -> bool: + def should_stop(self, results, **kwargs) -> bool: return True @@ -51,10 +51,8 @@ class RewardImprovementStopping(StoppingCondition): def __init__(self, mininum_reward_improvement: float) -> None: self.minium_reward_improvement = mininum_reward_improvement - def should_stop(self, latest_trainer_result: dict, *args, **kwargs) -> bool: - reward_this_iter = latest_trainer_result.get( - "evaluation", {"episode_reward_mean": float("inf")} - )["episode_reward_mean"] + def should_stop(self, results, **kwargs) -> bool: + reward_this_iter = results.get("episode_reward_mean", float("inf")) if reward_this_iter == float("inf"): return False should_stop = False @@ -69,7 +67,7 @@ def __init__( self.max_iteration = max_iteration self.n_iteration = 0 - def should_stop(self, latest_trainer_result: dict, *args, **kwargs) -> bool: + def should_stop(self, results, **kwargs) -> bool: self.n_iteration += 1 should_stop = False @@ -87,8 +85,14 @@ def __init__(self, stoppings: List[StoppingCondition]) -> None: super().__init__() self.stoppings = stoppings - def should_stop(self, latest_trainer_result: dict, *args, **kwargs) -> bool: - stops = [e.should_stop(latest_trainer_result) for e in self.stoppings] + def should_stop(self, results, **kwargs) -> bool: + stops = [ + e.should_stop( + results, + **kwargs, + ) + for e in self.stoppings + ] return all(stops) diff --git a/tests/rollout/test_env_runner.py b/tests/rollout/test_env_runner.py index b3c368db..ddddaa3a 100644 --- a/tests/rollout/test_env_runner.py +++ b/tests/rollout/test_env_runner.py @@ -41,7 +41,6 @@ def test_env_runner(env_desc: Dict[str, Any], max_env_num: int): for agent in agents } rollout_config = RolloutConfig( - inference_server_type="ray", num_workers=1, eval_interval=1, n_envs_per_worker=10, @@ -51,7 +50,7 @@ def test_env_runner(env_desc: Dict[str, Any], max_env_num: int): infer_clients = { agent: inference_remote_cls.remote( - entry_point=None, + model_entry_point=None, policy_cls=RandomPolicy, observation_space=observation_spaces[agent], action_space=action_spaces[agent], diff --git a/tests/rollout/test_pb_rollout_worker.py b/tests/rollout/test_pb_rollout_worker.py index a553f315..a8652705 100644 --- a/tests/rollout/test_pb_rollout_worker.py +++ b/tests/rollout/test_pb_rollout_worker.py @@ -25,15 +25,20 @@ from typing import Dict, Any import pytest -import threading -import time import ray from pytest_mock import MockerFixture from gym import spaces -from malib.runner import start_servers -from malib.mocker.mocker_utils import FakeInferenceClient, FakeInferenceServer +from malib.common.task import RolloutTask +from malib.common.strategy_spec import StrategySpec +from malib.rl.random import RandomPolicy +from malib.rl.config import Algorithm +from malib.rollout.envs.random import env_desc_gen +from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.pb_rolloutworker import PBRolloutWorker +from malib.rollout.inference.manager import InferenceManager +from malib.scenarios.scenario import form_group_info def gen_rollout_config(inference_server_type: str): @@ -52,93 +57,69 @@ def gen_rollout_config(inference_server_type: str): } -def create_rollout_worker( - mocker: MockerFixture, env_desc: Dict[str, Any], rollout_config: Dict[str, Any] -): - mocker.patch( - "malib.rollout.rolloutworker.RayInferenceClient", new=FakeInferenceClient - ) - mocker.patch( - "malib.rollout.rolloutworker.RayInferenceServer", new=FakeInferenceServer - ) - from malib.rollout.pb_rolloutworker import PBRolloutWorker - - worker = PBRolloutWorker( - experiment_tag="test_rollout_worker", - env_desc=env_desc, - agent_mapping_func=lambda agent: agent, - rollout_config=rollout_config, - log_dir="./logs", - ) - return worker - - @pytest.mark.parametrize("n_player", [1, 2]) -@pytest.mark.parametrize("inference_server_type", ["local", "ray"]) class TestRolloutWorker: - def test_rollout( - self, mocker: MockerFixture, n_player: int, inference_server_type: str - ): - if not ray.is_initialized(): - ray.init() - - parameter_server, dataset_server = start_servers() - - agents = [f"player_{i}" for i in range(n_player)] - - worker = create_rollout_worker( - mocker, - env_desc={ - "possible_agents": agents, - "observation_spaces": { - agent: spaces.Box(-1, 1.0, shape=(2,)) for agent in agents - }, - "action_spaces": { - agent: spaces.Box(-1, 1, shape=(2,)) for agent in agents - }, - }, - rollout_config=gen_rollout_config(inference_server_type), - ) - - data_entrypoints = {agent: agent for agent in agents} - results = worker.rollout( - None, - {"max_iteration": 2}, - data_entrypoints, - None, - ) - print("rollout results:", results) - - ray.kill(parameter_server) - ray.kill(dataset_server) - ray.shutdown() - - def test_simulation( - self, mocker: MockerFixture, n_player: int, inference_server_type: str - ): - if not ray.is_initialized(): - ray.init() - - parameter_server, dataset_server = start_servers() - - agents = [f"player_{i}" for i in range(n_player)] - - worker = create_rollout_worker( - mocker, - env_desc={ - "possible_agents": agents, - "observation_spaces": { - agent: spaces.Box(-1, 1.0, shape=(2,)) for agent in agents - }, - "action_spaces": { - agent: spaces.Box(-1, 1, shape=(2,)) for agent in agents - }, - }, - rollout_config=gen_rollout_config(inference_server_type), - ) - - results = worker.simulate({}) - - ray.kill(parameter_server) - ray.kill(dataset_server) - ray.shutdown() + def test_rollout(self, n_player: int): + with ray.init(local_mode=True): + env_desc = env_desc_gen(num_agents=n_player) + obs_spaces = env_desc["observation_spaces"] + act_spaces = env_desc["action_spaces"] + agents = env_desc["possible_agents"] + log_dir = "./logs" + + algorithm = Algorithm( + policy=RandomPolicy, + trainer=None, + model_config=None, + ) + + rollout_config = RolloutConfig( + num_workers=1, + eval_interval=1, + n_envs_per_worker=10, + use_subproc_env=False, + timelimit=256, + ) + + group_info = form_group_info(env_desc, lambda agent: "default") + + inference_namespace = "test_pb_rolloutworker" + + infer_manager = InferenceManager( + group_info=group_info, + ray_actor_namespace=inference_namespace, + algorithm=algorithm, + model_entry_point=None, + ) + + rollout_config.inference_entry_points = infer_manager.inference_entry_points + + strategy_specs = { + agent: StrategySpec( + policy_cls=algorithm.policy, + observation_space=obs_spaces[agent], + action_space=act_spaces[agent], + identifier=agent, + model_config=algorithm.model_config, + policy_ids=["policy-0"], + ) + for agent in agents + } + + worker = PBRolloutWorker( + env_desc=env_desc, + agent_groups=group_info["agent_groups"], + rollout_config=rollout_config, + log_dir=log_dir, + ) + + task = RolloutTask( + strategy_specs=strategy_specs, + stopping_conditions={"max_iteration": 10}, + data_entrypoint_mapping=None, # no data collect + ) + stats = worker.rollout(task) + + # def test_rollout_with_data_entrypoint(self, mocker: MockerFixture, n_player: int): + # with ray.init(local_mode=True): + # pass From 9c3af38fdd46ff78b9d62f478b5bac87d0b6b7c4 Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Tue, 14 Nov 2023 17:38:21 +0800 Subject: [PATCH 17/24] test passed for dynamic dataset --- malib/backend/dataset_server/data_loader.py | 53 +++--- malib/backend/dataset_server/feature.py | 107 +++++++++-- malib/backend/dataset_server/service.py | 4 +- malib/backend/dataset_server/utils.py | 36 +++- malib/common/task.py | 2 +- .../training_config.py => learner/config.py} | 6 +- malib/learner/indepdent_learner.py | 7 +- malib/learner/learner.py | 88 ++++----- malib/learner/manager.py | 18 +- malib/mocker/mocker_utils.py | 178 ++---------------- malib/rl/common/trainer.py | 9 +- malib/rl/config.py | 6 +- malib/rollout/inference/env_runner.py | 9 +- malib/rollout/inference/manager.py | 8 +- malib/rollout/pb_rolloutworker.py | 4 +- malib/rollout/rolloutworker.py | 28 +-- tests/backend/test_dataset_server.py | 144 -------------- tests/backend/test_dynamic_dataset.py | 164 ++++++++++++++++ tests/backend/test_parameter_server.py | 140 -------------- tests/rollout/test_pb_rollout_worker.py | 129 ++++++++++--- 20 files changed, 512 insertions(+), 628 deletions(-) rename malib/{common/training_config.py => learner/config.py} (87%) delete mode 100644 tests/backend/test_dataset_server.py create mode 100644 tests/backend/test_dynamic_dataset.py delete mode 100644 tests/backend/test_parameter_server.py diff --git a/malib/backend/dataset_server/data_loader.py b/malib/backend/dataset_server/data_loader.py index 9d774ffc..a62a217e 100644 --- a/malib/backend/dataset_server/data_loader.py +++ b/malib/backend/dataset_server/data_loader.py @@ -1,16 +1,14 @@ from typing import Type, Any +import socket import threading import grpc -import socket -from concurrent import futures -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset from malib.utils.general import find_free_port +from malib.backend.dataset_server.utils import service_wrapper -from .service import DatasetServer -from . import data_pb2_grpc from .feature import BaseFeature @@ -23,43 +21,34 @@ def __init__( self, grpc_thread_num_workers: int, max_message_length: int, - feature_handler_caller: Type, + feature_handler_cls: Type[BaseFeature], + **feature_handler_kwargs, ) -> None: super().__init__() # start a service as thread - self.feature_handler: BaseFeature = feature_handler_caller() - self.server = self._start_servicer( + self.feature_handler: BaseFeature = feature_handler_cls( + **feature_handler_kwargs + ) + self.server_port = find_free_port() + self.server = service_wrapper( grpc_thread_num_workers, max_message_length, - find_free_port(), - ) - self.host = socket.gethostbyname(socket.gethostbyname()) - - def _start_servicer( - self, max_workers: int, max_message_length: int, grpc_port: int - ): - server = grpc.server( - futures.ThreadPoolExecutor(max_workers=max_workers), - options=[ - ("grpc.max_send_message_length", max_message_length), - ("grpc.max_receive_message_length", max_message_length), - ], - ) - servicer = DatasetServer(self.feature_handler) - data_pb2_grpc.add_SendDataServicer_to_server(servicer, server) - - server.add_insecure_port(f"[::]:{grpc_port}") - server.start() - - return server + self.server_port, + )(self.feature_handler) + self.server.start() + self.host = socket.gethostbyname(socket.gethostname()) @property def entrypoint(self) -> str: - return f"{self.host}:{self.server._state.port}" + return f"{self.host}:{self.server_port}" + + @property + def readable_block_size(self) -> str: + return len(self.feature_handler) def __len__(self): - return self.feature_handler_caller.block_size + return self.feature_handler.block_size def __getitem__(self, index) -> Any: if index >= len(self): @@ -71,4 +60,4 @@ def __getitem__(self, index) -> Any: return self.feature_handler.safe_get(index) def close(self): - self.server.stop() + self.server.wait_for_termination(3) diff --git a/malib/backend/dataset_server/feature.py b/malib/backend/dataset_server/feature.py index 3276e96d..a759258d 100644 --- a/malib/backend/dataset_server/feature.py +++ b/malib/backend/dataset_server/feature.py @@ -1,30 +1,105 @@ -from typing import Any +from typing import Any, Dict +from abc import ABC, abstractmethod + +import copy +import numpy as np +import torch + +from gym import spaces from readerwriterlock import rwlock -class BaseFeature: - def __init__(self) -> None: +numpy_to_torch_dtype_dict = { + np.bool_: torch.bool, + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} + + +class BaseFeature(ABC): + def __init__( + self, + spaces: Dict[str, spaces.Space], + np_memory: Dict[str, np.ndarray], + block_size: int = None, + device: str = "cpu", + ) -> None: self.rw_lock = rwlock.RWLockFair() - self._readable_index = [] - self._writable_index = [] + self._device = device + self._spaces = spaces + self._block_size = block_size or list(np_memory.values())[0].shape[0] + self._available_size = 0 + self._flag = 0 + self._shared_memory = { + k: torch.from_numpy(v).to(device).share_memory_() + for k, v in np_memory.items() + } + + def get(self, index: int): + """Get data from this feature. + + Args: + index (int): Index of the data. + + Returns: + Any: Data + """ + data = {} + for k, v in self._shared_memory.items(): + data[k] = v[index] + return data + + def write(self, data: Dict[str, Any], start: int, end: int): + for k, v in data.items(): + self._shared_memory[k][start:end] = torch.as_tensor(v).to( + self._device, dtype=self._shared_memory[k].dtype + ) + + def generate_timestep(self) -> Dict[str, np.ndarray]: + return {k: space.sample() for k, space in self.spaces.items()} + + def generate_batch(self, batch_size: int = 1) -> Dict[str, np.ndarray]: + batch = {} + for k, space in self.spaces.items(): + data = np.stack( + [space.sample() for _ in range(batch_size)], dtype=space.dtype + ) + batch[k] = data + return batch + + @property + def spaces(self) -> Dict[str, spaces.Space]: + return copy.deepcopy(self._spaces) @property def block_size(self) -> int: - raise NotImplementedError + return self._block_size def __len__(self): - return len(self._readable_index) - - def _get(self, index: int): - raise NotImplementedError + return self._available_size def safe_get(self, index: int): with self.rw_lock.gen_rlock(): - return self._get(index) - - def _write(self, data: Any): - raise NotImplementedError + if len(self) == 0: + raise IndexError(f"index:{index} exceeds for available size is 0") + elif index >= len(self): + # re-sampling + index = index % self._available_size + return self.get(index) - def safe_put(self, data: Any): + def safe_put(self, data: Any, batch_size: int): with self.rw_lock.gen_wlock(): - self._write(data) + # request segment asscessment + self.write(data, self._flag, self._flag + batch_size) + self._flag = (self._flag + batch_size) % self._block_size + self._available_size = min( + self._available_size + batch_size, self._block_size + ) diff --git a/malib/backend/dataset_server/service.py b/malib/backend/dataset_server/service.py index d2af656c..2119787a 100644 --- a/malib/backend/dataset_server/service.py +++ b/malib/backend/dataset_server/service.py @@ -20,9 +20,11 @@ def __init__( def Collect(self, request, context): try: data = pickle.loads(request.data) - self.feature_handler.safe_put(data) + batch_size = len(list(data.values())[0]) + self.feature_handler.safe_put(data, batch_size) message = "success" except Exception as e: message = traceback.format_exc() + print(message) return data_pb2.Reply(message=message) diff --git a/malib/backend/dataset_server/utils.py b/malib/backend/dataset_server/utils.py index a78e6517..a24bce74 100644 --- a/malib/backend/dataset_server/utils.py +++ b/malib/backend/dataset_server/utils.py @@ -1,12 +1,14 @@ from typing import Any, Union +from concurrent import futures -import pickle -import grpc import sys import os +import pickle +import grpc sys.path.append(os.path.dirname(__file__)) +from .service import DatasetServer from . import data_pb2 from . import data_pb2_grpc @@ -25,3 +27,33 @@ def send_data(data: Any, host: str = None, port: int = None, entrypoint: str = N reply = stub.Collect(data_pb2.Data(data=data)) return reply.message + + +def service_wrapper(max_workers: int, max_message_length: int, grpc_port: int): + def func(feature_handler): + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=max_workers), + options=[ + ("grpc.max_send_message_length", max_message_length), + ("grpc.max_receive_message_length", max_message_length), + ], + ) + servicer = DatasetServer(feature_handler) + data_pb2_grpc.add_SendDataServicer_to_server(servicer, server) + + server.add_insecure_port(f"[::]:{grpc_port}") + return server + + return func + + +def start_server( + max_workers: int, max_message_length: int, grpc_port: int, feature_handler +): + server = service_wrapper( + max_workers=max_workers, + max_message_length=max_message_length, + grpc_port=grpc_port, + )(feature_handler) + server.start() + server.wait_for_termination() diff --git a/malib/common/task.py b/malib/common/task.py index 1155e967..15279f6c 100644 --- a/malib/common/task.py +++ b/malib/common/task.py @@ -20,7 +20,7 @@ class Task: class RolloutTask(Task): strategy_specs: Dict[str, Any] = field(default_factory=dict()) stopping_conditions: Dict[str, Any] = field(default_factory=dict()) - data_entrypoint_mapping: Dict[str, Any] = field(default_factory=dict()) + data_entrypoints: Dict[str, Any] = field(default_factory=dict()) @classmethod def from_raw( diff --git a/malib/common/training_config.py b/malib/learner/config.py similarity index 87% rename from malib/common/training_config.py rename to malib/learner/config.py index 4fc1ae0e..ff36d106 100644 --- a/malib/common/training_config.py +++ b/malib/learner/config.py @@ -1,13 +1,15 @@ -from typing import Dict, Any, Union +from typing import Dict, Any, Union, Type from dataclasses import dataclass, field +from malib.learner.learner import Learner + # TODO(ming): rename it as LearnerConfig @dataclass class TrainingConfig: trainer_config: Dict[str, Any] - learner_type: str + learner_type: Type[Learner] custom_config: Dict[str, Any] = field(default_factory=dict()) @classmethod diff --git a/malib/learner/indepdent_learner.py b/malib/learner/indepdent_learner.py index 18540916..a21f97de 100644 --- a/malib/learner/indepdent_learner.py +++ b/malib/learner/indepdent_learner.py @@ -22,12 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Tuple, Any, Callable, List, Type, Union - -import gym - -from gym import spaces -from malib.backend.dataset_server.data_loader import DynamicDataset +from typing import Dict, Tuple, Any, List, Union from malib.utils.typing import AgentID from malib.utils.tianshou_batch import Batch diff --git a/malib/learner/learner.py b/malib/learner/learner.py index 56fc4114..76d5b103 100644 --- a/malib/learner/learner.py +++ b/malib/learner/learner.py @@ -29,7 +29,6 @@ import traceback -import gym import torch import ray @@ -47,20 +46,19 @@ from malib.backend.dataset_server.data_loader import DynamicDataset from malib.rl.common.trainer import Trainer from malib.rl.common.policy import Policy +from malib.rl.config import Algorithm class Learner(RemoteInterface, ABC): """Base class of agent interface, for training""" - @abstractmethod def __init__( self, - experiment_tag: str, runtime_id: str, log_dir: str, observation_space: spaces.Space, action_space: spaces.Space, - algorithms: Dict[str, Tuple[Type, Type, Dict, Dict]], + algorithm: Algorithm, agent_mapping_func: Callable[[AgentID], str], governed_agents: Tuple[AgentID], trainer_config: Dict[str, Any], @@ -72,13 +70,11 @@ def __init__( """Construct agent interface for training. Args: - experiment_tag (str): Experiment tag. runtime_id (str): Assigned runtime id, should be an element of the agent mapping results. log_dir (str): The directory for logging. observation_space (gym.Space): Observation space. action_space (gym.Space): Action space. - algorithms (Dict[str, Tuple[Type, Type, Dict]]): A dict that describes the algorithm candidates. Each is \ - a tuple of `policy_cls`, `trainer_cls`, `model_config` and `custom_config`. + algorithms (Algorithm): Algorithm configuration. agent_mapping_func (Callable[[AgentID], str]): A function that defines the rule of agent groupping. governed_agents (Tuple[AgentID]): A tuple that records which agents is related to this learner. \ Note that it should be a subset of the original set of environment agents. @@ -96,16 +92,15 @@ def __init__( # initialize a strategy spec for policy maintainance. strategy_spec = StrategySpec( - policy_cls=algorithms["default"][0], + policy_cls=algorithm.policy, observation_space=observation_space, action_space=action_space, - model_config=algorithms["default"][2], - **algorithms["default"][3], + model_config=algorithm.model_config, ) self._runtime_id = runtime_id self._device = device - self._algorithms = algorithms + self._algorithm = algorithm self._governed_agents = governed_agents self._strategy_spec = strategy_spec self._agent_mapping_func = agent_mapping_func @@ -113,17 +108,45 @@ def __init__( self._policy = strategy_spec.gen_policy(device=device) self._summary_writer = tensorboard.SummaryWriter(log_dir=log_dir) - self._trainer_config = trainer_config # load policy for trainer - self._trainer: Trainer = algorithms["default"][1](trainer_config, self._policy) - self._total_step = 0 - self._total_epoch = 0 + self._trainer: Trainer = algorithm.trainer(trainer_config, self._policy) dataset = dataset or self.create_dataset() self._data_loader = DataLoader(dataset, batch_size=trainer_config["batch_size"]) + + self._total_step = 0 + self._total_epoch = 0 self._verbose = verbose + def create_dataset(self) -> DynamicDataset: + """Create dataset + + Returns: + DynamicDataset: Must be an subinstance of DynamicDataset + """ + return DynamicDataset( + grpc_thread_num_workers=1, + max_message_length=1024, + feature_handler_caller=None, + ) + + @abstractmethod + def multiagent_post_process( + self, + batch_info: Union[ + Dict[AgentID, Tuple[Batch, List[int]]], Tuple[Batch, List[int]] + ], + ) -> Dict[str, Any]: + """Merge agent buffer here and return the merged buffer. + + Args: + batch_info (Union[Dict[AgentID, Tuple[Batch, List[int]]], Tuple[Batch, List[int]]]): Batch info, could be a dict of agent batch info or a tuple. + + Returns: + Dict[str, Any]: A merged buffer dict. + """ + @property def verbose(self) -> bool: return self._verbose @@ -173,41 +196,6 @@ def get_strategy_spec(self) -> StrategySpec: def get_state_dict(self) -> Dict[str, torch.Tensor]: return self.policy.state_dict(device="cpu") - @abstractmethod - def create_dataset(self) -> DynamicDataset: - """Create dataset - - Returns: - DynamicDataset: Must be an subinstance of DynamicDataset - """ - - @abstractmethod - def add_policies(self, n: int) -> StrategySpec: - """Construct `n` new policies and return the latest strategy spec. - - Args: - n (int): Indicates how many new policies will be added. - - Returns: - StrategySpec: The latest strategy spec instance. - """ - - @abstractmethod - def multiagent_post_process( - self, - batch_info: Union[ - Dict[AgentID, Tuple[Batch, List[int]]], Tuple[Batch, List[int]] - ], - ) -> Dict[str, Any]: - """Merge agent buffer here and return the merged buffer. - - Args: - batch_info (Union[Dict[AgentID, Tuple[Batch, List[int]]], Tuple[Batch, List[int]]]): Batch info, could be a dict of agent batch info or a tuple. - - Returns: - Dict[str, Any]: A merged buffer dict. - """ - def get_interface_state(self) -> Dict[str, Any]: """Return a dict that describes the current learning state. diff --git a/malib/learner/manager.py b/malib/learner/manager.py index ea15b225..39b5f140 100644 --- a/malib/learner/manager.py +++ b/malib/learner/manager.py @@ -49,7 +49,8 @@ from malib.learner.learner import Learner from malib.common.strategy_spec import StrategySpec from malib.common.manager import Manager -from malib.common.training_config import TrainingConfig +from malib.learner.config import TrainingConfig +from malib.rl.config import Algorithm DEFAULT_RESOURCE_CONFIG = dict( @@ -60,9 +61,8 @@ class LearnerManager(Manager): def __init__( self, - experiment_tag: str, stopping_conditions: Dict[str, Any], - algorithms: Dict[str, Any], + algorithm: Algorithm, env_desc: Dict[str, Any], agent_mapping_func: Callable[[AgentID], str], group_info: Dict[str, Any], @@ -77,7 +77,7 @@ def __init__( Args: experiment_tag (str): Experiment identifier, for data tracking. - algorithms (Dict[str, Any]): The algorithms configuration candidates. + algorithm (Dict[str, Any]): The algorithms configuration candidates. env_desc (Dict[str, Any]): The description for environment generation. interface_config (Dict[str, Any]): Configuration for agent training inferece construction, keys include \ `type` and `custom_config`, a dict. @@ -118,12 +118,11 @@ def __init__( learners[rid] = learner_cls.options( name=f"learner_{rid}", max_concurrency=10, namespace=self.namespace ).remote( - experiment_tag=experiment_tag, runtime_id=rid, log_dir=f"{log_dir}/learner_{rid}", observation_space=group_info["observation_space"][rid], action_space=group_info["action_space"][rid], - algorithms=algorithms, + algorithm=algorithm, agent_mapping_func=agent_mapping_func, governed_agents=tuple(agents), trainer_config=training_config.trainer_config, @@ -150,7 +149,6 @@ def __init__( # TODO(ming): collect data entrypoints from learners self._group_info = group_info self._runtime_ids = tuple(group_info["agent_groups"].keys()) - self._experiment_tag = experiment_tag self._env_description = env_desc self._training_config = training_config self._log_dir = log_dir @@ -185,6 +183,12 @@ def data_entrypoints(self) -> Dict[str, str]: @property def learner_entrypoints(self) -> Dict[str, str]: + """Return a mapping from runtime ids to learner entrypoints. + + Returns: + Dict[str, str]: A dict of learner entrypoints. + """ + return self._learner_entrypoints @property diff --git a/malib/mocker/mocker_utils.py b/malib/mocker/mocker_utils.py index ff8407f8..1f85a5d2 100644 --- a/malib/mocker/mocker_utils.py +++ b/malib/mocker/mocker_utils.py @@ -22,82 +22,20 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Sequence, Dict, Any, Callable, List +from typing import Sequence, Dict, Any, Callable, List, Union import time import ray from ray.util import ActorPool +from malib.rollout.rollout_config import RolloutConfig from malib.utils.typing import AgentID -from malib.remote.interface import RemoteInterface from malib.common.strategy_spec import StrategySpec from malib.rollout.rolloutworker import RolloutWorker -class FakeInferenceClient(RemoteInterface): - def __init__( - self, - env_desc: Dict[str, Any], - dataset_server: ray.ObjectRef, - max_env_num: int, - use_subproc_env: bool = False, - batch_mode: str = "time_step", - postprocessor_types: Dict = None, - training_agent_mapping: Any = None, - custom_config: Dict[str, Any] = {}, - ): - self.max_env_num = max_env_num - self.agents = env_desc["possible_agents"] - - def add_envs(self, maxinum: int) -> int: - return self.max_env_num - - def close(self): - time.sleep(0.5) - return - - def run( - self, - agent_interfaces, - rollout_config, - dataset_writer_info_dict=None, - ) -> Dict[str, Any]: - time.sleep(0.5) - return { - "evaluation": [ - {f"agent_reward/{agent}_mean": 1.0 for agent in self.agents} - ], - "total_timesteps": 1000, - "FPS": 100000, - } - - -class FakeInferenceServer(RemoteInterface): - def __init__( - self, - agent_id, - observation_space, - action_space, - parameter_server, - governed_agents, - ) -> None: - pass - - def shutdown(self): - time.sleep(0.5) - return - - def save(self, model_dir: str): - print("called save method") - return - - def compute_action(self, dataframes, runtime_config): - print("called computation action") - return None - - class FakeRolloutWorker(RolloutWorker): def init_agent_interfaces( self, env_desc: Dict[str, Any], runtime_ids: Sequence[AgentID] @@ -190,17 +128,27 @@ def update_payoff( class FakeRolloutManager(RolloutWorkerManager): def __init__( self, - experiment_tag: str, stopping_conditions: Dict[str, Any], num_worker: int, - agent_mapping_func: Callable, - rollout_config: Dict[str, Any], + group_info: Dict[str, Any], + rollout_config: RolloutConfig | Dict[str, Any], env_desc: Dict[str, Any], log_dir: str, resource_config: Dict[str, Any] = None, + ray_actor_namespace: str = "rollout_worker", verbose: bool = True, ): - self.env_desc = env_desc + super().__init__( + stopping_conditions, + num_worker, + group_info, + rollout_config, + env_desc, + log_dir, + resource_config, + ray_actor_namespace, + verbose, + ) def rollout(self, task_list: List[Dict[str, Any]]) -> None: pass @@ -210,97 +158,3 @@ def wait(self) -> List[Any]: def terminate(self): pass - - -from typing import Union, Type -from collections import defaultdict -from malib.learner.manager import TrainingManager - - -class FakeTrainingManager(TrainingManager): - def __init__( - self, - experiment_tag: str, - stopping_conditions: Dict[str, Any], - algorithms: Dict[str, Any], - env_desc: Dict[str, Any], - agent_mapping_func: Callable[[AgentID], str], - training_config: Dict[str, Any], - log_dir: str, - remote_mode: bool = True, - resource_config: Dict[str, Any] = None, - verbose: bool = True, - ): - agent_groups = defaultdict(lambda: set()) - for agent in env_desc["possible_agents"]: - rid = agent_mapping_func(agent) - agent_groups[rid].add(agent) - - self.env_desc = env_desc - self._agent_groups = agent_groups - self._runtime_ids = tuple(self._agent_groups.keys()) - self._agent_mapping_func = agent_mapping_func - self.algorithm = algorithms - self._experiment_tag = experiment_tag - - def add_policies( - self, interface_ids: Sequence[str] = None, n: Union[int, Dict[str, int]] = 1 - ) -> Dict[str, Type[StrategySpec]]: - # return a strategy specs that contains n policy - return { - agent: StrategySpec( - identifier=agent, - policy_ids=[f"policy-{i}" for i in range(n)], - meta_data={ - "policy_cls": self.algorithm["default"][0], - "kwargs": None, - "experiment_tag": self._experiment_tag, - "prob_list": [1 / n] * n, - }, - ) - for agent in self.env_desc["possible_agents"] - } - - @property - def runtime_ids(self) -> Tuple[str]: - return self._runtime_ids - - def run(self, data_request_identifiers: Dict[str, str]): - pass - - def wait(self) -> List[Any]: - time.sleep(0.1) - - def cancel_pending_tasks(self): - pass - - def terminate(self) -> None: - pass - - -import contextlib -import ray - -from malib.runner import start_servers - - -@contextlib.contextmanager -def use_ray_env(namespace: str = None): - """Start a ray cluster and init parameter server and dataset server. - - Yields: - Tuple[Any, Any]: A tuple of parameter_server and dataset_server. - """ - - parameter_server, dataset_server = None, None - try: - if not ray.is_initialized(): - ray.init() - parameter_server, dataset_server = start_servers() - yield (parameter_server, dataset_server) - finally: - if parameter_server is not None: - ray.kill(parameter_server) - if dataset_server is not None: - ray.kill(dataset_server) - ray.shutdown() diff --git a/malib/rl/common/trainer.py b/malib/rl/common/trainer.py index 16897aa0..77b576a4 100644 --- a/malib/rl/common/trainer.py +++ b/malib/rl/common/trainer.py @@ -22,16 +22,12 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Any, Sequence, Union, Type, List - -import torch +from typing import Dict, Any, Sequence, Type, List from abc import ABCMeta, abstractmethod -from functools import reduce from malib.utils.typing import AgentID from malib.rl.common.policy import Policy -from malib.utils.data import to_torch from malib.utils.tianshou_batch import Batch @@ -156,6 +152,3 @@ def reset(self, policy_instance=None, configs=None, learning_mode: str = None): if configs is not None: self.training_config.update(configs) - - -TrainerType = Type[Trainer] diff --git a/malib/rl/config.py b/malib/rl/config.py index 03e4c6d4..730be866 100644 --- a/malib/rl/config.py +++ b/malib/rl/config.py @@ -1,4 +1,4 @@ -from typing import Dict, Any +from typing import Dict, Any, Type from dataclasses import dataclass @@ -9,8 +9,8 @@ @dataclass class Algorithm: - policy: Policy + policy: Type[Policy] - trainer: Trainer + trainer: Type[Trainer] model_config: Dict[str, Any] diff --git a/malib/rollout/inference/env_runner.py b/malib/rollout/inference/env_runner.py index bc28882e..0167f77c 100644 --- a/malib/rollout/inference/env_runner.py +++ b/malib/rollout/inference/env_runner.py @@ -120,6 +120,7 @@ def merge_episodes(self): from malib.utils.timing import Timing +from malib.backend.dataset_server.utils import send_data class BasicEnvRunner(RemoteInterface): @@ -173,7 +174,7 @@ def run( rollout_config: RolloutConfig, strategy_specs: Dict[AgentID, StrategySpec], inference_clients: Dict[AgentID, InferenceClient] = None, - data_entrypoint_mapping: Dict[AgentID, str] = None, + data_entrypoints: Dict[str, str] = None, ): """Single thread env simulation stepping. @@ -181,7 +182,7 @@ def run( rollout_config (RolloutConfig): Rollout configuration, which specifies how many data pieces will rollout. strategy_specs (Dict[AgentID, StrategySpec]): A dict of strategy specs, which rules the behavior policy of each agent. inference_clients (Dict[AgentID, InferenceClient]): A dict of remote inference client. - data_entrypoint_mapping (Dict[AgentID, str], optional): A mapping which defines the data collection trigger, if not None, then return episodes. Defaults to None. + data_entrypoints (Dict[str, str], optional): A mapping which defines the data collection trigger, if not None, then return episodes. Defaults to None. Raises: e: _description_ @@ -263,5 +264,9 @@ def run( # merge agent episodes # FIXME(ming): send data to remote dataset data = agent_manager.merge_episodes() + data_entrypoints = data_entrypoints or {} + for entrypoint in data_entrypoints.values(): + send_data(data, entrypoint=entrypoint) + stats = {"total_timesteps": total_timestep, **timer.todict()} return stats diff --git a/malib/rollout/inference/manager.py b/malib/rollout/inference/manager.py index edac1189..0bbd3a82 100644 --- a/malib/rollout/inference/manager.py +++ b/malib/rollout/inference/manager.py @@ -62,7 +62,13 @@ def inference_clients(self) -> Dict[str, ray.ObjectRef]: return self._infer_clients @property - def inference_entry_points(self) -> str: + def inference_entry_points(self) -> Dict[str, str]: + """Return a mapping of inference client entrypoints. + + Returns: + Dict[str, str]: A dict mapping from runtime id to entrypoints. + """ + return self._inference_entry_points def submit(self, task: Any, wait: bool = False): diff --git a/malib/rollout/pb_rolloutworker.py b/malib/rollout/pb_rolloutworker.py index deb15a83..de7adee3 100644 --- a/malib/rollout/pb_rolloutworker.py +++ b/malib/rollout/pb_rolloutworker.py @@ -40,13 +40,13 @@ def step_rollout( self, eval_step: bool, strategy_specs: Dict[AgentID, StrategySpec], - data_entrypoint_mapping: Dict[AgentID, str], + data_entrypoints: Dict[str, str], ) -> List[Dict[str, Any]]: results = ray.get( self.env_runner.run.remote( rollout_config=self.rollout_config, strategy_specs=strategy_specs, - data_entrypoint_mapping=data_entrypoint_mapping, + data_entrypoints=data_entrypoints, ) ) # check evaluation info diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index aa35ac44..455b8d1c 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -91,28 +91,6 @@ def default_simulate_callback(results: Dict[str, Any]): pass -def validate_runtime_configs(configs: Dict[str, Any]): - """Validate runtime configuration. - - Args: - configs (Dict[str, Any]): Raw runtime configuration - - Raises: - AssertionError: Key not in configs - """ - - assert "fragment_length" in configs - assert "max_step" in configs - assert "num_eval_episodes" in configs - assert "num_threads" in configs - assert "num_env_per_thread" in configs - assert "num_eval_threads" in configs - assert "use_subproc_env" in configs - assert "batch_mode" in configs - assert "postprocessor_types" in configs - assert "eval_interval" in configs - - class RolloutWorker(RemoteInterface): def __init__( self, @@ -223,7 +201,7 @@ def rollout(self, task: RolloutTask): results = self.step_rollout( eval_step, task.strategy_specs, - task.data_entrypoint_mapping, + task.data_entrypoints, ) # total_timesteps += results["total_timesteps"] @@ -263,13 +241,13 @@ def step_rollout( self, eval_step: bool, strategy_specs: Dict[AgentID, StrategySpec], - data_entrypoint_mapping: Dict[AgentID, str], + data_entrypoints: Dict[str, str], ) -> List[Dict[str, Any]]: """The logic function to run rollout. Users must implment this method. Args: eval_step (bool): Indicate evaluation or not. - data_entrypoint_mapping: ... + data_entrypoints: ... Raises: NotImplementedError: _description_ diff --git a/tests/backend/test_dataset_server.py b/tests/backend/test_dataset_server.py deleted file mode 100644 index d51a4697..00000000 --- a/tests/backend/test_dataset_server.py +++ /dev/null @@ -1,144 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Author: Ming Zhou - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from threading import Thread - -import time -import pytest -import numpy as np -import ray - -from ray.util.queue import Queue -from readerwriterlock import rwlock - -from malib.utils.episode import Episode -from malib.utils.tianshou_batch import Batch -from malib.utils.replay_buffer import ReplayBuffer -from malib.backend.offline_dataset_server import ( - OfflineDataset, - write_table, - read_table, -) - - -def is_actor_done(actor): - if actor is None: - return True - done_ref = actor.__ray_terminate__.remote() - done, not_done = ray.wait([done_ref], timeout=5) - return len(not_done) == 0 - - -def start_reader_thread(reader, n_round, read_size): - try: - for i in range(n_round): - if not is_actor_done(reader.actor): - batch_info = reader.get() - else: - break - except Exception as e: - print("done for actor has been terminated") - - -def start_writer_thread(writer, n_round, write_size): - obs_array = np.random.random((write_size, 3)) - action_array = np.random.random((write_size, 2)) - rew_array = np.random.random(write_size) - obs_next_array = np.random.random((write_size, 3)) - - batch = Batch( - { - Episode.CUR_OBS: obs_array, - Episode.ACTION: action_array, - Episode.NEXT_OBS: obs_next_array, - Episode.REWARD: rew_array, - } - ) - try: - for _ in range(n_round): - if is_actor_done(writer.actor): - writer.put_nowait_batch([batch]) - else: - break - except Exception as e: - print("done for actor has been terminated") - - -@pytest.mark.parametrize("read_size,write_size", [(64, 64), (64, 128), (128, 64)]) -def test_datatable_read_and_write(read_size: int, write_size: int): - if not ray.is_initialized(): - ray.init() - - buffer_size = 10000 - marker = rwlock.RWLockFair() - writer = Queue(actor_options={"num_cpus": 0.1, "name": str(time.time())}) - reader = Queue(actor_options={"num_cpus": 0.1, "name": str(time.time())}) - buffer = ReplayBuffer(buffer_size) - - write_thread = Thread(target=write_table, args=(marker, buffer, writer)) - read_thread = Thread(target=read_table, args=(marker, buffer, read_size, reader)) - - write_thread.start() - read_thread.start() - - n_round = 1000 - reader_thread = Thread( - target=start_reader_thread, args=(reader, n_round, read_size) - ) - writer_thread = Thread( - target=start_writer_thread, args=(writer, n_round, write_size) - ) - - reader_thread.start() - writer_thread.start() - - reader_thread.join() - writer_thread.join() - - reader.shutdown() - writer.shutdown() - - read_thread.join() - write_thread.join() - - ray.shutdown() - - -def test_offline_dataset(): - if not ray.is_initialized(): - ray.init() - - server = OfflineDataset(table_capacity=10000) - server.start() - - # functionality test - pname, pqueue = server.start_producer_pipe(name="test_offline_dataset") - cname, cqueue = server.start_consumer_pipe( - name="test_offline_dataset", batch_size=64 - ) - - server.end_consumer_pipe(name=cname) - server.end_producer_pipe(name=pname) - - ray.shutdown() diff --git a/tests/backend/test_dynamic_dataset.py b/tests/backend/test_dynamic_dataset.py new file mode 100644 index 00000000..304ecb0c --- /dev/null +++ b/tests/backend/test_dynamic_dataset.py @@ -0,0 +1,164 @@ +from typing import Any, Dict + +import time +import random +import multiprocessing +import pytest +import threading +import numpy as np + +from gym import spaces + +from malib.backend.dataset_server.data_loader import DynamicDataset +from malib.backend.dataset_server.utils import send_data, start_server +from malib.backend.dataset_server.feature import BaseFeature + + +class FakeFeatureHandler(BaseFeature): + def write(self, data: Dict[str, Any], start: int, end: int): + print("[FakeFeatureHandler] write data, size:", self._available_size) + return super().write(data, start, end) + + def get(self, index: int): + print("[FakeFeatureHandler] get data for index={}".format(index)) + return super().get(index) + + @classmethod + def gen_instance(cls): + return cls( + spaces={ + "a": spaces.Box(-1.0, 1.0, shape=(4,)), + "b": spaces.Discrete(2), + }, + block_size=1024, + ) + + +class TestDynamicDataset: + def test_grpc_service_write(self): + grpc_port = 8899 + _spaces = { + "a": spaces.Box(-1.0, 1.0, shape=(4,)), + "b": spaces.Discrete(2), + } + feature_handler = FakeFeatureHandler( + spaces=_spaces, + np_memory={k: np.zeros((1024,) + v.shape) for k, v in _spaces.items()}, + ) + + # start server proc + server_proc = multiprocessing.Process( + target=start_server, + args=( + 2, + 1024, + grpc_port, + feature_handler, + ), + ) + server_proc.start() + + # send data + for _ in range(10): + message = send_data( + feature_handler.generate_timestep(), host="localhost", port=grpc_port + ) + time.sleep(1) + print("returned message:", message) + + server_proc.terminate() + + def test_sync_grpc_service_get(self): + _spaces = { + "a": spaces.Box(-1.0, 1.0, shape=(4,)), + "b": spaces.Discrete(2), + } + dataset = DynamicDataset( + grpc_thread_num_workers=2, + max_message_length=1024, + feature_handler_cls=FakeFeatureHandler, + spaces=_spaces, + np_memory={ + k: np.zeros((1024,) + v.shape, dtype=v.dtype) + for k, v in _spaces.items() + }, + ) + + # send data + print("send 10 piece of data, entrypoint=", dataset.entrypoint) + for _ in range(10): + message = send_data( + dataset.feature_handler.generate_batch(batch_size=1), + entrypoint=dataset.entrypoint, + ) + time.sleep(1) + print("returned message:", message) + + # sample data + assert dataset.readable_block_size == 10, ( + dataset.readable_block_size, + dataset.feature_handler._available_size, + ) + for _ in range(10): + idx = random.randint(0, dataset.readable_block_size - 1) + data = dataset[idx] + assert isinstance(data, dict), type(data) + for k, v in data.items(): + # convert v to numpy + v = v.cpu().numpy() + assert dataset.feature_handler.spaces[k].contains(v), (k, v) + + dataset.close() + + def test_async_grpc_service_get(self): + _spaces = { + "a": spaces.Box(-1.0, 1.0, shape=(4,)), + "b": spaces.Discrete(2), + } + dataset = DynamicDataset( + grpc_thread_num_workers=2, + max_message_length=1024, + feature_handler_cls=FakeFeatureHandler, + spaces=_spaces, + np_memory={ + k: np.zeros((100,) + v.shape, dtype=v.dtype) for k, v in _spaces.items() + }, + ) + + def start_send(batch, entrypoint): + print("send 10 piece of data, entrypoint=", entrypoint) + for data in batch: + message = send_data( + data, + entrypoint=entrypoint, + ) + time.sleep(1) + print("returned message:", message) + + batch = [ + dataset.feature_handler.generate_batch(batch_size=1) for _ in range(10) + ] + entrypoint = dataset.entrypoint + send_proc = threading.Thread(target=start_send, args=(batch, entrypoint)) + + send_proc.start() + + def start_get(): + while dataset.readable_block_size == 0: + time.sleep(0.1) + + for _ in range(10): + idx = random.randint(0, dataset.readable_block_size) + data = dataset[idx] + assert isinstance(data, dict), type(data) + for k, v in data.items(): + # convert v to numpy + v = v.cpu().numpy() + assert dataset.feature_handler.spaces[k].contains(v), (k, v) + + get_proc = threading.Thread(target=start_get) + get_proc.start() + + send_proc.join() + get_proc.join() + dataset.close() diff --git a/tests/backend/test_parameter_server.py b/tests/backend/test_parameter_server.py deleted file mode 100644 index edc52140..00000000 --- a/tests/backend/test_parameter_server.py +++ /dev/null @@ -1,140 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Author: Ming Zhou - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import pytest -import gym -import numpy as np -import torch - -from gym import spaces - -from malib import rl -from malib.backend.parameter_server import Table, ParameterServer -from malib.rl.common.policy import Policy -from malib.common.strategy_spec import StrategySpec - - -@pytest.mark.parametrize("optim_config", [None, {"type": "Adam", "lr": 1e-4}]) -@pytest.mark.parametrize( - "policy_cls,rl_default_config", - [ - [rl.a2c.A2CPolicy, rl.a2c.DEFAULT_CONFIG], - [rl.dqn.DQNPolicy, rl.dqn.DEFAULT_CONFIG], - [rl.pg.PGPolicy, rl.pg.DEFAULT_CONFIG], - ], -) -def test_parameter_table(optim_config, policy_cls, rl_default_config): - observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4, 3)) - action_space = spaces.Discrete(5) - - policy_kwargs = { - "observation_space": observation_space, - "action_space": action_space, - "model_config": rl_default_config["model_config"], - "custom_config": rl_default_config["custom_config"], - "kwargs": {}, - } - policy_copy: Policy = policy_cls( - observation_space=observation_space, - action_space=action_space, - model_config=rl_default_config["model_config"], - custom_config=rl_default_config["custom_config"], - ) - table = Table( - policy_meta_data={ - "policy_cls": policy_cls, - "optim_config": optim_config, - "kwargs": policy_kwargs, - } - ) - - # set weights from policy - table.set_weights(policy_copy.state_dict()) - - # check weights - table_weights = table.get_weights() - for k, v in policy_copy.state_dict().items(): - if isinstance(v, dict): - for _k, _v in v.items(): - assert torch.all(_v == table_weights[k][_k]), (k, _k) - - # TODO(ming): test gradient apply here, if the method has been implemented - - -@pytest.mark.parametrize("optim_config", [None, {"type": "Adam", "lr": 1e-4}]) -@pytest.mark.parametrize( - "policy_cls,rl_default_config", - [ - [rl.a2c.A2CPolicy, rl.a2c.DEFAULT_CONFIG], - [rl.dqn.DQNPolicy, rl.dqn.DEFAULT_CONFIG], - [rl.pg.PGPolicy, rl.pg.DEFAULT_CONFIG], - ], -) -def test_parameter_server(optim_config, policy_cls, rl_default_config): - server = ParameterServer() - - server.start() - - observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4, 3)) - action_space = spaces.Discrete(5) - - policy_kwargs = { - "observation_space": observation_space, - "action_space": action_space, - "model_config": rl_default_config["model_config"], - "custom_config": rl_default_config["custom_config"], - "kwargs": {}, - } - - policy_copy: Policy = policy_cls( - observation_space=observation_space, - action_space=action_space, - model_config=rl_default_config["model_config"], - custom_config=rl_default_config["custom_config"], - ) - - # create a parameter table - strategy_spec = StrategySpec( - identifier="test_parameter_server", - policy_ids=[f"policy-{i}" for i in range(10)], - meta_data={ - "policy_cls": policy_cls, - "kwargs": policy_kwargs, - "experiment_tag": "test_parameter_server", - }, - ) - server.create_table(strategy_spec=strategy_spec) - - # try to create repeated table: should jump over - server.create_table(strategy_spec=strategy_spec) - - # set weights - server.set_weights( - spec_id=strategy_spec.id, - spec_policy_id="policy-1", - state_dict=policy_copy.state_dict(), - ) - - # retrive weights - server.get_weights(spec_id=strategy_spec.id, spec_policy_id="policy-1") diff --git a/tests/rollout/test_pb_rollout_worker.py b/tests/rollout/test_pb_rollout_worker.py index a8652705..e880a3ea 100644 --- a/tests/rollout/test_pb_rollout_worker.py +++ b/tests/rollout/test_pb_rollout_worker.py @@ -22,13 +22,11 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Any +from typing import Dict, Any, List, Tuple import pytest import ray - -from pytest_mock import MockerFixture -from gym import spaces +from malib.backend.dataset_server.data_loader import DynamicDataset from malib.common.task import RolloutTask from malib.common.strategy_spec import StrategySpec @@ -39,6 +37,8 @@ from malib.rollout.pb_rolloutworker import PBRolloutWorker from malib.rollout.inference.manager import InferenceManager from malib.scenarios.scenario import form_group_info +from malib.utils.tianshou_batch import Batch +from malib.utils.typing import AgentID def gen_rollout_config(inference_server_type: str): @@ -57,32 +57,56 @@ def gen_rollout_config(inference_server_type: str): } +def gen_common_requirements(n_player: int): + env_desc = env_desc_gen(num_agents=n_player) + + algorithm = Algorithm( + policy=RandomPolicy, + trainer=None, + model_config=None, + ) + + rollout_config = RolloutConfig( + num_workers=1, + eval_interval=1, + n_envs_per_worker=10, + use_subproc_env=False, + timelimit=256, + ) + + group_info = form_group_info(env_desc, lambda agent: "default") + + return env_desc, algorithm, rollout_config, group_info + + +from malib.learner.learner import Learner +from gym import spaces +from malib.learner.learner import Learner +from malib.learner.manager import LearnerManager +from malib.learner.config import TrainingConfig + + +class FakeLearner(Learner): + def multiagent_post_process( + self, + batch_info, + ) -> Dict[str, Any]: + pass + + @pytest.mark.parametrize("n_player", [1, 2]) class TestRolloutWorker: def test_rollout(self, n_player: int): with ray.init(local_mode=True): - env_desc = env_desc_gen(num_agents=n_player) + env_desc, algorithm, rollout_config, group_info = gen_common_requirements( + n_player + ) + obs_spaces = env_desc["observation_spaces"] act_spaces = env_desc["action_spaces"] agents = env_desc["possible_agents"] log_dir = "./logs" - algorithm = Algorithm( - policy=RandomPolicy, - trainer=None, - model_config=None, - ) - - rollout_config = RolloutConfig( - num_workers=1, - eval_interval=1, - n_envs_per_worker=10, - use_subproc_env=False, - timelimit=256, - ) - - group_info = form_group_info(env_desc, lambda agent: "default") - inference_namespace = "test_pb_rolloutworker" infer_manager = InferenceManager( @@ -120,6 +144,63 @@ def test_rollout(self, n_player: int): ) stats = worker.rollout(task) - # def test_rollout_with_data_entrypoint(self, mocker: MockerFixture, n_player: int): - # with ray.init(local_mode=True): - # pass + def test_rollout_with_data_entrypoint(self, n_player: int): + with ray.init(local_mode=True): + env_desc, algorithm, rollout_config, group_info = gen_common_requirements( + n_player + ) + + obs_spaces = env_desc["observation_spaces"] + act_spaces = env_desc["action_spaces"] + agents = env_desc["possible_agents"] + log_dir = "./logs" + + inference_namespace = "test_pb_rolloutworker" + + learner_manager = LearnerManager( + stopping_conditions={"max_iteration": 10}, + algorithm=algorithm, + env_desc=env_desc, + agent_mapping_func=lambda agent: "default", + group_info=group_info, + training_config=TrainingConfig( + trainer_config={}, learner_type=FakeLearner, custom_config=None + ), + log_dir=log_dir, + ) + + infer_manager = InferenceManager( + group_info=group_info, + ray_actor_namespace=inference_namespace, + algorithm=algorithm, + model_entry_point=learner_manager.learner_entrypoints, + ) + + rollout_config.inference_entry_points = infer_manager.inference_entry_points + + strategy_spaces = { + agent: StrategySpec( + policy_cls=algorithm.policy, + observation_space=obs_spaces[agent], + action_space=act_spaces[agent], + identifier=agent, + model_config=algorithm.model_config, + policy_ids=["policy-0"], + ) + for agent in agents + } + + worker = PBRolloutWorker( + env_desc=env_desc, + agent_groups=group_info["agent_groups"], + rollout_config=rollout_config, + log_dir=log_dir, + ) + + task = RolloutTask( + strategy_specs=strategy_spaces, + stopping_conditions={"max_iteration": 10}, + data_entrypoint_mapping=learner_manager.data_entrypoints, + ) + + stats = worker.rollout(task) From 7a77d431a93a31799906b7ef7f6e150f643dd091 Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Fri, 17 Nov 2023 18:10:10 +0800 Subject: [PATCH 18/24] tmp save: test rolloutm nanager --- examples/sarl/ppo_gym.py | 74 +++---- malib/backend/dataset_server/data_loader.py | 13 +- malib/backend/offline_dataset_server.py | 193 ------------------ malib/backend/parameter_server.py | 158 -------------- malib/learner/config.py | 15 +- malib/learner/learner.py | 54 +++-- malib/learner/manager.py | 27 +-- malib/mocker/mocker_utils.py | 56 +---- malib/registration.py | 90 -------- malib/rl/common/policy.py | 14 +- malib/rl/config.py | 2 + .../rollout/{rollout_config.py => config.py} | 0 malib/rollout/inference/env_runner.py | 2 +- malib/rollout/manager.py | 8 +- malib/rollout/rolloutworker.py | 2 +- malib/scenarios/sarl_scenario.py | 40 ++-- malib/scenarios/scenario.py | 16 +- malib/settings.py | 48 +---- malib/utils/general.py | 2 - tests/backend/test_dynamic_dataset.py | 2 + tests/rollout/test_env_runner.py | 2 +- tests/rollout/test_mdp_env.py | 0 tests/rollout/test_open_spiel.py | 0 tests/rollout/test_pb_rollout_worker.py | 44 +++- tests/rollout/test_rollout_manager.py | 172 ++++++---------- 25 files changed, 242 insertions(+), 792 deletions(-) delete mode 100644 malib/backend/offline_dataset_server.py delete mode 100644 malib/backend/parameter_server.py delete mode 100644 malib/registration.py rename malib/rollout/{rollout_config.py => config.py} (100%) delete mode 100644 tests/rollout/test_mdp_env.py delete mode 100644 tests/rollout/test_open_spiel.py diff --git a/examples/sarl/ppo_gym.py b/examples/sarl/ppo_gym.py index 0a734e9f..ee37ec76 100644 --- a/examples/sarl/ppo_gym.py +++ b/examples/sarl/ppo_gym.py @@ -4,10 +4,11 @@ from argparse import ArgumentParser from malib.learner import IndependentAgent -from malib.scenarios.marl_scenario import MARLScenario - -from malib.runner import run +from malib.scenarios import sarl_scenario +from malib.rl.config import Algorithm from malib.rl.ppo import PPOPolicy, PPOTrainer, DEFAULT_CONFIG +from malib.learner.config import LearnerConfig +from malib.rollout.config import RolloutConfig from malib.rollout.envs.gym import env_desc_gen @@ -23,59 +24,38 @@ trainer_config["total_timesteps"] = int(1e6) trainer_config["use_cuda"] = args.use_cuda - training_config = { - "learner_type": IndependentAgent, - "trainer_config": trainer_config, - "custom_config": {}, - } - - rollout_config = { - "fragment_length": 2000, # determine the size of sended data block - "max_step": 200, - "num_eval_episodes": 10, - "num_threads": 2, - "num_env_per_thread": 10, - "num_eval_threads": 1, - "use_subproc_env": False, - "batch_mode": "time_step", - "postprocessor_types": ["defaults"], - # every # rollout epoch run evaluation. - "eval_interval": 1, - "inference_server": "ray", # three kinds of inference server: `local`, `pipe` and `ray` - } - - # one to one, no sharing, if sharing, implemented as: - # agent_mapping_func = lambda agent: "default" - agent_mapping_func = lambda agent: agent - - algorithms = { - "default": ( - PPOPolicy, - PPOTrainer, - # model configuration, None as default - {}, - {"use_cuda": args.use_cuda}, - ) - } - - env_description = env_desc_gen(env_id=args.env_id, scenario_configs={}) - runtime_logdir = os.path.join(args.log_dir, f"sa_ppo_gym/{time.time()}") + runtime_logdir = os.path.join( + args.log_dir, f"gym/{args.env_id}/independent_ppo/{time.time()}" + ) if not os.path.exists(runtime_logdir): os.makedirs(runtime_logdir) - scenario = MARLScenario( + scenario = sarl_scenario.SARLScenario( name=f"ppo-gym-{args.env_id}", log_dir=runtime_logdir, - algorithms=algorithms, - env_description=env_description, - training_config=training_config, - rollout_config=rollout_config, - agent_mapping_func=agent_mapping_func, + env_desc=env_desc_gen(env_id=args.env_id), + algorithm=Algorithm( + trainer=PPOTrainer, + policy=PPOPolicy, + model_config=None, # use default + trainer_config=trainer_config, + ), + learner_config=LearnerConfig( + learner_type=IndependentAgent, + feature_handler_meta_gen=None, + custom_config={}, + ), + rollout_config=RolloutConfig( + num_workers=1, + ), + agent_mapping_func=lambda agent: agent, stopping_conditions={ "training": {"max_iteration": int(1e10)}, "rollout": {"max_iteration": 1000, "minimum_reward_improvement": 1.0}, }, ) - run(scenario) + results = sarl_scenario.execution_plan( + experiment_tag=scenario.name, scenario=scenario, verbose=True + ) diff --git a/malib/backend/dataset_server/data_loader.py b/malib/backend/dataset_server/data_loader.py index a62a217e..92000c41 100644 --- a/malib/backend/dataset_server/data_loader.py +++ b/malib/backend/dataset_server/data_loader.py @@ -21,19 +21,24 @@ def __init__( self, grpc_thread_num_workers: int, max_message_length: int, - feature_handler_cls: Type[BaseFeature], + feature_handler: BaseFeature = None, + feature_handler_cls: Type[BaseFeature] = None, **feature_handler_kwargs, ) -> None: super().__init__() # start a service as thread - self.feature_handler: BaseFeature = feature_handler_cls( + self.feature_handler: BaseFeature = feature_handler or feature_handler_cls( **feature_handler_kwargs ) + self.grpc_thread_num_workers = grpc_thread_num_workers + self.max_message_length = max_message_length + + def start_server(self): self.server_port = find_free_port() self.server = service_wrapper( - grpc_thread_num_workers, - max_message_length, + self.grpc_thread_num_workers, + self.max_message_length, self.server_port, )(self.feature_handler) self.server.start() diff --git a/malib/backend/offline_dataset_server.py b/malib/backend/offline_dataset_server.py deleted file mode 100644 index b99c9c88..00000000 --- a/malib/backend/offline_dataset_server.py +++ /dev/null @@ -1,193 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Author: Ming Zhou - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from typing import Dict, Any, Tuple, Union, List -from concurrent.futures import ThreadPoolExecutor -from readerwriterlock import rwlock - -import traceback -import time - -import numpy as np -import ray - -from ray.util.queue import Queue - -from malib.remote.interface import RemoteInterface -from malib.utils.logging import Logger -from malib.utils.tianshou_batch import Batch -from malib.utils.replay_buffer import ReplayBuffer, MultiagentReplayBuffer - - -def write_table( - marker: rwlock.RWLockFair, - buffer: Union[MultiagentReplayBuffer, ReplayBuffer], - writer: Queue, -): - wlock = marker.gen_wlock() - while True: - try: - batches: Union[Batch, List[Batch]] = writer.get() - with wlock: - if not isinstance(batches, List): - batches = [batches] - for e in batches: - buffer.add_batch(e) - except Exception as e: - print(traceback.format_exc()) - break - - -def read_table( - marker: rwlock.RWLockFair, - buffer: Union[MultiagentReplayBuffer, ReplayBuffer], - batch_size: int, - reader: Queue, -): - rlock = marker.gen_rlock() - while True: - try: - with rlock: - if len(buffer) >= batch_size: - ret = buffer.sample(batch_size) - # batch, indices = buffer.sample(batch_size) - else: - # batch, indices = [], np.array([], int) - if isinstance(buffer, MultiagentReplayBuffer): - ret = {} - else: - ret = ([], np.array([], int)) - reader.put_nowait(ret) - except Exception as e: - print(traceback.format_exc()) - break - - -class OfflineDataset(RemoteInterface): - def __init__(self, table_capacity: int, max_consumer_size: int = 1024) -> None: - """Construct an offline datataset. It maintans a dict of datatable, each for a training instance. - - Args: - table_capacity (int): Table capacity, it indicates the buffer size of each data table. - max_consumer_size (int, optional): Defines the maximum of concurrency. Defaults to 1024. - """ - - self.tb_capacity = table_capacity - self.reader_queues: Dict[str, Queue] = {} - self.writer_queues: Dict[str, Queue] = {} - self.buffers: Dict[str, ReplayBuffer] = {} - self.markers: Dict[str, rwlock.RWLockFair] = {} - self.thread_pool = ThreadPoolExecutor(max_workers=max_consumer_size) - - def start(self): - Logger.info("Dataset server started") - - def start_producer_pipe( - self, - name: str, - stack_num: int = 1, - ignore_obs_next: bool = False, - save_only_last_obs: bool = False, - sample_avail: bool = False, - **kwargs, - ) -> Tuple[str, Queue]: - """Start a producer pipeline and create a datatable if not exisits. - - Args: - name (str): The name of datatable need to access - stack_num (int, optional): Indicates how many steps are stacked in a single data sample. Defaults to 1. - ignore_obs_next (bool, optional): Ignore the next observation or not. Defaults to False. - save_only_last_obs (bool, optional): Either save only the last observation frame. Defaults to False. - sample_avail (bool, optional): Sample action maks or not. Defaults to False. - - Returns: - Tuple[str, Queue]: A tuple of table name and queue for insert samples. - """ - - if name not in self.buffers: - buffer = ReplayBuffer( - size=self.tb_capacity, - stack_num=stack_num, - ignore_obs_next=ignore_obs_next, - save_only_last_obs=save_only_last_obs, - sample_avail=sample_avail, - **kwargs, - ) - marker = rwlock.RWLockFair() - - self.buffers[name] = buffer - self.markers[name] = marker - - if name not in self.writer_queues: - writer = Queue(actor_options={"num_cpus": 0}) - self.writer_queues[name] = writer - self.thread_pool.submit( - write_table, self.markers[name], self.buffers[name], writer - ) - - return name, self.writer_queues[name] - - def end_producer_pipe(self, name: str): - """Kill a producer pipe with given name. - - Args: - name (str): The name of related data table. - """ - - if name in self.writer_queues: - queue = self.writer_queues.pop(name) - queue.shutdown() - - def start_consumer_pipe(self, name: str, batch_size: int) -> Tuple[str, Queue]: - """Start a consumer pipeline, if there is no such a table that named as `name`, the function will be stucked until the table has been created. - - Args: - name (str): Name of datatable. - batch_size (int): Batch size. - - Returns: - Tuple[str, Queue]: A tuple of table name and queue for retrieving samples. - """ - - queue_id = f"{name}_{time.time()}" - queue = Queue(actor_options={"num_cpus": 0}) - self.reader_queues[queue_id] = queue - # make sure that the buffer is ready - while name not in self.buffers: - time.sleep(1) - self.thread_pool.submit( - read_table, self.markers[name], self.buffers[name], batch_size, queue - ) - return queue_id, queue - - def end_consumer_pipe(self, name: str): - """Kill a consumer pipeline with given table name. - - Args: - name (str): Name of related datatable. - """ - - if name in self.reader_queues: - queue = self.reader_queues.pop(name) - queue.shutdown() diff --git a/malib/backend/parameter_server.py b/malib/backend/parameter_server.py deleted file mode 100644 index 3f93e915..00000000 --- a/malib/backend/parameter_server.py +++ /dev/null @@ -1,158 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Author: Ming Zhou - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from argparse import Namespace -from typing import Dict, Any, Sequence -from threading import Lock - -import itertools -import torch - -from malib.rl.common.policy import Policy -from malib.common.strategy_spec import StrategySpec -from malib.remote.interface import RemoteInterface -from malib.utils.logging import Logger - - -class Table: - def __init__(self, policy_meta_data: Dict[str, Any]): - policy_cls = policy_meta_data["policy_cls"] - optim_config = policy_meta_data.get("optim_config") - policy_init_kwargs = Namespace(**policy_meta_data["kwargs"]) - self.state_dict = None - if optim_config is not None: - self.policy: Policy = policy_cls( - observation_space=policy_init_kwargs.observation_space, - action_space=policy_init_kwargs.action_space, - model_config=policy_init_kwargs.model_config, - custom_config=policy_init_kwargs.custom_config, - **policy_init_kwargs.kwargs, - ) - parameters = [list(v) for v in self.policy.parameters().values()] - parameters = itertools.chain(*parameters) - self.optimizer: torch.optim.Optimizer = getattr( - torch.optim, optim_config["type"] - )(parameters, lr=optim_config["lr"]) - else: - self.optimizer: torch.optim.Optimizer = None - self.lock = Lock() - - def set_weights(self, state_dict: Dict[str, Any]): - """Update weights with given weights. - - Args: - state_dict (Dict[str, Any]): A dict of weights - """ - - with self.lock: - self.state_dict = state_dict - - def apply_gradients(self, *gradients): - raise NotImplementedError - - def get_weights(self) -> Dict[str, Any]: - """Retrive model weights. - - Returns: - Dict[str, Any]: Weights dict - """ - - with self.lock: - return self.state_dict - - -class ParameterServer(RemoteInterface): - def __init__(self, **kwargs): - self.tables: Dict[str, Table] = {} - self.lock = Lock() - - def start(self): - """For debug""" - Logger.info("Parameter server started") - - def apply_gradients(self, table_name: str, gradients: Sequence[Any]): - """Apply gradients to a data table. - - Args: - table_name (str): The specified table name. - gradients (Sequence[Any]): Given gradients to update parameters. - - Raises: - NotImplementedError: Not implemented yet. - """ - - raise NotImplementedError - - def get_weights(self, spec_id: str, spec_policy_id: str) -> Dict[str, Any]: - """Request for weight retrive, return a dict includes keys: `spec_id`, `spec_policy_id` and `weights`. - - Args: - spec_id (str): Strategy spec id. - spec_policy_id (str): Related policy id. - - Returns: - Dict[str, Any]: A dict. - """ - - table_name = f"{spec_id}/{spec_policy_id}" - weights = self.tables[table_name].get_weights() - return { - "spec_id": spec_id, - "spec_policy_id": spec_policy_id, - "weights": weights, - } - - def set_weights( - self, spec_id: str, spec_policy_id: str, state_dict: Dict[str, Any] - ): - """Set weights to a parameter table. The table name will be defined as `{spec_id}/{spec_policy_id}` - - Args: - spec_id (str): StrategySpec id. - spec_policy_id (str): Policy id in the specified strategy spec. - state_dict (Dict[str, Any]): A dict that specify the parameters. - """ - - table_name = f"{spec_id}/{spec_policy_id}" - self.tables[table_name].set_weights(state_dict) - - def create_table(self, strategy_spec: StrategySpec) -> str: - """Create parameter table with given strategy spec. This function will traverse existing policy \ - id in this spec, then generate table for policy ids which have no cooresponding tables. - - Args: - strategy_spec (StrategySpec): A startegy spec instance. - - Returns: - str: Table name formatted as `{startegy_spec_id}/{policy_id}`. - """ - - with self.lock: - for policy_id in strategy_spec.policy_ids: - table_name = f"{strategy_spec.id}/{policy_id}" - if table_name in self.tables: - continue - meta_data = strategy_spec.get_meta_data().copy() - self.tables[table_name] = Table(meta_data) - return table_name diff --git a/malib/learner/config.py b/malib/learner/config.py index ff36d106..afa31367 100644 --- a/malib/learner/config.py +++ b/malib/learner/config.py @@ -1,22 +1,23 @@ -from typing import Dict, Any, Union, Type +from typing import Dict, Any, Union, Type, Callable from dataclasses import dataclass, field from malib.learner.learner import Learner +from malib.backend.dataset_server.feature import BaseFeature # TODO(ming): rename it as LearnerConfig @dataclass -class TrainingConfig: - trainer_config: Dict[str, Any] +class LearnerConfig: learner_type: Type[Learner] + feature_handler_meta_gen: Callable[["EnvDesc", str], Callable[[str], BaseFeature]] custom_config: Dict[str, Any] = field(default_factory=dict()) @classmethod def from_raw( - cls, config: Union["TrainingConfig", Dict[str, Any]] - ) -> "TrainingConfig": - """Cat dict-style configuration to TrainingConfig instance + cls, config: Union["LearnerConfig", Dict[str, Any]] + ) -> "LearnerConfig": + """Cat dict-style configuration to LearnerConfig instance Args: config (Dict[str, Any]): A dict @@ -25,7 +26,7 @@ def from_raw( RuntimeError: Unexpected config type Returns: - TrainingConfig: A training config instance + LearnerConfig: A training config instance """ if isinstance(config, Dict): diff --git a/malib/learner/learner.py b/malib/learner/learner.py index 76d5b103..4beab200 100644 --- a/malib/learner/learner.py +++ b/malib/learner/learner.py @@ -23,10 +23,10 @@ # SOFTWARE. -from typing import Dict, Any, Tuple, Callable, Type, List, Union +from typing import Dict, Any, Tuple, Callable, List, Union, Type from abc import ABC, abstractmethod -from collections import deque +import time import traceback import torch @@ -44,6 +44,7 @@ from malib.common.task import OptimizationTask from malib.common.strategy_spec import StrategySpec from malib.backend.dataset_server.data_loader import DynamicDataset +from malib.backend.dataset_server.feature import BaseFeature from malib.rl.common.trainer import Trainer from malib.rl.common.policy import Policy from malib.rl.config import Algorithm @@ -61,11 +62,10 @@ def __init__( algorithm: Algorithm, agent_mapping_func: Callable[[AgentID], str], governed_agents: Tuple[AgentID], - trainer_config: Dict[str, Any], custom_config: Dict[str, Any] = None, - local_buffer_config: Dict = None, - verbose: bool = True, dataset: DynamicDataset = None, + feature_handler_gen: Callable[[str], BaseFeature] = None, + verbose: bool = True, ): """Construct agent interface for training. @@ -80,14 +80,14 @@ def __init__( Note that it should be a subset of the original set of environment agents. trainer_config (Dict[str, Any]): Trainer configuration. custom_config (Dict[str, Any], optional): A dict of custom configuration. Defaults to None. - local_buffer_config (Dict, optional): A dict for local buffer configuration. Defaults to None. + dataset (DynamicDataset, optional): A dataset instance. Defaults to None. + feature_handler_gen (Callable[[str], BaseFeature], optional): A function that generates feature handler. Defaults to None. verbose (bool, True): Enable logging or not. Defaults to True. """ if verbose: Logger.info("\tAssigned GPUs: {}".format(ray.get_gpu_ids())) - local_buffer_config = local_buffer_config or {} device = torch.device("cuda" if ray.get_gpu_ids() else "cpu") # initialize a strategy spec for policy maintainance. @@ -110,27 +110,31 @@ def __init__( self._summary_writer = tensorboard.SummaryWriter(log_dir=log_dir) # load policy for trainer - self._trainer: Trainer = algorithm.trainer(trainer_config, self._policy) + self._trainer: Trainer = algorithm.trainer( + algorithm.trainer_config, self._policy + ) - dataset = dataset or self.create_dataset() - self._data_loader = DataLoader(dataset, batch_size=trainer_config["batch_size"]) + if dataset is None: + dataset = DynamicDataset( + grpc_thread_num_workers=2, + max_message_length=1024, + feature_handler=feature_handler_gen(device), + ) + else: + if feature_handler_gen is not None: + # XXX(ming): should we replace feature handler ? + dataset.feature_handler = feature_handler_gen(device) + + dataset.start_server() + + self._data_loader = DataLoader( + dataset, batch_size=algorithm.trainer_config["batch_size"] + ) self._total_step = 0 self._total_epoch = 0 self._verbose = verbose - def create_dataset(self) -> DynamicDataset: - """Create dataset - - Returns: - DynamicDataset: Must be an subinstance of DynamicDataset - """ - return DynamicDataset( - grpc_thread_num_workers=1, - max_message_length=1024, - feature_handler_caller=None, - ) - @abstractmethod def multiagent_post_process( self, @@ -223,6 +227,12 @@ def train(self, task: OptimizationTask) -> Dict[str, Any]: self.set_running(True) try: + while ( + self.data_loader.dataset.readable_block_size + < self.data_loader.batch_size + ): + time.sleep(1) + while self.is_running(): for data in self.data_loader: batch_info = self.multiagent_post_process(data) diff --git a/malib/learner/manager.py b/malib/learner/manager.py index 39b5f140..86e784f0 100644 --- a/malib/learner/manager.py +++ b/malib/learner/manager.py @@ -34,11 +34,9 @@ Type, Generator, ) -from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor, Future, CancelledError +from concurrent.futures import ThreadPoolExecutor import os -import traceback import ray from malib.common.task import OptimizationTask @@ -46,10 +44,9 @@ from malib.utils.logging import Logger from malib.utils.exploitability import measure_exploitability from malib.remote.interface import RemoteInterface -from malib.learner.learner import Learner from malib.common.strategy_spec import StrategySpec from malib.common.manager import Manager -from malib.learner.config import TrainingConfig +from malib.learner.config import LearnerConfig from malib.rl.config import Algorithm @@ -66,7 +63,7 @@ def __init__( env_desc: Dict[str, Any], agent_mapping_func: Callable[[AgentID], str], group_info: Dict[str, Any], - training_config: Union[Dict[str, Any], TrainingConfig], + learner_config: LearnerConfig, log_dir: str, resource_config: Dict[str, Any] = None, ray_actor_namespace: str = "learner", @@ -90,19 +87,19 @@ def __init__( super().__init__(verbose=verbose, namespace=ray_actor_namespace) resource_config = resource_config or DEFAULT_RESOURCE_CONFIG - training_config = TrainingConfig.from_raw(training_config) + learner_config = LearnerConfig.from_raw(learner_config) # interface config give the agent type used here and the group mapping if needed # FIXME(ming): resource configuration is not available now, will turn-on in the next version - if training_config.trainer_config.get("use_cuda", False): + if algorithm.trainer_config.get("use_cuda", False): num_gpus = 1 / len(group_info["agent_groups"]) else: num_gpus = 0.0 if not os.path.exists(log_dir): os.makedirs(log_dir) - learner_cls = training_config.learner_type + learner_cls = learner_config.learner_type # update num gpus resource_config["num_gpus"] = num_gpus learner_cls = learner_cls.as_remote(**resource_config) @@ -115,6 +112,7 @@ def __init__( ready_check = [] for rid, agents in group_info["agent_groups"].items(): + agents = tuple(agents) learners[rid] = learner_cls.options( name=f"learner_{rid}", max_concurrency=10, namespace=self.namespace ).remote( @@ -124,9 +122,12 @@ def __init__( action_space=group_info["action_space"][rid], algorithm=algorithm, agent_mapping_func=agent_mapping_func, - governed_agents=tuple(agents), - trainer_config=training_config.trainer_config, - custom_config=training_config.custom_config, + governed_agents=agents, + trainer_config=algorithm.trainer_config, + custom_config=learner_config.custom_config, + feature_handler_gen=learner_config.feature_handler_meta_gen( + env_desc, agents[0] + ), verbose=verbose, ) ready_check.append(learners[rid].ready.remote()) @@ -150,7 +151,7 @@ def __init__( self._group_info = group_info self._runtime_ids = tuple(group_info["agent_groups"].keys()) self._env_description = env_desc - self._training_config = training_config + self._learner_config = learner_config self._log_dir = log_dir self._agent_mapping_func = agent_mapping_func self._learners = learners diff --git a/malib/mocker/mocker_utils.py b/malib/mocker/mocker_utils.py index 1f85a5d2..e4685e30 100644 --- a/malib/mocker/mocker_utils.py +++ b/malib/mocker/mocker_utils.py @@ -22,69 +22,17 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Sequence, Dict, Any, Callable, List, Union +from typing import Sequence, Dict, Any, Callable, List, Tuple import time -import ray -from ray.util import ActorPool -from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.config import RolloutConfig from malib.utils.typing import AgentID from malib.common.strategy_spec import StrategySpec -from malib.rollout.rolloutworker import RolloutWorker -class FakeRolloutWorker(RolloutWorker): - def init_agent_interfaces( - self, env_desc: Dict[str, Any], runtime_ids: Sequence[AgentID] - ) -> Dict[AgentID, Any]: - return {} - - def init_actor_pool( - self, - env_desc: Dict[str, Any], - rollout_config: Dict[str, Any], - agent_mapping_func: Callable, - ) -> ActorPool: - return NotImplementedError - - def init_servers(self): - pass - - def rollout( - self, - runtime_strategy_specs: Dict[str, StrategySpec], - stopping_conditions: Dict[str, Any], - data_entrypoints: Dict[str, str], - trainable_agents: List[AgentID] = None, - ): - self.set_running(True) - return {} - - def simulate(self, runtime_strategy_specs: Dict[str, StrategySpec]): - time.sleep(0.5) - return {} - - def step_rollout( - self, - eval_step: bool, - rollout_config: Dict[str, Any], - dataset_writer_info_dict: Dict[str, Any], - ) -> List[Dict[str, Any]]: - pass - - def step_simulation( - self, - runtime_strategy_specs_list: Dict[str, StrategySpec], - rollout_config: Dict[str, Any], - ) -> Dict[str, Any]: - pass - - -from typing import Tuple - from malib.utils.typing import PolicyID from malib.common.payoff_manager import PayoffManager diff --git a/malib/registration.py b/malib/registration.py deleted file mode 100644 index ec7b2bb1..00000000 --- a/malib/registration.py +++ /dev/null @@ -1,90 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from typing import Dict, Callable, Union - - -class Registry: - """Global registry of algorithms, models, preprocessors and environments - - Examples: - >>> # register custom model - >>> Registry.register_custom_model("MyCustomModel", model_class) - >>> # register custom policy - >>> Registry.register_custom_policy("MyCustomPolicy", policy_class) - >>> # register custom environment - >>> Registry.register_custom_env("MyCustomEnvironment", environment_class) - >>> # register custom algorithm - >>> Registry.register_custom_algorithm( - ... name="MyCustomAlgo", - ... policy="registered_policy_name_or_cls", - ... trainer="registered_trainer_name_or_cls", - ... loss="registered_loss_name_or_cls") - >>> - """ - - @staticmethod - def register_custom_algorithm( - name: str, - policy: Union[type, str], - trainer: Union[type, str], - loss: Union[type, str] = None, - ) -> None: - """Register a custom algorithm by name. - - :param name: str, Name to register the algorithm under. - :param policy: Union[type, str], Python class or registered name of policy. - :param trainer: Union[type, str], Python class or registered name of trainer. - :param loss: Union[type, str], Python class or registered name of loss function. - :return: - """ - # _global_registry.register(ALGORITHM, name, policy, trainer, loss) - pass - - @staticmethod - def register_custom_model(name: str, model_class: type) -> None: - """Register a custom model by name. - - :param name: str, Name to register the model under. - :param model_class: type, Python class of the model. - :return: - """ - # _global_registry.register(MODEL, name, model_class) - pass - - @staticmethod - def register_custom_policy(name: str, policy_class: type) -> None: - """Register a custom policy by name. - - :param name: str, Name to register the policy under. - :param policy_class: type, Python class of the policy. - """ - pass - - @staticmethod - def register_custom_env(name: str, env_class: type) -> None: - """Register a custom environment by name. - - :param name: str, Name to register the environment under. - :param env_class: type, Python class of the environment. - """ - pass diff --git a/malib/rl/common/policy.py b/malib/rl/common/policy.py index ba346aa7..4cd783dd 100644 --- a/malib/rl/common/policy.py +++ b/malib/rl/common/policy.py @@ -249,16 +249,19 @@ def to(self, device: str = None, use_copy: bool = False) -> "Policy": Policy: A policy instance """ + if isinstance(device, torch.device): + device = device.type + if device is None: - device = "cpu" if not self.use_cuda else "cuda" + device = "cpu" if "cuda" not in self.device else "cuda" - cond1 = "cpu" in device and self.use_cuda - cond2 = "cuda" in device and not self.use_cuda + cond1 = "cpu" in device and "cuda" in self.device + cond2 = "cuda" in device and "cuda" not in self.device if "cpu" in device: - use_cuda = False + _device = device else: - use_cuda = self._custom_config.get("use_cuda", False) + _device = self.device replacement = {} if cond1 or cond2: @@ -273,7 +276,6 @@ def to(self, device: str = None, use_copy: bool = False) -> "Policy": if use_copy: ret = self.copy(self, replacement=replacement) else: - self.use_cuda = use_cuda ret = self return ret diff --git a/malib/rl/config.py b/malib/rl/config.py index 730be866..5935b997 100644 --- a/malib/rl/config.py +++ b/malib/rl/config.py @@ -14,3 +14,5 @@ class Algorithm: trainer: Type[Trainer] model_config: Dict[str, Any] + + trainer_config: Dict[str, Any] diff --git a/malib/rollout/rollout_config.py b/malib/rollout/config.py similarity index 100% rename from malib/rollout/rollout_config.py rename to malib/rollout/config.py diff --git a/malib/rollout/inference/env_runner.py b/malib/rollout/inference/env_runner.py index 0167f77c..1048e787 100644 --- a/malib/rollout/inference/env_runner.py +++ b/malib/rollout/inference/env_runner.py @@ -33,7 +33,7 @@ from malib.utils.timing import Timing from malib.remote.interface import RemoteInterface from malib.rollout.envs.vector_env import VectorEnv, SubprocVecEnv -from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.config import RolloutConfig from malib.rollout.inference.client import InferenceClient, PolicyReturnWithObs from malib.rollout.envs.env import Environment from malib.common.strategy_spec import StrategySpec diff --git a/malib/rollout/manager.py b/malib/rollout/manager.py index e669a426..146d6f4d 100644 --- a/malib/rollout/manager.py +++ b/malib/rollout/manager.py @@ -39,7 +39,7 @@ from malib.common.manager import Manager from malib.remote.interface import RemoteInterface from malib.common.strategy_spec import StrategySpec -from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.config import RolloutConfig from malib.rollout.pb_rolloutworker import PBRolloutWorker @@ -99,7 +99,7 @@ def __init__( super().__init__(verbose=verbose, namespace=ray_actor_namespace) rollout_worker_cls = PBRolloutWorker - worker_cls = rollout_worker_cls.as_remote(num_cpus=0, num_gpus=0).options() + worker_cls = rollout_worker_cls.as_remote(num_cpus=0, num_gpus=0) workers = [] ready_check = [] for i in range(num_worker): @@ -180,7 +180,9 @@ def submit( for _task in task: validate_strategy_specs(_task.strategy_specs) - self._actor_pool.submit(lambda actor, _task: actor.rollout.remote(_task)) + self._actor_pool.submit( + lambda actor, _task: actor.rollout.remote(_task), _task + ) if wait: result_list = self.wait() diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index 455b8d1c..c3f24ec3 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -47,7 +47,7 @@ from malib.common.strategy_spec import StrategySpec from malib.common.task import RolloutTask from malib.remote.interface import RemoteInterface -from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.config import RolloutConfig from malib.rollout.inference.client import InferenceClient from malib.rollout.inference.env_runner import BasicEnvRunner diff --git a/malib/scenarios/sarl_scenario.py b/malib/scenarios/sarl_scenario.py index 653f6cdc..0cc36596 100644 --- a/malib/scenarios/sarl_scenario.py +++ b/malib/scenarios/sarl_scenario.py @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Any +from typing import Dict, Any, Union from malib.common.task import TaskType, OptimizationTask, RolloutTask from malib.scenarios import Scenario @@ -30,8 +30,11 @@ from malib.utils.logging import Logger from malib.backend.league import League from malib.learner.manager import LearnerManager +from malib.learner.config import LearnerConfig +from malib.rollout.config import RolloutConfig from malib.rollout.manager import RolloutWorkerManager from malib.rollout.inference.manager import InferenceManager +from malib.rl.config import Algorithm class SARLScenario(Scenario): @@ -40,9 +43,9 @@ def __init__( name: str, log_dir: str, env_desc: Dict[str, Any], - algorithms: Dict[str, Any], - training_config: Dict[str, Any], - rollout_config: Dict[str, Any], + algorithm: Algorithm, + learner_config: Union[Dict[str, Any], LearnerConfig], + rollout_config: Union[Dict[str, Any], RolloutConfig], stopping_conditions: Dict[str, Any], resource_config: Dict[str, Any] = None, ): @@ -50,9 +53,9 @@ def __init__( name, log_dir, env_desc, - algorithms, + algorithm, lambda agent: "default", - training_config, + learner_config, rollout_config, stopping_conditions, ) @@ -66,15 +69,13 @@ def create_global_stopper(self) -> StoppingCondition: def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = True): # TODO(ming): simplify the initialization of training and rollout manager with a scenario instance as input learner_manager = LearnerManager( - experiment_tag=experiment_tag, stopping_conditions=scenario.stopping_conditions, - algorithms=scenario.algorithms, + algorithm=scenario.algorithm, env_desc=scenario.env_desc, agent_mapping_func=scenario.agent_mapping_func, group_info=scenario.group_info, - training_config=scenario.training_config, + learner_config=scenario.learner_config, log_dir=scenario.log_dir, - remote_mode=True, resource_config=scenario.resource_config["training"], ray_actor_namespace="learner_{}".format(experiment_tag), verbose=verbose, @@ -84,7 +85,8 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = group_info=scenario.group_info, ray_actor_namespace="inference_{}".format(experiment_tag), model_entry_point=learner_manager.learner_entrypoints, - scenario=scenario, + algorithm=scenario.algorithm, + verbose=verbose, ) rollout_manager = RolloutWorkerManager( @@ -99,27 +101,23 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = verbose=verbose, ) - league = League( - learner_manager, rollout_manager, inference_manager, namespace=experiment_tag - ) + league = League(learner_manager, rollout_manager, inference_manager) + # TODO(ming): further check is needed optimization_task = OptimizationTask( - active_agents=scenario.env_desc["possible_agents"], stop_conditions=scenario.stopping_conditions["training"], + strategy_specs=None, + active_agents=None, ) - strategy_specs = learner_manager.get_strategy_specs() - rollout_task = RolloutTask( - task_type=TaskType.ROLLOUT, - strategy_specs=strategy_specs, + strategy_specs=None, stopping_conditions=scenario.stopping_conditions["rollout"], data_entrypoint_mapping=learner_manager.data_entrypoints, ) evaluation_task = RolloutTask( - task_type=TaskType.EVALUATION, - strategy_specs=strategy_specs, + strategy_specs=None, ) stopper = scenario.create_global_stopper() diff --git a/malib/scenarios/scenario.py b/malib/scenarios/scenario.py index 03af7069..d5e1ea22 100644 --- a/malib/scenarios/scenario.py +++ b/malib/scenarios/scenario.py @@ -31,6 +31,10 @@ from malib.utils.typing import AgentID from malib.utils.stopping_conditions import StoppingCondition +from malib.rl.config import Algorithm +from malib.learner.config import LearnerConfig +from malib.rollout.config import RolloutConfig + DEFAULT_STOPPING_CONDITIONS = {} @@ -91,16 +95,16 @@ def __init__( name: str, log_dir: str, env_desc: Dict[str, Any], - algorithms: Dict[str, Any], + algorithm: Algorithm, agent_mapping_func: LambdaType, - training_config: Dict[str, Any], - rollout_config: Dict[str, Any], + learner_config: LearnerConfig, + rollout_config: RolloutConfig, stopping_conditions: Dict[str, Any], ): self.name = name self.log_dir = log_dir self.env_desc = env_desc - self.algorithms = algorithms + self.algorithm = algorithm self.agent_mapping_func = agent_mapping_func # then generate grouping information here self.group_info = form_group_info(env_desc, agent_mapping_func) @@ -109,8 +113,8 @@ def __init__( env_desc["observation_spaces"], env_desc["action_spaces"], ) - self.training_config = training_config - self.rollout_config = rollout_config + self.learner_config = LearnerConfig.from_raw(learner_config) + self.rollout_config = RolloutConfig.from_raw(rollout_config) self.stopping_conditions = stopping_conditions or DEFAULT_STOPPING_CONDITIONS def copy(self): diff --git a/malib/settings.py b/malib/settings.py index af9b66c8..428a46b2 100644 --- a/malib/settings.py +++ b/malib/settings.py @@ -1,49 +1,3 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - import logging -import os - -BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - -LOG_DIR = os.path.join(BASE_DIR, "logs") -LOG_LEVEL = logging.INFO -STATISTIC_FEEDBACK = True -DATA_FEEDBACK = False -USE_REMOTE_LOGGER = True -USE_MONGO_LOGGER = False -PROFILING = False - -PARAMETER_SERVER_ACTOR = "ParameterServer" -OFFLINE_DATASET_ACTOR = "OfflineDataset" -COORDINATOR_SERVER_ACTOR = "coordinator" - -# default episode capacity when initializing -DEFAULT_EPISODE_INIT_CAPACITY = int(1e6) -# default episode maximum capacity -DEFAULT_EPISODE_CAPACITY = 30000 # int(1e15) -# related to each group of expr settings -DEFAULT_EPISODE_BLOCK_SIZE = int(75) -PICKLE_PROTOCOL_VER = 4 -PARAM_DIR = os.path.join(BASE_DIR, "../checkpoints") -DATASET_DIR = os.path.join(BASE_DIR, "dataset") +LOG_LEVEL = logging.DEBUG diff --git a/malib/utils/general.py b/malib/utils/general.py index 2bf25ada..778cbcab 100644 --- a/malib/utils/general.py +++ b/malib/utils/general.py @@ -41,8 +41,6 @@ import torch import numpy as np -from malib import settings - T = TypeVar("T") diff --git a/tests/backend/test_dynamic_dataset.py b/tests/backend/test_dynamic_dataset.py index 304ecb0c..5c62387d 100644 --- a/tests/backend/test_dynamic_dataset.py +++ b/tests/backend/test_dynamic_dataset.py @@ -83,6 +83,7 @@ def test_sync_grpc_service_get(self): for k, v in _spaces.items() }, ) + dataset.start_server() # send data print("send 10 piece of data, entrypoint=", dataset.entrypoint) @@ -124,6 +125,7 @@ def test_async_grpc_service_get(self): k: np.zeros((100,) + v.shape, dtype=v.dtype) for k, v in _spaces.items() }, ) + dataset.start_server() def start_send(batch, entrypoint): print("send 10 piece of data, entrypoint=", entrypoint) diff --git a/tests/rollout/test_env_runner.py b/tests/rollout/test_env_runner.py index ddddaa3a..33e69f96 100644 --- a/tests/rollout/test_env_runner.py +++ b/tests/rollout/test_env_runner.py @@ -7,7 +7,7 @@ from malib.rollout.inference import env_runner from malib.rollout.inference.client import InferenceClient from malib.rollout.envs import mdp -from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.config import RolloutConfig from malib.rl.random import RandomPolicy diff --git a/tests/rollout/test_mdp_env.py b/tests/rollout/test_mdp_env.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/rollout/test_open_spiel.py b/tests/rollout/test_open_spiel.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/rollout/test_pb_rollout_worker.py b/tests/rollout/test_pb_rollout_worker.py index e880a3ea..9d55103f 100644 --- a/tests/rollout/test_pb_rollout_worker.py +++ b/tests/rollout/test_pb_rollout_worker.py @@ -26,19 +26,16 @@ import pytest import ray -from malib.backend.dataset_server.data_loader import DynamicDataset from malib.common.task import RolloutTask from malib.common.strategy_spec import StrategySpec from malib.rl.random import RandomPolicy from malib.rl.config import Algorithm from malib.rollout.envs.random import env_desc_gen -from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.config import RolloutConfig from malib.rollout.pb_rolloutworker import PBRolloutWorker from malib.rollout.inference.manager import InferenceManager from malib.scenarios.scenario import form_group_info -from malib.utils.tianshou_batch import Batch -from malib.utils.typing import AgentID def gen_rollout_config(inference_server_type: str): @@ -61,9 +58,7 @@ def gen_common_requirements(n_player: int): env_desc = env_desc_gen(num_agents=n_player) algorithm = Algorithm( - policy=RandomPolicy, - trainer=None, - model_config=None, + policy=RandomPolicy, trainer=None, model_config=None, trainer_config={} ) rollout_config = RolloutConfig( @@ -79,11 +74,14 @@ def gen_common_requirements(n_player: int): return env_desc, algorithm, rollout_config, group_info +import numpy as np + from malib.learner.learner import Learner from gym import spaces -from malib.learner.learner import Learner from malib.learner.manager import LearnerManager -from malib.learner.config import TrainingConfig +from malib.learner.config import LearnerConfig +from malib.utils.episode import Episode +from malib.backend.dataset_server.feature import BaseFeature class FakeLearner(Learner): @@ -94,6 +92,28 @@ def multiagent_post_process( pass +class FakeFeatureHandler(BaseFeature): + + pass + + +def feature_handler_meta_gen(env_desc, agent_id): + def f(device): + _spaces = { + Episode.DONE: spaces.Discrete(1), + Episode.CUR_OBS: env_desc["observation_spaces"][agent_id], + Episode.ACTION: env_desc["action_spaces"][agent_id], + Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32), + Episode.NEXT_OBS: env_desc["observation_spaces"][agent_id], + } + np_memory = { + k: np.zeros((100,) + v.shape, dtype=v.dtype) for k, v in _spaces.items() + } + return FakeFeatureHandler(_spaces, np_memory, device) + + return f + + @pytest.mark.parametrize("n_player", [1, 2]) class TestRolloutWorker: def test_rollout(self, n_player: int): @@ -163,8 +183,10 @@ def test_rollout_with_data_entrypoint(self, n_player: int): env_desc=env_desc, agent_mapping_func=lambda agent: "default", group_info=group_info, - training_config=TrainingConfig( - trainer_config={}, learner_type=FakeLearner, custom_config=None + learner_config=LearnerConfig( + learner_type=FakeLearner, + feature_handler_meta_gen=feature_handler_meta_gen, + custom_config=None, ), log_dir=log_dir, ) diff --git a/tests/rollout/test_rollout_manager.py b/tests/rollout/test_rollout_manager.py index 2802ef49..4258e8ee 100644 --- a/tests/rollout/test_rollout_manager.py +++ b/tests/rollout/test_rollout_manager.py @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Any +from typing import Dict, Any, Callable import pytest import ray @@ -30,23 +30,33 @@ from gym import spaces from pytest_mock import MockerFixture +from malib.common.task import RolloutTask from malib.common.strategy_spec import StrategySpec +from malib.rollout.config import RolloutConfig from malib.rollout.manager import RolloutWorkerManager -from malib.mocker.mocker_utils import FakeRolloutWorker +from malib.rl.random import RandomPolicy +from malib.scenarios.scenario import form_group_info +from malib.learner.manager import LearnerManager +from malib.learner.config import LearnerConfig +from malib.rollout.inference.manager import InferenceManager +from test_pb_rollout_worker import ( + feature_handler_meta_gen, + FakeFeatureHandler, + FakeLearner, + gen_common_requirements, +) def create_manager( - mocker: MockerFixture, stopping_conditions: Dict[str, Any], rollout_config: Dict[str, Any], env_desc: Dict[str, Any], + agent_mapping_func: Callable, ): - mocker.patch("malib.rollout.manager.PBRolloutWorker", new=FakeRolloutWorker) manager = RolloutWorkerManager( - experiment_tag="test_rollout_manager", stopping_conditions=stopping_conditions, num_worker=1, - agent_mapping_func=lambda agent: agent, + group_info=form_group_info(env_desc, agent_mapping_func), rollout_config=rollout_config, env_desc=env_desc, log_dir="./logs", @@ -55,105 +65,57 @@ def create_manager( @pytest.mark.parametrize("n_players", [1, 2]) -@pytest.mark.parametrize("inference_server_type", ["local", "ray"]) class TestRolloutManager: - def test_rollout_task_send( - self, mocker: MockerFixture, n_players: int, inference_server_type: str - ): - if not ray.is_initialized(): - ray.init() - - agents = [f"player_{i}" for i in range(n_players)] - manager = create_manager( - mocker, - stopping_conditions={"rollout": {"max_iteration": 2}}, - rollout_config={ - "fragment_length": 100, - "max_step": 10, - "num_eval_episodes": 2, - "num_threads": 1, - "num_env_per_thread": 1, - "num_eval_threads": 1, - "use_subproc_env": False, - "batch_mode": "timestep", - "postprocessor_types": None, - "eval_interval": 2, - "inference_server": inference_server_type, - }, - env_desc={ - "possible_agents": agents, - "observation_spaces": { - agent: spaces.Box(-1, 1.0, shape=(2,)) for agent in agents - }, - "action_spaces": { - agent: spaces.Box(-1, 1, shape=(2,)) for agent in agents - }, - }, - ) - - strategy_specs = { - agent: StrategySpec( - identifier=agent, - policy_ids=["policy_0"], - meta_data={ - "prob_list": [1.0], - "policy_cls": None, - "kwargs": None, - "experiment_tag": "test_rollout_manager", - }, + def test_rollout_task_send(self, mocker: MockerFixture, n_players: int): + with ray.init(local_mode=True): + env_desc, algorithm, rollout_config, group_info = gen_common_requirements( + n_players ) - for agent in agents - } - task_list = [ - { - "trainable_agents": agents, - "data_entrypoints": None, - "strategy_specs": strategy_specs, + inference_namespace = "test_pb_rolloutworker" + manager = create_manager( + stopping_conditions={"rollout": {"max_iteration": 2}}, + rollout_config=RolloutConfig(), + env_desc=env_desc, + agent_mapping_func=lambda agent: "default", + ) + + learner_manager = LearnerManager( + stopping_conditions={"max_iteration": 10}, + algorithm=algorithm, + env_desc=env_desc, + agent_mapping_func=lambda agent: "default", + group_info=group_info, + learner_config=LearnerConfig( + learner_type=FakeLearner, + feature_handler_meta_gen=feature_handler_meta_gen, + custom_config=None, + ), + log_dir="./logs", + ) + + infer_manager = InferenceManager( + group_info=group_info, + ray_actor_namespace=inference_namespace, + algorithm=algorithm, + model_entry_point=learner_manager.learner_entrypoints, + ) + + rollout_config.inference_entry_points = infer_manager.inference_entry_points + + strategy_specs = { + agent: StrategySpec( + policy_cls=RandomPolicy, + observation_space=env_desc["observation_spaces"][agent], + action_space=env_desc["action_spaces"][agent], + policy_ids=["policy_0"], + ) + for agent in env_desc["possible_agents"] } - for _ in range(2) - ] - manager.rollout(task_list) - - for result in manager.retrive_results(): - print(result) - - ray.shutdown() - - def test_simulation_task_send( - self, mocker: MockerFixture, n_players: int, inference_server_type: str - ): - if not ray.is_initialized(): - ray.init() - - agents = [f"player_{i}" for i in range(n_players)] - manager = create_manager( - mocker, - stopping_conditions={"rollout": {"max_iteration": 2}}, - rollout_config={ - "fragment_length": 100, - "max_step": 10, - "num_eval_episodes": 2, - "num_threads": 1, - "num_env_per_thread": 1, - "num_eval_threads": 1, - "use_subproc_env": False, - "batch_mode": "timestep", - "postprocessor_types": None, - "eval_interval": 2, - "inference_server": inference_server_type, - }, - env_desc={ - "possible_agents": agents, - "observation_spaces": { - agent: spaces.Box(-1, 1.0, shape=(2,)) for agent in agents - }, - "action_spaces": { - agent: spaces.Box(-1, 1, shape=(2,)) for agent in agents - }, - }, - ) - - manager.simulate([None] * 2) - for result in manager.retrive_results(): - print(result) - ray.shutdown() + + task = RolloutTask( + strategy_specs=strategy_specs, + stopping_conditions={"max_iteration": 10}, + data_entrypoints=None, + ) + + results = manager.submit(task, wait=True) From 207bcaeb10a130bb1881348ee2352b52f1412022 Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Fri, 24 Nov 2023 20:20:41 +0800 Subject: [PATCH 19/24] tmp save --- malib/learner/manager.py | 3 +++ malib/models/config.py | 1 - malib/rl/config.py | 1 - malib/rollout/config.py | 1 - malib/rollout/envs/mdp/env.py | 1 - malib/scenarios/sarl_scenario.py | 2 +- tests/rollout/test_pb_rollout_worker.py | 1 - 7 files changed, 4 insertions(+), 6 deletions(-) diff --git a/malib/learner/manager.py b/malib/learner/manager.py index 86e784f0..79ccbbd0 100644 --- a/malib/learner/manager.py +++ b/malib/learner/manager.py @@ -158,6 +158,9 @@ def __init__( self._thread_pool = ThreadPoolExecutor(max_workers=len(learners)) self._stopping_conditions = stopping_conditions + # init strategy spec + self.add_policies() + Logger.info( f"training manager launched, {len(self._learners)} learner(s) created" ) diff --git a/malib/models/config.py b/malib/models/config.py index bb1a2108..0683b93b 100644 --- a/malib/models/config.py +++ b/malib/models/config.py @@ -5,7 +5,6 @@ @dataclass class ModelConfig: - model_cls: Type model_args: Dict[str, Any] diff --git a/malib/rl/config.py b/malib/rl/config.py index 5935b997..552543d5 100644 --- a/malib/rl/config.py +++ b/malib/rl/config.py @@ -8,7 +8,6 @@ @dataclass class Algorithm: - policy: Type[Policy] trainer: Type[Trainer] diff --git a/malib/rollout/config.py b/malib/rollout/config.py index 4e462d6b..f576c528 100644 --- a/malib/rollout/config.py +++ b/malib/rollout/config.py @@ -5,7 +5,6 @@ @dataclass class RolloutConfig: - num_workers: int = 1 """Defines how many workers will be used for executing one rollout task, default is 1""" diff --git a/malib/rollout/envs/mdp/env.py b/malib/rollout/envs/mdp/env.py index e97c99ee..ce965152 100644 --- a/malib/rollout/envs/mdp/env.py +++ b/malib/rollout/envs/mdp/env.py @@ -9,7 +9,6 @@ class MDPEnvironment(Environment): def __init__(self, **configs): - try: from blackhc import mdp from blackhc.mdp import example as mdp_examples diff --git a/malib/scenarios/sarl_scenario.py b/malib/scenarios/sarl_scenario.py index 0cc36596..07bc2817 100644 --- a/malib/scenarios/sarl_scenario.py +++ b/malib/scenarios/sarl_scenario.py @@ -67,7 +67,7 @@ def create_global_stopper(self) -> StoppingCondition: def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = True): - # TODO(ming): simplify the initialization of training and rollout manager with a scenario instance as input + # TODO(ming): simplize the initialization of training and rollout manager with a scenario instance as input learner_manager = LearnerManager( stopping_conditions=scenario.stopping_conditions, algorithm=scenario.algorithm, diff --git a/tests/rollout/test_pb_rollout_worker.py b/tests/rollout/test_pb_rollout_worker.py index 9d55103f..baf1c334 100644 --- a/tests/rollout/test_pb_rollout_worker.py +++ b/tests/rollout/test_pb_rollout_worker.py @@ -93,7 +93,6 @@ def multiagent_post_process( class FakeFeatureHandler(BaseFeature): - pass From 40c50741d1244fa4d8f7a5fa8e43d033d0b016ed Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Mon, 27 Nov 2023 17:49:40 +0800 Subject: [PATCH 20/24] in progress: pb-rollout-worker in test --- examples/run_gym.py | 98 -------------------- examples/run_psro.py | 6 +- examples/sarl/ppo_gym.py | 39 +++++++- malib/backend/dataset_server/data_loader.py | 6 +- malib/learner/config.py | 2 + malib/learner/learner.py | 5 +- malib/learner/manager.py | 12 +-- malib/models/model_client.py | 6 +- malib/rl/common/policy.py | 23 +++-- malib/rl/pg/config.py | 1 + malib/rl/pg/policy.py | 1 + malib/rl/pg/trainer.py | 12 ++- malib/rl/random/random_trainer.py | 18 +++- malib/rollout/inference/client.py | 9 ++ malib/rollout/rolloutworker.py | 14 +-- malib/scenarios/sarl_scenario.py | 8 +- tests/rollout/test_pb_rollout_worker.py | 99 +++++++++++---------- 17 files changed, 170 insertions(+), 189 deletions(-) delete mode 100644 examples/run_gym.py diff --git a/examples/run_gym.py b/examples/run_gym.py deleted file mode 100644 index 1ad8c40c..00000000 --- a/examples/run_gym.py +++ /dev/null @@ -1,98 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -# pragma: no cover -from argparse import ArgumentParser - -import os -import time - -from malib.learner import IndependentAgent -from malib.scenarios.marl_scenario import MARLScenario -from malib.rl.dqn import DQNPolicy, DQNTrainer, DEFAULT_CONFIG -from malib.rollout.envs.gym import env_desc_gen - - -if __name__ == "__main__": - parser = ArgumentParser("Multi-agent reinforcement learning for gym cases.") - parser.add_argument("--log-dir", default="./logs/", help="Log directory.") - parser.add_argument("--env-id", default="CartPole-v1", help="gym environment id.") - parser.add_argument("--use-cuda", action="store_true") - - args = parser.parse_args() - - trainer_config = DEFAULT_CONFIG["training_config"].copy() - trainer_config["total_timesteps"] = int(1e6) - trainer_config["use_cuda"] = args.use_cuda - - training_config = { - "type": IndependentAgent, - "trainer_config": trainer_config, - "custom_config": {}, - } - rollout_config = { - "fragment_length": 2000, # determine the size of sended data block - "max_step": 200, - "num_eval_episodes": 10, - "num_threads": 2, - "num_env_per_thread": 10, - "num_eval_threads": 1, - "use_subproc_env": False, - "batch_mode": "time_step", - "postprocessor_types": ["defaults"], - # every # rollout epoch run evaluation. - "eval_interval": 1, - "inference_server": "ray", # three kinds of inference server: `local`, `pipe` and `ray` - } - agent_mapping_func = lambda agent: agent - - algorithms = { - "default": ( - DQNPolicy, - DQNTrainer, - # model configuration, None for default - {}, - {"use_cuda": args.use_cuda}, - ) - } - - env_description = env_desc_gen(env_id=args.env_id, scenario_configs={}) - runtime_logdir = os.path.join(args.log_dir, f"gym/{time.time()}") - - if not os.path.exists(runtime_logdir): - os.makedirs(runtime_logdir) - - scenario = MARLScenario( - name="gym", - log_dir=runtime_logdir, - algorithms=algorithms, - env_description=env_description, - training_config=training_config, - rollout_config=rollout_config, - agent_mapping_func=agent_mapping_func, - stopping_conditions={ - "training": {"max_iteration": int(1e10)}, - "rollout": {"max_iteration": 1000, "minimum_reward_improvement": 1.0}, - }, - ) - - run(scenario) diff --git a/examples/run_psro.py b/examples/run_psro.py index b9382c5f..a8cc9481 100644 --- a/examples/run_psro.py +++ b/examples/run_psro.py @@ -26,7 +26,7 @@ import os import time -from malib.runner import run +from malib.scenarios import psro_scenario from malib.learner import IndependentAgent from malib.scenarios.psro_scenario import PSROScenario from malib.rl.dqn import DQNPolicy, DQNTrainer, DEFAULT_CONFIG @@ -99,4 +99,6 @@ }, ) - run(scenario) + results = psro_scenario.execution_plan(scenario=scenario, verbose=True) + + print(results) diff --git a/examples/sarl/ppo_gym.py b/examples/sarl/ppo_gym.py index ee37ec76..f64ced00 100644 --- a/examples/sarl/ppo_gym.py +++ b/examples/sarl/ppo_gym.py @@ -3,6 +3,11 @@ from argparse import ArgumentParser +from gym import spaces + +import numpy as np + +from malib.utils.episode import Episode from malib.learner import IndependentAgent from malib.scenarios import sarl_scenario from malib.rl.config import Algorithm @@ -10,6 +15,32 @@ from malib.learner.config import LearnerConfig from malib.rollout.config import RolloutConfig from malib.rollout.envs.gym import env_desc_gen +from malib.backend.dataset_server.feature import BaseFeature + + +class FeatureHandler(BaseFeature): + pass + + +def feature_handler_meta_gen(env_desc, agent_id): + def f(device): + # define the data schema + _spaces = { + Episode.DONE: spaces.Discrete(1), + Episode.CUR_OBS: env_desc["observation_spaces"][agent_id], + Episode.ACTION: env_desc["action_spaces"][agent_id], + Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32), + Episode.NEXT_OBS: env_desc["observation_spaces"][agent_id], + } + + # you should know the maximum of replaybuffer before training + np_memory = { + k: np.zeros((100,) + v.shape, dtype=v.dtype) for k, v in _spaces.items() + } + + return FeatureHandler(_spaces, np_memory, device) + + return f if __name__ == "__main__": @@ -43,7 +74,7 @@ ), learner_config=LearnerConfig( learner_type=IndependentAgent, - feature_handler_meta_gen=None, + feature_handler_meta_gen=feature_handler_meta_gen, custom_config={}, ), rollout_config=RolloutConfig( @@ -56,6 +87,6 @@ }, ) - results = sarl_scenario.execution_plan( - experiment_tag=scenario.name, scenario=scenario, verbose=True - ) + results = sarl_scenario.execution_plan(scenario=scenario, verbose=True) + + print(results) diff --git a/malib/backend/dataset_server/data_loader.py b/malib/backend/dataset_server/data_loader.py index 92000c41..e1a2c25a 100644 --- a/malib/backend/dataset_server/data_loader.py +++ b/malib/backend/dataset_server/data_loader.py @@ -28,8 +28,10 @@ def __init__( super().__init__() # start a service as thread - self.feature_handler: BaseFeature = feature_handler or feature_handler_cls( - **feature_handler_kwargs + self.feature_handler: BaseFeature = ( + feature_handler + if feature_handler is not None + else feature_handler_cls(**feature_handler_kwargs) ) self.grpc_thread_num_workers = grpc_thread_num_workers self.max_message_length = max_message_length diff --git a/malib/learner/config.py b/malib/learner/config.py index afa31367..14267842 100644 --- a/malib/learner/config.py +++ b/malib/learner/config.py @@ -11,6 +11,8 @@ class LearnerConfig: learner_type: Type[Learner] feature_handler_meta_gen: Callable[["EnvDesc", str], Callable[[str], BaseFeature]] + """what is it?""" + custom_config: Dict[str, Any] = field(default_factory=dict()) @classmethod diff --git a/malib/learner/learner.py b/malib/learner/learner.py index 4beab200..23f8cc09 100644 --- a/malib/learner/learner.py +++ b/malib/learner/learner.py @@ -114,6 +114,10 @@ def __init__( algorithm.trainer_config, self._policy ) + # since the trainer_config has been updated by the trainer + # thus the algorithm should update its trainer_config + algorithm.trainer_config = self._trainer.training_config + if dataset is None: dataset = DynamicDataset( grpc_thread_num_workers=2, @@ -126,7 +130,6 @@ def __init__( dataset.feature_handler = feature_handler_gen(device) dataset.start_server() - self._data_loader = DataLoader( dataset, batch_size=algorithm.trainer_config["batch_size"] ) diff --git a/malib/learner/manager.py b/malib/learner/manager.py index 79ccbbd0..6738c3ed 100644 --- a/malib/learner/manager.py +++ b/malib/learner/manager.py @@ -105,9 +105,9 @@ def __init__( learner_cls = learner_cls.as_remote(**resource_config) learners: Dict[str, ray.ObjectRef] = {} - assert ( - "training" in stopping_conditions - ), f"Stopping conditions should contains `training` stoppong conditions: {stopping_conditions}" + # assert ( + # "training" in stopping_conditions + # ), f"Stopping conditions should contains `training` stoppong conditions: {stopping_conditions}" ready_check = [] @@ -123,7 +123,6 @@ def __init__( algorithm=algorithm, agent_mapping_func=agent_mapping_func, governed_agents=agents, - trainer_config=algorithm.trainer_config, custom_config=learner_config.custom_config, feature_handler_gen=learner_config.feature_handler_meta_gen( env_desc, agents[0] @@ -236,10 +235,7 @@ def add_policies( policy_nums = dict.fromkeys(interface_ids, n) if isinstance(n, int) else n strategy_spec_list: List[StrategySpec] = ray.get( - [ - self._learners[k].add_policies.remote(n=policy_nums[k]) - for k in interface_ids - ] + [self._learners[k].get_strategy_spec.remote() for k in interface_ids] ) strategy_spec_dict: Dict[str, StrategySpec] = dict( zip(interface_ids, strategy_spec_list) diff --git a/malib/models/model_client.py b/malib/models/model_client.py index e239fd45..fefb524d 100644 --- a/malib/models/model_client.py +++ b/malib/models/model_client.py @@ -1,4 +1,4 @@ -from typing import Dict, Any +from typing import Dict, Any, Union from concurrent import futures import threading @@ -19,7 +19,9 @@ def load_state_dict(client, timeout=10): class ModelClient: - def __init__(self, entry_point: str, model_config: ModelConfig): + def __init__( + self, entry_point: str, model_config: Union[ModelConfig, Dict[str, Any]] + ): """Construct a model client for mantaining a model instance and its update. Args: diff --git a/malib/rl/common/policy.py b/malib/rl/common/policy.py index 4cd783dd..2f0b1d38 100644 --- a/malib/rl/common/policy.py +++ b/malib/rl/common/policy.py @@ -108,7 +108,10 @@ def __init__( self._model = kwargs.get("model_client") if self._model is None: if kwargs.get("model_entry_point"): - self._model = ModelClient(kwargs["model_entry_point"], model_config) + self._model = ModelClient( + kwargs["model_entry_point"], + ModelConfig(lambda **x: self.create_model(), model_config), + ) else: self._model = self.create_model().to(self._device) @@ -147,7 +150,7 @@ def preprocessor(self) -> Preprocessor: return self._preprocessor @property - def device(self) -> str: + def device(self) -> torch.device: return self._device @property @@ -186,7 +189,7 @@ def state_dict( res = self.model.state_dict() else: res = {} - for k, v in self.model.state_dict(): + for k, v in self.model.state_dict().items(): res[k] = v.to(device) return res @@ -249,16 +252,18 @@ def to(self, device: str = None, use_copy: bool = False) -> "Policy": Policy: A policy instance """ - if isinstance(device, torch.device): - device = device.type + if isinstance(device, str): + device = torch.device(device) if device is None: - device = "cpu" if "cuda" not in self.device else "cuda" + device = ( + torch.device("cpu") if "cuda" not in self.device.type else self.device + ) - cond1 = "cpu" in device and "cuda" in self.device - cond2 = "cuda" in device and "cuda" not in self.device + cond1 = "cpu" in device.type and "cuda" in self.device.type + cond2 = "cuda" in device.type and "cuda" not in self.device.type - if "cpu" in device: + if "cpu" in device.type: _device = device else: _device = self.device diff --git a/malib/rl/pg/config.py b/malib/rl/pg/config.py index 1b7b8a35..b60cf81b 100644 --- a/malib/rl/pg/config.py +++ b/malib/rl/pg/config.py @@ -29,6 +29,7 @@ "reward_norm": None, "n_repeat": 2, "minibatch": 2, + "batch_size": 32, "gamma": 0.99, }, "model_config": { diff --git a/malib/rl/pg/policy.py b/malib/rl/pg/policy.py index 20b23a2c..257e281b 100644 --- a/malib/rl/pg/policy.py +++ b/malib/rl/pg/policy.py @@ -80,6 +80,7 @@ def create_model(self): self.model_config["preprocess_net"].get("net_type", None), **self.model_config["preprocess_net"]["config"] ) + if isinstance(self.action_space, spaces.Discrete): return discrete.Actor( preprocess_net=preprocess_net, diff --git a/malib/rl/pg/trainer.py b/malib/rl/pg/trainer.py index 1a8d4091..6f8c9416 100644 --- a/malib/rl/pg/trainer.py +++ b/malib/rl/pg/trainer.py @@ -29,18 +29,28 @@ import numpy as np from torch import optim +from malib.rl.common.policy import Policy from malib.rl.common.trainer import Trainer from malib.utils.data import Postprocessor +from malib.utils.general import merge_dicts from malib.utils.typing import AgentID from malib.utils.tianshou_batch import Batch +from .config import DEFAULT_CONFIG class PGTrainer(Trainer): + def __init__(self, training_config: Dict[str, Any], policy_instance: Policy = None): + # merge from default + training_config = merge_dicts( + DEFAULT_CONFIG["training_config"], training_config or {} + ) + super().__init__(training_config, policy_instance) + def setup(self): self.optimizer: Type[optim.Optimizer] = getattr( optim, self.training_config["optimizer"] - )(self.policy.parameters()["actor"], lr=self.training_config["lr"]) + )(self.policy.actor.parameters(), lr=self.training_config["lr"]) self.lr_scheduler: torch.optim.lr_scheduler.LambdaLR = None self.ret_rms = None diff --git a/malib/rl/random/random_trainer.py b/malib/rl/random/random_trainer.py index 1d09e66f..3020ab69 100644 --- a/malib/rl/random/random_trainer.py +++ b/malib/rl/random/random_trainer.py @@ -1,6 +1,20 @@ -from typing import Any, Dict +from typing import Any, Dict, Type + +import torch + +from torch import optim + +from malib.rl.common.policy import Policy from malib.rl.pg.trainer import PGTrainer class RandomTrainer(PGTrainer): - pass + def __init__(self, training_config: Dict[str, Any], policy_instance: Policy = None): + super().__init__(training_config, policy_instance) + + def setup(self): + self.optimizer: Type[optim.Optimizer] = getattr( + optim, self.training_config["optimizer"] + )(self.policy.parameters(), lr=self.training_config["lr"]) + self.lr_scheduler: torch.optim.lr_scheduler.LambdaLR = None + self.ret_rms = None diff --git a/malib/rollout/inference/client.py b/malib/rollout/inference/client.py index 82715e71..2c41e2ae 100644 --- a/malib/rollout/inference/client.py +++ b/malib/rollout/inference/client.py @@ -77,6 +77,15 @@ def shutdown(self): pass def process_obs(self, raw_observation: Any) -> np.ndarray: + """Convert raw environmental observation to array like. + + Args: + raw_observation (Any): Raw environmental observation. + + Returns: + np.ndarray: Array-like observation. + """ + return self.fixed_policy.preprocessor.transform(raw_observation) def compute_action( diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index c3f24ec3..a7b93ad3 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -203,7 +203,7 @@ def rollout(self, task: RolloutTask): task.strategy_specs, task.data_entrypoints, ) - # total_timesteps += results["total_timesteps"] + total_timesteps += results["total_timesteps"] # performance["rollout_iter_rate"] = (epoch + 1) / (time.time() - start_time) # performance["rollout_FPS"] = results["FPS"] @@ -216,12 +216,12 @@ def rollout(self, task: RolloutTask): # formatted_results = pprint.pformat(eval_results) # Logger.info(f"Evaluation at epoch: {epoch}\n{formatted_results}") - # write_to_tensorboard( - # self.tb_writer, - # results, - # global_step=total_timesteps, - # prefix="Evaluation", - # ) + write_to_tensorboard( + self.tb_writer, + results, + global_step=total_timesteps, + prefix="Rollouts", + ) write_to_tensorboard( self.tb_writer, performance, global_step=epoch, prefix="Performance" diff --git a/malib/scenarios/sarl_scenario.py b/malib/scenarios/sarl_scenario.py index 07bc2817..b50ffcfb 100644 --- a/malib/scenarios/sarl_scenario.py +++ b/malib/scenarios/sarl_scenario.py @@ -66,7 +66,7 @@ def create_global_stopper(self) -> StoppingCondition: return get_stopper(self.stopping_conditions) -def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = True): +def execution_plan(scenario: SARLScenario, verbose: bool = True): # TODO(ming): simplize the initialization of training and rollout manager with a scenario instance as input learner_manager = LearnerManager( stopping_conditions=scenario.stopping_conditions, @@ -77,13 +77,13 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = learner_config=scenario.learner_config, log_dir=scenario.log_dir, resource_config=scenario.resource_config["training"], - ray_actor_namespace="learner_{}".format(experiment_tag), + ray_actor_namespace="learner_{}".format(scenario.name), verbose=verbose, ) inference_manager = InferenceManager( group_info=scenario.group_info, - ray_actor_namespace="inference_{}".format(experiment_tag), + ray_actor_namespace="inference_{}".format(scenario.name), model_entry_point=learner_manager.learner_entrypoints, algorithm=scenario.algorithm, verbose=verbose, @@ -97,7 +97,7 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = env_desc=scenario.env_desc, log_dir=scenario.log_dir, resource_config=scenario.resource_config["rollout"], - ray_actor_namespace="rollout_{}".format(experiment_tag), + ray_actor_namespace="rollout_{}".format(scenario.name), verbose=verbose, ) diff --git a/tests/rollout/test_pb_rollout_worker.py b/tests/rollout/test_pb_rollout_worker.py index baf1c334..2936a269 100644 --- a/tests/rollout/test_pb_rollout_worker.py +++ b/tests/rollout/test_pb_rollout_worker.py @@ -30,6 +30,7 @@ from malib.common.task import RolloutTask from malib.common.strategy_spec import StrategySpec from malib.rl.random import RandomPolicy +from malib.rl.random import RandomTrainer from malib.rl.config import Algorithm from malib.rollout.envs.random import env_desc_gen from malib.rollout.config import RolloutConfig @@ -58,7 +59,7 @@ def gen_common_requirements(n_player: int): env_desc = env_desc_gen(num_agents=n_player) algorithm = Algorithm( - policy=RandomPolicy, trainer=None, model_config=None, trainer_config={} + policy=RandomPolicy, trainer=RandomTrainer, model_config=None, trainer_config={} ) rollout_config = RolloutConfig( @@ -115,53 +116,53 @@ def f(device): @pytest.mark.parametrize("n_player", [1, 2]) class TestRolloutWorker: - def test_rollout(self, n_player: int): - with ray.init(local_mode=True): - env_desc, algorithm, rollout_config, group_info = gen_common_requirements( - n_player - ) - - obs_spaces = env_desc["observation_spaces"] - act_spaces = env_desc["action_spaces"] - agents = env_desc["possible_agents"] - log_dir = "./logs" - - inference_namespace = "test_pb_rolloutworker" - - infer_manager = InferenceManager( - group_info=group_info, - ray_actor_namespace=inference_namespace, - algorithm=algorithm, - model_entry_point=None, - ) - - rollout_config.inference_entry_points = infer_manager.inference_entry_points - - strategy_specs = { - agent: StrategySpec( - policy_cls=algorithm.policy, - observation_space=obs_spaces[agent], - action_space=act_spaces[agent], - identifier=agent, - model_config=algorithm.model_config, - policy_ids=["policy-0"], - ) - for agent in agents - } - - worker = PBRolloutWorker( - env_desc=env_desc, - agent_groups=group_info["agent_groups"], - rollout_config=rollout_config, - log_dir=log_dir, - ) - - task = RolloutTask( - strategy_specs=strategy_specs, - stopping_conditions={"max_iteration": 10}, - data_entrypoint_mapping=None, # no data collect - ) - stats = worker.rollout(task) + # def test_rollout(self, n_player: int): + # with ray.init(local_mode=True): + # env_desc, algorithm, rollout_config, group_info = gen_common_requirements( + # n_player + # ) + + # obs_spaces = env_desc["observation_spaces"] + # act_spaces = env_desc["action_spaces"] + # agents = env_desc["possible_agents"] + # log_dir = "./logs" + + # inference_namespace = "test_pb_rolloutworker" + + # infer_manager = InferenceManager( + # group_info=group_info, + # ray_actor_namespace=inference_namespace, + # algorithm=algorithm, + # model_entry_point=None, + # ) + + # rollout_config.inference_entry_points = infer_manager.inference_entry_points + + # strategy_specs = { + # agent: StrategySpec( + # policy_cls=algorithm.policy, + # observation_space=obs_spaces[agent], + # action_space=act_spaces[agent], + # identifier=agent, + # model_config=algorithm.model_config, + # policy_ids=["policy-0"], + # ) + # for agent in agents + # } + + # worker = PBRolloutWorker( + # env_desc=env_desc, + # agent_groups=group_info["agent_groups"], + # rollout_config=rollout_config, + # log_dir=log_dir, + # ) + + # task = RolloutTask( + # strategy_specs=strategy_specs, + # stopping_conditions={"max_iteration": 10}, + # data_entrypoints=None, # no data collect + # ) + # stats = worker.rollout(task) def test_rollout_with_data_entrypoint(self, n_player: int): with ray.init(local_mode=True): @@ -221,7 +222,7 @@ def test_rollout_with_data_entrypoint(self, n_player: int): task = RolloutTask( strategy_specs=strategy_spaces, stopping_conditions={"max_iteration": 10}, - data_entrypoint_mapping=learner_manager.data_entrypoints, + data_entrypoints=learner_manager.data_entrypoints, ) stats = worker.rollout(task) From 3e6923046786775d1db11708f8d871384935fb4a Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Mon, 27 Nov 2023 19:05:35 +0800 Subject: [PATCH 21/24] pb test passed --- malib/backend/dataset_server/feature.py | 21 +++++++++++++++++++-- malib/learner/learner.py | 5 ++++- malib/learner/manager.py | 5 +---- malib/models/model_client.py | 1 - malib/rollout/inference/client.py | 2 +- malib/rollout/inference/env_runner.py | 10 +++++++--- malib/rollout/inference/manager.py | 15 +++++++++------ malib/rollout/rolloutworker.py | 5 +++++ tests/rollout/test_pb_rollout_worker.py | 14 +++++++------- 9 files changed, 53 insertions(+), 25 deletions(-) diff --git a/malib/backend/dataset_server/feature.py b/malib/backend/dataset_server/feature.py index a759258d..27269ce1 100644 --- a/malib/backend/dataset_server/feature.py +++ b/malib/backend/dataset_server/feature.py @@ -35,7 +35,11 @@ def __init__( self.rw_lock = rwlock.RWLockFair() self._device = device self._spaces = spaces - self._block_size = block_size or list(np_memory.values())[0].shape[0] + self._block_size = ( + block_size + if block_size is not None + else list(np_memory.values())[0].shape[0] + ) self._available_size = 0 self._flag = 0 self._shared_memory = { @@ -59,9 +63,22 @@ def get(self, index: int): def write(self, data: Dict[str, Any], start: int, end: int): for k, v in data.items(): - self._shared_memory[k][start:end] = torch.as_tensor(v).to( + # FIXME(ming): should check the size of v + tensor = torch.as_tensor(v).to( self._device, dtype=self._shared_memory[k].dtype ) + split = 0 + if end > self.block_size: + # we now should split the data + split = self.block_size - start + self._shared_memory[k][start:] = tensor[:split] + _start = 0 + _end = tensor.shape[0] - split + else: + _start = start + _end = end + + self._shared_memory[k][_start:_end] = tensor[split:] def generate_timestep(self) -> Dict[str, np.ndarray]: return {k: space.sample() for k, space in self.spaces.items()} diff --git a/malib/learner/learner.py b/malib/learner/learner.py index 23f8cc09..438f5df0 100644 --- a/malib/learner/learner.py +++ b/malib/learner/learner.py @@ -50,6 +50,9 @@ from malib.rl.config import Algorithm +MAX_MESSAGE_LENGTH = 7309898 + + class Learner(RemoteInterface, ABC): """Base class of agent interface, for training""" @@ -121,7 +124,7 @@ def __init__( if dataset is None: dataset = DynamicDataset( grpc_thread_num_workers=2, - max_message_length=1024, + max_message_length=MAX_MESSAGE_LENGTH, feature_handler=feature_handler_gen(device), ) else: diff --git a/malib/learner/manager.py b/malib/learner/manager.py index 6738c3ed..8ed48f09 100644 --- a/malib/learner/manager.py +++ b/malib/learner/manager.py @@ -105,10 +105,6 @@ def __init__( learner_cls = learner_cls.as_remote(**resource_config) learners: Dict[str, ray.ObjectRef] = {} - # assert ( - # "training" in stopping_conditions - # ), f"Stopping conditions should contains `training` stoppong conditions: {stopping_conditions}" - ready_check = [] for rid, agents in group_info["agent_groups"].items(): @@ -135,6 +131,7 @@ def __init__( while len(ready_check): _, ready_check = ray.wait(ready_check, num_returns=1, timeout=1) + Logger.info("All Learners are ready for accepting new tasks.") data_entrypoints = ray.get( [x.get_data_entrypoint.remote() for x in learners.values()] ) diff --git a/malib/models/model_client.py b/malib/models/model_client.py index fefb524d..02e39aee 100644 --- a/malib/models/model_client.py +++ b/malib/models/model_client.py @@ -34,7 +34,6 @@ def __init__( """ namespace, name = entry_point.split(":") - self.client = ray.get_actor(name=name, namespace=namespace) self.thread_pool = futures.ThreadPoolExecutor(max_workers=10) diff --git a/malib/rollout/inference/client.py b/malib/rollout/inference/client.py index 2c41e2ae..93a6ab5f 100644 --- a/malib/rollout/inference/client.py +++ b/malib/rollout/inference/client.py @@ -126,7 +126,7 @@ def compute_action( with torch.inference_mode(): obs = self.fixed_policy.preprocessor.transform(raw_obs) - obs = torch.from_numpy(obs).float() + obs = torch.tensor(obs).float() # FIXME(ming): act mask and hidden state is set to None, # not feasible for cases which require them policy_return = policy.compute_action( diff --git a/malib/rollout/inference/env_runner.py b/malib/rollout/inference/env_runner.py index 1048e787..230f427f 100644 --- a/malib/rollout/inference/env_runner.py +++ b/malib/rollout/inference/env_runner.py @@ -135,7 +135,7 @@ def __init__( agent_groups: Dict[str, Set] = None, inference_entry_points: Dict[str, str] = None, ) -> None: - super().__init__() + super(RemoteInterface, self).__init__() self._use_subproc_env = use_subproc_env self._max_env_num = max_env_num @@ -265,8 +265,12 @@ def run( # FIXME(ming): send data to remote dataset data = agent_manager.merge_episodes() data_entrypoints = data_entrypoints or {} - for entrypoint in data_entrypoints.values(): - send_data(data, entrypoint=entrypoint) + for k, entrypoint in data_entrypoints.items(): + # FIXME(ming): a bug, data: list of agent episode + agent_episode = data[0] + # requires agent group for identification + random_data = list(agent_episode.values())[0] + send_data(random_data, entrypoint=entrypoint) stats = {"total_timesteps": total_timestep, **timer.todict()} return stats diff --git a/malib/rollout/inference/manager.py b/malib/rollout/inference/manager.py index 0bbd3a82..21e07e54 100644 --- a/malib/rollout/inference/manager.py +++ b/malib/rollout/inference/manager.py @@ -4,7 +4,7 @@ from malib.common.manager import Manager from malib.rl.config import Algorithm -from malib.scenarios import Scenario +from malib.utils.logging import Logger from malib.rollout.inference.client import InferenceClient @@ -12,10 +12,10 @@ class InferenceManager(Manager): def __init__( self, group_info: Dict[str, Set], - ray_actor_namespace: str, model_entry_point: Dict[str, str], algorithm: Algorithm, verbose: bool = False, + ray_actor_namespace: str = "inference", ): super().__init__(verbose, namespace=ray_actor_namespace) @@ -26,10 +26,11 @@ def __init__( self._infer_clients = {} self._inference_entry_points = {} - # FIXME(Ming): for debug only - model_entry_point = model_entry_point or { - rid: None for rid in agent_groups.keys() - } + model_entry_point = ( + model_entry_point + if model_entry_point is not None + else {rid: None for rid in agent_groups.keys()} + ) infer_client_ready_check = [] for rid, _ in agent_groups.items(): @@ -54,6 +55,8 @@ def __init__( infer_client_ready_check, num_returns=1, timeout=1 ) + Logger.info("All inference clients are ready for serving") + def get_inference_client(self, runtime_id: str) -> InferenceClient: return self.inference_clients[runtime_id] diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index a7b93ad3..d06bec1f 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -170,6 +170,11 @@ def create_env_runner( agent_groups=self.agent_groups, inference_entry_points=rollout_config.inference_entry_points, ) + ready_check = [env_runner.ready.remote()] + + # wait for it be ready + while len(ready_check): + _, ready_check = ray.wait(ready_check, num_returns=1, timeout=1) return env_runner diff --git a/tests/rollout/test_pb_rollout_worker.py b/tests/rollout/test_pb_rollout_worker.py index 2936a269..9bbd45ff 100644 --- a/tests/rollout/test_pb_rollout_worker.py +++ b/tests/rollout/test_pb_rollout_worker.py @@ -103,13 +103,13 @@ def f(device): Episode.DONE: spaces.Discrete(1), Episode.CUR_OBS: env_desc["observation_spaces"][agent_id], Episode.ACTION: env_desc["action_spaces"][agent_id], - Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32), + Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(), dtype=np.float32), Episode.NEXT_OBS: env_desc["observation_spaces"][agent_id], } np_memory = { - k: np.zeros((100,) + v.shape, dtype=v.dtype) for k, v in _spaces.items() + k: np.zeros((1000,) + v.shape, dtype=v.dtype) for k, v in _spaces.items() } - return FakeFeatureHandler(_spaces, np_memory, device) + return FakeFeatureHandler(_spaces, np_memory, device=device) return f @@ -165,7 +165,7 @@ class TestRolloutWorker: # stats = worker.rollout(task) def test_rollout_with_data_entrypoint(self, n_player: int): - with ray.init(local_mode=True): + with ray.init(): env_desc, algorithm, rollout_config, group_info = gen_common_requirements( n_player ) @@ -175,8 +175,6 @@ def test_rollout_with_data_entrypoint(self, n_player: int): agents = env_desc["possible_agents"] log_dir = "./logs" - inference_namespace = "test_pb_rolloutworker" - learner_manager = LearnerManager( stopping_conditions={"max_iteration": 10}, algorithm=algorithm, @@ -193,7 +191,6 @@ def test_rollout_with_data_entrypoint(self, n_player: int): infer_manager = InferenceManager( group_info=group_info, - ray_actor_namespace=inference_namespace, algorithm=algorithm, model_entry_point=learner_manager.learner_entrypoints, ) @@ -219,6 +216,8 @@ def test_rollout_with_data_entrypoint(self, n_player: int): log_dir=log_dir, ) + print("PBRollout worker is ready to work!!!") + task = RolloutTask( strategy_specs=strategy_spaces, stopping_conditions={"max_iteration": 10}, @@ -226,3 +225,4 @@ def test_rollout_with_data_entrypoint(self, n_player: int): ) stats = worker.rollout(task) + ray.shutdown() From 6c60dc54f1c83eff132686106aae84ec53bb7b8f Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Fri, 1 Dec 2023 15:14:34 +0800 Subject: [PATCH 22/24] refine some tests for rollout module --- malib/mocker/mocker_utils.py | 18 +++- tests/rollout/test_env.py | 2 + tests/rollout/test_env_runner.py | 2 +- tests/rollout/test_pb_rollout_worker.py | 114 +++++++++++------------- tests/rollout/test_rollout_manager.py | 24 ++--- tests/rollout/test_vector_env.py | 2 + 6 files changed, 85 insertions(+), 77 deletions(-) diff --git a/malib/mocker/mocker_utils.py b/malib/mocker/mocker_utils.py index e4685e30..559cadb6 100644 --- a/malib/mocker/mocker_utils.py +++ b/malib/mocker/mocker_utils.py @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Sequence, Dict, Any, Callable, List, Tuple +from typing import Sequence, Dict, Any, Callable, List, Tuple, Union import time @@ -35,6 +35,20 @@ from malib.utils.typing import PolicyID from malib.common.payoff_manager import PayoffManager +from malib.learner.learner import Learner +from malib.backend.dataset_server.feature import BaseFeature + + +class FakeLearner(Learner): + def multiagent_post_process( + self, + batch_info, + ) -> Dict[str, Any]: + pass + + +class FakeFeatureHandler(BaseFeature): + pass class FakePayoffManager(PayoffManager): @@ -79,7 +93,7 @@ def __init__( stopping_conditions: Dict[str, Any], num_worker: int, group_info: Dict[str, Any], - rollout_config: RolloutConfig | Dict[str, Any], + rollout_config: Union[RolloutConfig, Dict[str, Any]], env_desc: Dict[str, Any], log_dir: str, resource_config: Dict[str, Any] = None, diff --git a/tests/rollout/test_env.py b/tests/rollout/test_env.py index b16502be..ff2a3cde 100644 --- a/tests/rollout/test_env.py +++ b/tests/rollout/test_env.py @@ -22,6 +22,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +# NOTE(ming): not been tested yet + import pytest from pytest_mock import MockerFixture diff --git a/tests/rollout/test_env_runner.py b/tests/rollout/test_env_runner.py index 33e69f96..35e9cbef 100644 --- a/tests/rollout/test_env_runner.py +++ b/tests/rollout/test_env_runner.py @@ -63,7 +63,7 @@ def test_env_runner(env_desc: Dict[str, Any], max_env_num: int): rollout_config, strategy_specs, inference_clients=infer_clients, - data_entrypoint_mapping=None, + data_entrypoints=None, ) print(stats) diff --git a/tests/rollout/test_pb_rollout_worker.py b/tests/rollout/test_pb_rollout_worker.py index 9bbd45ff..0b5a4ee3 100644 --- a/tests/rollout/test_pb_rollout_worker.py +++ b/tests/rollout/test_pb_rollout_worker.py @@ -22,8 +22,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Any, List, Tuple - import pytest import ray @@ -77,24 +75,11 @@ def gen_common_requirements(n_player: int): import numpy as np -from malib.learner.learner import Learner from gym import spaces from malib.learner.manager import LearnerManager from malib.learner.config import LearnerConfig from malib.utils.episode import Episode -from malib.backend.dataset_server.feature import BaseFeature - - -class FakeLearner(Learner): - def multiagent_post_process( - self, - batch_info, - ) -> Dict[str, Any]: - pass - - -class FakeFeatureHandler(BaseFeature): - pass +from malib.mocker.mocker_utils import FakeLearner, FakeFeatureHandler def feature_handler_meta_gen(env_desc, agent_id): @@ -116,53 +101,53 @@ def f(device): @pytest.mark.parametrize("n_player", [1, 2]) class TestRolloutWorker: - # def test_rollout(self, n_player: int): - # with ray.init(local_mode=True): - # env_desc, algorithm, rollout_config, group_info = gen_common_requirements( - # n_player - # ) - - # obs_spaces = env_desc["observation_spaces"] - # act_spaces = env_desc["action_spaces"] - # agents = env_desc["possible_agents"] - # log_dir = "./logs" - - # inference_namespace = "test_pb_rolloutworker" - - # infer_manager = InferenceManager( - # group_info=group_info, - # ray_actor_namespace=inference_namespace, - # algorithm=algorithm, - # model_entry_point=None, - # ) - - # rollout_config.inference_entry_points = infer_manager.inference_entry_points - - # strategy_specs = { - # agent: StrategySpec( - # policy_cls=algorithm.policy, - # observation_space=obs_spaces[agent], - # action_space=act_spaces[agent], - # identifier=agent, - # model_config=algorithm.model_config, - # policy_ids=["policy-0"], - # ) - # for agent in agents - # } - - # worker = PBRolloutWorker( - # env_desc=env_desc, - # agent_groups=group_info["agent_groups"], - # rollout_config=rollout_config, - # log_dir=log_dir, - # ) - - # task = RolloutTask( - # strategy_specs=strategy_specs, - # stopping_conditions={"max_iteration": 10}, - # data_entrypoints=None, # no data collect - # ) - # stats = worker.rollout(task) + def test_rollout(self, n_player: int): + with ray.init(local_mode=True): + env_desc, algorithm, rollout_config, group_info = gen_common_requirements( + n_player + ) + + obs_spaces = env_desc["observation_spaces"] + act_spaces = env_desc["action_spaces"] + agents = env_desc["possible_agents"] + log_dir = "./logs" + + inference_namespace = "test_pb_rolloutworker" + + infer_manager = InferenceManager( + group_info=group_info, + ray_actor_namespace=inference_namespace, + algorithm=algorithm, + model_entry_point=None, + ) + + rollout_config.inference_entry_points = infer_manager.inference_entry_points + + strategy_specs = { + agent: StrategySpec( + policy_cls=algorithm.policy, + observation_space=obs_spaces[agent], + action_space=act_spaces[agent], + identifier=agent, + model_config=algorithm.model_config, + policy_ids=["policy-0"], + ) + for agent in agents + } + + worker = PBRolloutWorker( + env_desc=env_desc, + agent_groups=group_info["agent_groups"], + rollout_config=rollout_config, + log_dir=log_dir, + ) + + task = RolloutTask( + strategy_specs=strategy_specs, + stopping_conditions={"max_iteration": 10}, + data_entrypoints=None, # no data collect + ) + stats = worker.rollout(task) def test_rollout_with_data_entrypoint(self, n_player: int): with ray.init(): @@ -196,6 +181,9 @@ def test_rollout_with_data_entrypoint(self, n_player: int): ) rollout_config.inference_entry_points = infer_manager.inference_entry_points + assert ( + "default" in rollout_config.inference_entry_points + ), rollout_config.inference_entry_points strategy_spaces = { agent: StrategySpec( diff --git a/tests/rollout/test_rollout_manager.py b/tests/rollout/test_rollout_manager.py index 4258e8ee..e550a80e 100644 --- a/tests/rollout/test_rollout_manager.py +++ b/tests/rollout/test_rollout_manager.py @@ -51,12 +51,12 @@ def create_manager( stopping_conditions: Dict[str, Any], rollout_config: Dict[str, Any], env_desc: Dict[str, Any], - agent_mapping_func: Callable, + group_info: Dict[str, Any], ): manager = RolloutWorkerManager( stopping_conditions=stopping_conditions, num_worker=1, - group_info=form_group_info(env_desc, agent_mapping_func), + group_info=group_info, rollout_config=rollout_config, env_desc=env_desc, log_dir="./logs", @@ -67,17 +67,11 @@ def create_manager( @pytest.mark.parametrize("n_players", [1, 2]) class TestRolloutManager: def test_rollout_task_send(self, mocker: MockerFixture, n_players: int): - with ray.init(local_mode=True): + with ray.init(): env_desc, algorithm, rollout_config, group_info = gen_common_requirements( n_players ) inference_namespace = "test_pb_rolloutworker" - manager = create_manager( - stopping_conditions={"rollout": {"max_iteration": 2}}, - rollout_config=RolloutConfig(), - env_desc=env_desc, - agent_mapping_func=lambda agent: "default", - ) learner_manager = LearnerManager( stopping_conditions={"max_iteration": 10}, @@ -100,7 +94,14 @@ def test_rollout_task_send(self, mocker: MockerFixture, n_players: int): model_entry_point=learner_manager.learner_entrypoints, ) - rollout_config.inference_entry_points = infer_manager.inference_entry_points + rollout_manager = create_manager( + stopping_conditions={"rollout": {"max_iteration": 2}}, + rollout_config=RolloutConfig( + inference_entry_points=infer_manager.inference_entry_points + ), + env_desc=env_desc, + group_info=group_info, + ) strategy_specs = { agent: StrategySpec( @@ -118,4 +119,5 @@ def test_rollout_task_send(self, mocker: MockerFixture, n_players: int): data_entrypoints=None, ) - results = manager.submit(task, wait=True) + results = rollout_manager.submit(task, wait=True) + ray.shutdown() diff --git a/tests/rollout/test_vector_env.py b/tests/rollout/test_vector_env.py index b1534d35..77d20efd 100644 --- a/tests/rollout/test_vector_env.py +++ b/tests/rollout/test_vector_env.py @@ -22,6 +22,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +# NOTE(ming): not been tested yet + from typing import List, Dict, Any import pytest From c9840a8942f057e7d2ebeae694d05d40bd0b89dc Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Wed, 13 Dec 2023 18:49:04 +0800 Subject: [PATCH 23/24] tmp save --- examples/sarl/ppo_gym.py | 2 +- malib/backend/dataset_server/data_loader.py | 2 + malib/backend/dataset_server/feature.py | 26 +--- malib/backend/dataset_server/service.py | 5 +- malib/learner/indepdent_learner.py | 23 +-- malib/learner/learner.py | 57 ++++--- malib/rl/coma/critic.py | 2 +- malib/rl/coma/trainer.py | 2 +- malib/rl/pg/__init__.py | 3 +- malib/rl/pg/config.py | 17 +- malib/rl/pg/policy.py | 6 +- malib/rl/pg/trainer.py | 4 +- malib/rl/random/__init__.py | 8 +- malib/rl/random/config.py | 13 +- malib/rl/random/random_trainer.py | 15 +- malib/rollout/envs/vector_env.py | 2 +- malib/{utils => rollout}/episode.py | 7 + malib/rollout/inference/env_runner.py | 46 ++++-- malib/rollout/inference/utils.py | 2 +- malib/rollout/rolloutworker.py | 11 +- malib/utils/data.py | 44 +++++- tests/agents/test_async_agent.py | 23 --- tests/agents/test_independent_agent.py | 163 ++++++++++++-------- tests/backend/test_dynamic_dataset.py | 27 +++- tests/rl/test_algorithms.py | 2 +- tests/rl/test_coma.py | 2 +- tests/rollout/test_pb_rollout_worker.py | 2 +- tests/structures/test_episode.py | 2 +- 28 files changed, 309 insertions(+), 209 deletions(-) rename malib/{utils => rollout}/episode.py (97%) delete mode 100644 tests/agents/test_async_agent.py diff --git a/examples/sarl/ppo_gym.py b/examples/sarl/ppo_gym.py index f64ced00..795ad8e6 100644 --- a/examples/sarl/ppo_gym.py +++ b/examples/sarl/ppo_gym.py @@ -7,7 +7,7 @@ import numpy as np -from malib.utils.episode import Episode +from malib.rollout.episode import Episode from malib.learner import IndependentAgent from malib.scenarios import sarl_scenario from malib.rl.config import Algorithm diff --git a/malib/backend/dataset_server/data_loader.py b/malib/backend/dataset_server/data_loader.py index e1a2c25a..6cf9ff8d 100644 --- a/malib/backend/dataset_server/data_loader.py +++ b/malib/backend/dataset_server/data_loader.py @@ -16,6 +16,8 @@ class EmptyError(Exception): pass +# TODO(ming): considering to determine the `max_message_length` +# by a FeatureHandler, as it is convinient for it to know the size of data. class DynamicDataset(Dataset): def __init__( self, diff --git a/malib/backend/dataset_server/feature.py b/malib/backend/dataset_server/feature.py index 27269ce1..8df9ce17 100644 --- a/malib/backend/dataset_server/feature.py +++ b/malib/backend/dataset_server/feature.py @@ -1,5 +1,5 @@ from typing import Any, Dict -from abc import ABC, abstractmethod +from abc import ABC import copy import numpy as np @@ -7,21 +7,7 @@ from gym import spaces from readerwriterlock import rwlock - - -numpy_to_torch_dtype_dict = { - np.bool_: torch.bool, - np.uint8: torch.uint8, - np.int8: torch.int8, - np.int16: torch.int16, - np.int32: torch.int32, - np.int64: torch.int64, - np.float16: torch.float16, - np.float32: torch.float32, - np.float64: torch.float64, - np.complex64: torch.complex64, - np.complex128: torch.complex128, -} +from malib.utils.data import numpy_to_torch_dtype_dict class BaseFeature(ABC): @@ -35,15 +21,11 @@ def __init__( self.rw_lock = rwlock.RWLockFair() self._device = device self._spaces = spaces - self._block_size = ( - block_size - if block_size is not None - else list(np_memory.values())[0].shape[0] - ) + self._block_size = min(block_size or np.iinfo(np.longlong).max, list(np_memory.values())[0].shape[0]) self._available_size = 0 self._flag = 0 self._shared_memory = { - k: torch.from_numpy(v).to(device).share_memory_() + k: torch.from_numpy(v[:self._block_size]).to(device).share_memory_() for k, v in np_memory.items() } diff --git a/malib/backend/dataset_server/service.py b/malib/backend/dataset_server/service.py index 2119787a..7f8537c4 100644 --- a/malib/backend/dataset_server/service.py +++ b/malib/backend/dataset_server/service.py @@ -1,6 +1,9 @@ +from typing import Dict + import threading import traceback import pickle +import numpy as np from . import data_pb2_grpc from . import data_pb2 @@ -19,7 +22,7 @@ def __init__( def Collect(self, request, context): try: - data = pickle.loads(request.data) + data: Dict[str, np.ndarray] = pickle.loads(request.data) batch_size = len(list(data.values())[0]) self.feature_handler.safe_put(data, batch_size) message = "success" diff --git a/malib/learner/indepdent_learner.py b/malib/learner/indepdent_learner.py index a21f97de..ae029b48 100644 --- a/malib/learner/indepdent_learner.py +++ b/malib/learner/indepdent_learner.py @@ -22,26 +22,15 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Tuple, Any, List, Union +from typing import Dict, Any + +import torch from malib.utils.typing import AgentID -from malib.utils.tianshou_batch import Batch +from malib.utils.data import to_torch from malib.learner.learner import Learner class IndependentAgent(Learner): - def multiagent_post_process( - self, - batch_info: Union[ - Dict[AgentID, Tuple[Batch, List[int]]], Tuple[Batch, List[int]] - ], - ) -> Dict[str, Any]: - if not isinstance(batch_info, Tuple): - raise TypeError( - "IndependentAgent support only a tuple of batch info as input." - ) - - batch = batch_info[0] - batch.to_torch(device=self.device) - - return batch + def multiagent_post_process(self, batch: Dict[AgentID, Dict[str, torch.Tensor]]) -> Dict[str, Any]: + return to_torch(batch, device=self.device) diff --git a/malib/learner/learner.py b/malib/learner/learner.py index 438f5df0..ffd6bf5c 100644 --- a/malib/learner/learner.py +++ b/malib/learner/learner.py @@ -23,7 +23,7 @@ # SOFTWARE. -from typing import Dict, Any, Tuple, Callable, List, Union, Type +from typing import Dict, Any, Tuple, Callable, List, Union from abc import ABC, abstractmethod import time @@ -50,6 +50,7 @@ from malib.rl.config import Algorithm +# TODO(ming): better to use a feature handler to determine the max_message_length MAX_MESSAGE_LENGTH = 7309898 @@ -63,7 +64,6 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Space, algorithm: Algorithm, - agent_mapping_func: Callable[[AgentID], str], governed_agents: Tuple[AgentID], custom_config: Dict[str, Any] = None, dataset: DynamicDataset = None, @@ -106,7 +106,6 @@ def __init__( self._algorithm = algorithm self._governed_agents = governed_agents self._strategy_spec = strategy_spec - self._agent_mapping_func = agent_mapping_func self._custom_config = custom_config self._policy = strategy_spec.gen_policy(device=device) @@ -144,14 +143,12 @@ def __init__( @abstractmethod def multiagent_post_process( self, - batch_info: Union[ - Dict[AgentID, Tuple[Batch, List[int]]], Tuple[Batch, List[int]] - ], + batch: Dict[AgentID, Dict[str, torch.Tensor]], ) -> Dict[str, Any]: """Merge agent buffer here and return the merged buffer. Args: - batch_info (Union[Dict[AgentID, Tuple[Batch, List[int]]], Tuple[Batch, List[int]]]): Batch info, could be a dict of agent batch info or a tuple. + batch (Dict[AgentID, Dict[str, torch.Tensor]]): A dict of agent batch. Returns: Dict[str, Any]: A merged buffer dict. @@ -218,6 +215,33 @@ def get_interface_state(self) -> Dict[str, Any]: "total_epoch": self._total_epoch, "policy_num": len(self._strategy_spec), } + + def step(self, prints: bool = False): + while ( + self.data_loader.dataset.readable_block_size + < self.data_loader.batch_size + ): + time.sleep(1) + + for data in self.data_loader: + batch_dict = self.multiagent_post_process(data) + batch = Batch(batch_dict) + # call trainer for one update step, and return training info + # since some algorithm may run multistep for one batch, + # then the returned training_info is a list of dict. + step_info_list = self.trainer(batch) + for step_info in step_info_list: + self._total_step += 1 + write_to_tensorboard( + self._summary_writer, + info=step_info, + global_step=self._total_step, + prefix=f"Learner/{self._runtime_id}", + ) + if prints: + print(self._total_step, step_info) + + self._total_epoch += 1 def train(self, task: OptimizationTask) -> Dict[str, Any]: """Executes a optimization task and returns the final interface state. @@ -233,25 +257,8 @@ def train(self, task: OptimizationTask) -> Dict[str, Any]: self.set_running(True) try: - while ( - self.data_loader.dataset.readable_block_size - < self.data_loader.batch_size - ): - time.sleep(1) - while self.is_running(): - for data in self.data_loader: - batch_info = self.multiagent_post_process(data) - step_info_list = self.trainer(batch_info) - for step_info in step_info_list: - self._total_step += 1 - write_to_tensorboard( - self._summary_writer, - info=step_info, - global_step=self._total_step, - prefix=f"Learner/{self._runtime_id}", - ) - self._total_epoch += 1 + self.step() except Exception as e: Logger.warning( f"training pipe is terminated. caused by: {traceback.format_exc()}" diff --git a/malib/rl/coma/critic.py b/malib/rl/coma/critic.py index 47719e36..9b1ed8b1 100644 --- a/malib/rl/coma/critic.py +++ b/malib/rl/coma/critic.py @@ -31,7 +31,7 @@ from torch import nn from gym import spaces -from malib.utils.episode import Episode +from malib.rollout.episode import Episode from malib.utils.tianshou_batch import Batch from malib.models.torch import make_net diff --git a/malib/rl/coma/trainer.py b/malib/rl/coma/trainer.py index abe5591d..24352ee3 100644 --- a/malib/rl/coma/trainer.py +++ b/malib/rl/coma/trainer.py @@ -32,7 +32,7 @@ from malib.utils.typing import AgentID from malib.utils.tianshou_batch import Batch from malib.utils.data import Postprocessor -from malib.utils.episode import Episode +from malib.rollout.episode import Episode from malib.rl.common import misc from malib.rl.common.trainer import Trainer from malib.rl.common.policy import Policy diff --git a/malib/rl/pg/__init__.py b/malib/rl/pg/__init__.py index 6e2aae1d..ee07d3aa 100644 --- a/malib/rl/pg/__init__.py +++ b/malib/rl/pg/__init__.py @@ -24,7 +24,8 @@ from .policy import PGPolicy from .trainer import PGTrainer -from .config import DEFAULT_CONFIG +from .config import Config POLICY = PGPolicy TRAINER = PGTrainer +DEFAULT_CONFIG = Config diff --git a/malib/rl/pg/config.py b/malib/rl/pg/config.py index b60cf81b..72db27f7 100644 --- a/malib/rl/pg/config.py +++ b/malib/rl/pg/config.py @@ -22,8 +22,10 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -DEFAULT_CONFIG = { - "training_config": { + +class Config: + + TRAINING_CONFIG = { "optimizer": "Adam", "lr": 1e-4, "reward_norm": None, @@ -31,10 +33,11 @@ "minibatch": 2, "batch_size": 32, "gamma": 0.99, - }, - "model_config": { + } + + CUSTOM_CONFIG = {} + + MODEL_CONFIG = { "preprocess_net": {"net_type": None, "config": {"hidden_sizes": [64]}}, "hidden_sizes": [64], - }, - "custom_config": {}, -} + } diff --git a/malib/rl/pg/policy.py b/malib/rl/pg/policy.py index 257e281b..a1a075fc 100644 --- a/malib/rl/pg/policy.py +++ b/malib/rl/pg/policy.py @@ -35,7 +35,7 @@ from malib.models.config import ModelConfig from malib.rl.common import misc from malib.rl.common.policy import Policy, PolicyReturn -from .config import DEFAULT_CONFIG +from .config import Config as DEFAULT_CONFIG class PGPolicy(Policy): @@ -60,9 +60,9 @@ def __init__( # update model_config with default ones model_config = merge_dicts( - DEFAULT_CONFIG["model_config"].copy(), model_config or {} + DEFAULT_CONFIG.MODEL_CONFIG.copy(), model_config or {} ) - kwargs = merge_dicts(DEFAULT_CONFIG["custom_config"].copy(), kwargs) + kwargs = merge_dicts(DEFAULT_CONFIG.CUSTOM_CONFIG.copy(), kwargs) super().__init__(observation_space, action_space, model_config, **kwargs) diff --git a/malib/rl/pg/trainer.py b/malib/rl/pg/trainer.py index 6f8c9416..80533ae9 100644 --- a/malib/rl/pg/trainer.py +++ b/malib/rl/pg/trainer.py @@ -36,14 +36,14 @@ from malib.utils.general import merge_dicts from malib.utils.typing import AgentID from malib.utils.tianshou_batch import Batch -from .config import DEFAULT_CONFIG +from .config import Config class PGTrainer(Trainer): def __init__(self, training_config: Dict[str, Any], policy_instance: Policy = None): # merge from default training_config = merge_dicts( - DEFAULT_CONFIG["training_config"], training_config or {} + Config.TRAINING_CONFIG, training_config or {} ) super().__init__(training_config, policy_instance) diff --git a/malib/rl/random/__init__.py b/malib/rl/random/__init__.py index 3a3239b6..a444c138 100644 --- a/malib/rl/random/__init__.py +++ b/malib/rl/random/__init__.py @@ -1,3 +1,9 @@ from .policy import RandomPolicy from .random_trainer import RandomTrainer -from .config import DEFAULT_CONFIG +from .config import Config + +Policy = RandomPolicy +Trainer = RandomTrainer +DEFAULT_CONFIG = Config + +__all__ = ["Policy", "Trainer", "DEFAULT_CONFIG"] diff --git a/malib/rl/random/config.py b/malib/rl/random/config.py index 593d30c8..1d9e3af2 100644 --- a/malib/rl/random/config.py +++ b/malib/rl/random/config.py @@ -1,5 +1,6 @@ -DEFAULT_CONFIG = { - "training_config": { +class Config: + + TRAINING_CONFIG = { "optimizer": "Adam", "lr": 1e-4, "reward_norm": None, @@ -12,9 +13,9 @@ "entropy_coef": 1e-3, "grad_norm": 5.0, "use_cuda": False, - }, - "model_config": { + } + + MODEL_CONFIG = { "preprocess_net": {"net_type": None, "config": {"hidden_sizes": [64]}}, "hidden_sizes": [64], - }, -} + } diff --git a/malib/rl/random/random_trainer.py b/malib/rl/random/random_trainer.py index 3020ab69..fa7c1545 100644 --- a/malib/rl/random/random_trainer.py +++ b/malib/rl/random/random_trainer.py @@ -1,17 +1,30 @@ -from typing import Any, Dict, Type +from typing import Any, Dict, Sequence, Type +import random +import time import torch from torch import optim from malib.rl.common.policy import Policy from malib.rl.pg.trainer import PGTrainer +from malib.utils.tianshou_batch import Batch +from malib.utils.typing import AgentID class RandomTrainer(PGTrainer): def __init__(self, training_config: Dict[str, Any], policy_instance: Policy = None): super().__init__(training_config, policy_instance) + def post_process(self, batch: Batch, agent_filter: Sequence[AgentID]) -> Batch: + return batch + + def train(self, batch: Batch) -> Dict[str, Any]: + time.sleep(random.random()) + return { + "loss": random.random() + } + def setup(self): self.optimizer: Type[optim.Optimizer] = getattr( optim, self.training_config["optimizer"] diff --git a/malib/rollout/envs/vector_env.py b/malib/rollout/envs/vector_env.py index e5ebafe2..d7cb7275 100644 --- a/malib/rollout/envs/vector_env.py +++ b/malib/rollout/envs/vector_env.py @@ -40,7 +40,7 @@ PolicyID, ) from malib.rollout.envs.env import Environment -from malib.utils.episode import Episode +from malib.rollout.episode import Episode EnvironmentType = Type[Environment] diff --git a/malib/utils/episode.py b/malib/rollout/episode.py similarity index 97% rename from malib/utils/episode.py rename to malib/rollout/episode.py index 6bd22023..e4f2587a 100644 --- a/malib/utils/episode.py +++ b/malib/rollout/episode.py @@ -185,6 +185,13 @@ def to_numpy(self) -> Dict[AgentID, Dict[str, np.ndarray]]: class ConventionalEpisodeList: def __init__(self, num: int, agents: List[AgentID]) -> None: + """Construct a list of COnventialEpisode, for trajectory tracking for a bunch of environments. + + Args: + num (int): Episode number. + agents (List[AgentID]): A list of enviornment agent ids, distinguished from runtime ids. + """ + self.episodes = [ConventionalEpisode(agents) for _ in range(num)] def record(self, obs, actions, last_dones, last_rews, states, idx: int = None): diff --git a/malib/rollout/inference/env_runner.py b/malib/rollout/inference/env_runner.py index 230f427f..912b1e4b 100644 --- a/malib/rollout/inference/env_runner.py +++ b/malib/rollout/inference/env_runner.py @@ -28,11 +28,11 @@ import numpy as np from malib.utils.typing import AgentID -from malib.utils.episode import ConventionalEpisodeList +from malib.rollout.episode import ConventionalEpisodeList from malib.utils.timing import Timing +from malib.utils.data import merge_array_by_keys from malib.remote.interface import RemoteInterface -from malib.rollout.envs.vector_env import VectorEnv, SubprocVecEnv from malib.rollout.config import RolloutConfig from malib.rollout.inference.client import InferenceClient, PolicyReturnWithObs from malib.rollout.envs.env import Environment @@ -115,8 +115,29 @@ def collect_and_act( return actions - def merge_episodes(self): - return self.episodes.to_numpy() + def merge_episodes( + self, agent_groups: Dict[str, Tuple] + ) -> Dict[str, Dict[str, np.ndarray]]: + """A dict of merged episodes, which is grouped by agent groups. + + Args: + agent_groups (Dict[str, Tuple]): A dict of agent groups. + + Returns: + Dict[str, Dict[str, np.ndarray]]: A dict of merged episodes, which is grouped by agent groups. + """ + + episodes: List[Dict[AgentID, Dict[str, np.ndarray]]] = self.episodes.to_numpy() + + # then merge this episodes by agent groups + merged = {} + for episode in episodes: + for gid, agents in agent_groups.items(): + filtered = [episode[agent] for agent in agents] + # then merge them by keys + tmp = merge_array_by_keys(filtered) + merged[gid] = tmp + return merged from malib.utils.timing import Timing @@ -132,7 +153,7 @@ def __init__( env_func: Type, max_env_num: int, use_subproc_env: bool = False, - agent_groups: Dict[str, Set] = None, + agent_groups: Dict[str, Tuple] = None, inference_entry_points: Dict[str, str] = None, ) -> None: super(RemoteInterface, self).__init__() @@ -157,6 +178,10 @@ def envs(self) -> Tuple[Environment]: def env_func(self) -> Type: return self._env_func + @property + def agent_groups(self) -> Dict[str, Tuple]: + return self._agent_groups + @property def num_active_envs(self) -> int: return len(self._envs) @@ -181,7 +206,7 @@ def run( Args: rollout_config (RolloutConfig): Rollout configuration, which specifies how many data pieces will rollout. strategy_specs (Dict[AgentID, StrategySpec]): A dict of strategy specs, which rules the behavior policy of each agent. - inference_clients (Dict[AgentID, InferenceClient]): A dict of remote inference client. + inference_clients (Dict[AgentID, InferenceClient]): A dict of remote inference client, mapping from env agents to inference clients. Note that there could be a shared client for multiple agents. data_entrypoints (Dict[str, str], optional): A mapping which defines the data collection trigger, if not None, then return episodes. Defaults to None. Raises: @@ -262,15 +287,10 @@ def run( vec_rews[env_idx] = rews # merge agent episodes - # FIXME(ming): send data to remote dataset - data = agent_manager.merge_episodes() + data = agent_manager.merge_episodes(agent_groups=self.agent_groups) data_entrypoints = data_entrypoints or {} for k, entrypoint in data_entrypoints.items(): - # FIXME(ming): a bug, data: list of agent episode - agent_episode = data[0] - # requires agent group for identification - random_data = list(agent_episode.values())[0] - send_data(random_data, entrypoint=entrypoint) + send_data(data[k], entrypoint=entrypoint) stats = {"total_timesteps": total_timestep, **timer.todict()} return stats diff --git a/malib/rollout/inference/utils.py b/malib/rollout/inference/utils.py index c6aa1ed3..fb8a262f 100644 --- a/malib/rollout/inference/utils.py +++ b/malib/rollout/inference/utils.py @@ -30,7 +30,7 @@ from gym import spaces from malib.utils.typing import AgentID, DataFrame, EnvID -from malib.utils.episode import Episode +from malib.rollout.episode import Episode from malib.utils.preprocessor import Preprocessor from malib.rollout.envs.vector_env import VectorEnv diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index d06bec1f..4137d356 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -22,33 +22,26 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Any, List, Callable, Sequence, Tuple, Set, Union +from typing import Dict, Any, List, Callable, Tuple, Union from abc import abstractmethod -from collections import defaultdict import os import time -import traceback import logging -import pprint import ray -import gym -import numpy as np from ray.util import ActorPool from torch.utils import tensorboard from malib import settings -from malib.utils.typing import AgentID, BehaviorMode -from malib.utils.logging import Logger +from malib.utils.typing import AgentID from malib.utils.stopping_conditions import get_stopper from malib.utils.monitor import write_to_tensorboard from malib.common.strategy_spec import StrategySpec from malib.common.task import RolloutTask from malib.remote.interface import RemoteInterface from malib.rollout.config import RolloutConfig -from malib.rollout.inference.client import InferenceClient from malib.rollout.inference.env_runner import BasicEnvRunner diff --git a/malib/utils/data.py b/malib/utils/data.py index 2ba06c84..a223011a 100644 --- a/malib/utils/data.py +++ b/malib/utils/data.py @@ -1,7 +1,7 @@ # reference: https://github.com/thu-ml/tianshou/blob/master/tianshou/data/batch.py -from typing import Any, Union, Optional, Collection, Dict +from typing import Any, Union, Optional, Collection, Dict, List from numbers import Number import torch @@ -10,6 +10,48 @@ from numba import njit +numpy_to_torch_dtype_dict = { + np.bool_: torch.bool, + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} + + +def merge_array_by_keys( + candidates: List[Dict[str, np.ndarray]] +) -> Dict[str, np.ndarray]: + """Merge a list of arrays by keys. + + Args: + candidates (List[Dict[str, np.ndarray]]): A list of dict, each element is a dict of arrays. + + Returns: + Dict[str, np.ndarray]: A merged dict of arrays. + """ + + # check whether keys are the same + keys_reference = set(candidates[0].keys()) + for candidate in candidates[1:]: + assert keys_reference == set(candidate.keys()) + + # then merge arrays by keys + merged = {} + for key in keys_reference: + merged[key] = np.concatenate( + [candidate[key] for candidate in candidates], axis=0 + ) + + return merged + + def _is_scalar(value: Any) -> bool: # check if the value is a scalar # 1. python bool object, number object: isinstance(value, Number) diff --git a/tests/agents/test_async_agent.py b/tests/agents/test_async_agent.py deleted file mode 100644 index 47cb69ed..00000000 --- a/tests/agents/test_async_agent.py +++ /dev/null @@ -1,23 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Author: Ming Zhou - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. diff --git a/tests/agents/test_independent_agent.py b/tests/agents/test_independent_agent.py index df0f8d79..6d1b300a 100644 --- a/tests/agents/test_independent_agent.py +++ b/tests/agents/test_independent_agent.py @@ -24,77 +24,106 @@ from typing import Any - import pytest +import importlib +import gym +import numpy as np -from malib import rl -from malib.utils.tianshou_batch import Batch -from malib.utils.episode import Episode -from malib.mocker.mocker_utils import use_ray_env -from malib.rollout.envs.mdp import env_desc_gen +from malib.utils.general import merge_dicts +from malib.rl.config import Algorithm +from malib.learner.learner import MAX_MESSAGE_LENGTH from malib.learner.indepdent_learner import IndependentAgent +from malib.backend.dataset_server.data_loader import DynamicDataset + + +def construct_dataset( + feature_handler=None, feature_handler_cls=None, feature_handler_kwargs=None +): + return DynamicDataset( + grpc_thread_num_workers=1, + max_message_length=MAX_MESSAGE_LENGTH, + feature_handler=feature_handler, + feature_handler_cls=feature_handler_cls, + **feature_handler_kwargs + ) + + +def construct_learner( + obs_space, + act_space, + algorithm, + governed_agents, + custom_config=None, + dataset=None, + feature_handler_gen=None, +) -> IndependentAgent: + return IndependentAgent( + runtime_id=None, + log_dir=None, + observation_space=obs_space, + action_space=act_space, + algorithm=algorithm, + governed_agents=governed_agents, + custom_config=custom_config, + dataset=dataset, + feature_handler_gen=feature_handler_gen, + ) + + +def construct_algorithm(module_path, model_config={}, trainer_config={}): + # import policy, trainer and default config from a given module + policy_cls = importlib.import_module(module_path).Policy + trainer_cls = importlib.import_module(module_path).Trainer + default_config = importlib.import_module(module_path).DEFAULT_CONFIG + + return Algorithm( + policy=policy_cls, + trainer=trainer_cls, + model_config=merge_dicts(default_config.MODEL_CONFIG, model_config), + trainer_config=merge_dicts(default_config.TRAINING_CONFIG, trainer_config), + ) + + +from malib.mocker.mocker_utils import FakeFeatureHandler +from malib.rollout.episode import Episode + + +@pytest.mark.parametrize("module_path", [ + 'malib.rl.random' +]) +class TestIndependentAgent: + def test_learner_with_outer_dataset(self, module_path): + obs_space = gym.spaces.Box(low=-1, high=1, shape=(1, 1), dtype=np.float32) + act_space = gym.spaces.Discrete(2) + np_memory = { + Episode.CUR_OBS: np.zeros() + } + governed_agents = ["default"] + + dataset = construct_dataset( + feature_handler=FakeFeatureHandler( + { + Episode.CUR_OBS: obs_space, + Episode.ACTION: act_space, + }, + np_memory, + block_size=100, + device="cpu", + ) + ) + algorithm = construct_algorithm(module_path) + learner = construct_learner( + algorithm, governed_agents, custom_config=None, dataset=dataset + ) + for _ in range(10): + learner.step(prints=True) -def start_learner(env_id: str, algorithm: Any): - experiment_tag = "test_" - agent_mapping_func = lambda agent: agent - env_desc = env_desc_gen(env_id=env_id) - learners = { - agent: IndependentAgent( - experiment_tag=experiment_tag, - runtime_id=agent, - log_dir="./logs", - env_desc=env_desc, - algorithms={ - "default": ( - algorithm.POLICY, - algorithm.TRAINER, - algorithm.DEFAULT_CONFIG["model_config"], - {}, - ) - }, - agent_mapping_func=agent_mapping_func, - governed_agents=[agent], - trainer_config=algorithm.DEFAULT_CONFIG["training_config"], - custom_config={}, - ) - for agent in env_desc["possible_agents"] - } - for learner in learners.values(): - learner.connect(max_tries=2) - return learners + def test_learner_with_outer_feature_handler(self): + pass + def test_learner_with_feature_handler_gen(self): + pass -@pytest.mark.parametrize("env_id", ["two_round_dmdp"]) -@pytest.mark.parametrize("algorithm", [rl.pg, rl.a2c, rl.dqn]) -class TestIndependentAgent: - def test_policy_add(self, env_id, algorithm): - with use_ray_env(): - learners = start_learner(env_id, algorithm) - for learner in learners.values(): - learner.add_policies(n=1) - - def test_parameter_sync(self, env_id, algorithm): - with use_ray_env(): - learners = start_learner(env_id, algorithm) - for learner in learners.values(): - learner.add_policies(n=1) - # then sync parameter to remote parameter server - learner.push() - # also pull down - learner.pull() - - def test_multiagent_post_process(self, env_id, algorithm): - with use_ray_env(): - learners = start_learner(env_id, algorithm) - for learner in learners.values(): - batch = learner.multiagent_post_process((Batch(), None)) - assert isinstance(batch, Batch) - with pytest.raises(TypeError): - learner.multiagent_post_process("fefefefe") - - def test_training_pipeline(self, env_id, algorithm): - with use_ray_env(): - learners = start_learner(env_id, algorithm) - for learner in learners.values(): - learner.add_policies(n=1) + def test_learner_with_dataset_gen(self): + pass diff --git a/tests/backend/test_dynamic_dataset.py b/tests/backend/test_dynamic_dataset.py index 5c62387d..214f11d9 100644 --- a/tests/backend/test_dynamic_dataset.py +++ b/tests/backend/test_dynamic_dataset.py @@ -1,9 +1,34 @@ +# MIT License + +# Copyright (c) 2021 MARL @ SJTU + +# Author: Ming Zhou + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# TEST PASSED + from typing import Any, Dict import time import random import multiprocessing -import pytest import threading import numpy as np diff --git a/tests/rl/test_algorithms.py b/tests/rl/test_algorithms.py index a9eafaaf..c89a38f8 100644 --- a/tests/rl/test_algorithms.py +++ b/tests/rl/test_algorithms.py @@ -30,7 +30,7 @@ from malib import rl from malib.rl.common import policy, trainer -from malib.utils.episode import Episode +from malib.rollout.episode import Episode from malib.utils.tianshou_batch import Batch from malib.rollout.envs.mdp.env import MDPEnvironment diff --git a/tests/rl/test_coma.py b/tests/rl/test_coma.py index d5525475..96b00fe9 100644 --- a/tests/rl/test_coma.py +++ b/tests/rl/test_coma.py @@ -32,7 +32,7 @@ from gym import spaces -from malib.utils.episode import Episode +from malib.rollout.episode import Episode from malib.utils.tianshou_batch import Batch from malib.rl.pg import PGPolicy from malib.rl.coma.critic import COMADiscreteCritic diff --git a/tests/rollout/test_pb_rollout_worker.py b/tests/rollout/test_pb_rollout_worker.py index 0b5a4ee3..c8a68352 100644 --- a/tests/rollout/test_pb_rollout_worker.py +++ b/tests/rollout/test_pb_rollout_worker.py @@ -78,7 +78,7 @@ def gen_common_requirements(n_player: int): from gym import spaces from malib.learner.manager import LearnerManager from malib.learner.config import LearnerConfig -from malib.utils.episode import Episode +from malib.rollout.episode import Episode from malib.mocker.mocker_utils import FakeLearner, FakeFeatureHandler diff --git a/tests/structures/test_episode.py b/tests/structures/test_episode.py index 9a6d0793..6b19f3da 100644 --- a/tests/structures/test_episode.py +++ b/tests/structures/test_episode.py @@ -31,7 +31,7 @@ from gym import spaces from malib.utils.typing import AgentID -from malib.utils.episode import Episode +from malib.rollout.episode import Episode from malib.rollout.envs.env import Environment From 030407c16aa0623ff5777c0d1507b134db5e1aa5 Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Thu, 14 Dec 2023 17:46:29 +0800 Subject: [PATCH 24/24] pass for sarl --- .gitignore | 3 +- examples/sarl/ppo_gym.py | 25 ++++++--- malib/backend/dataset_server/data_loader.py | 2 + malib/backend/dataset_server/feature.py | 16 +++++- malib/common/task.py | 17 ++++-- malib/learner/indepdent_learner.py | 4 +- malib/learner/learner.py | 50 +++++++++++++++-- malib/learner/manager.py | 33 +++++++---- malib/rl/a2c/policy.py | 43 +++++++------- malib/rl/a2c/trainer.py | 7 +-- malib/rl/common/policy.py | 26 +++++++++ malib/rl/pg/config.py | 9 +++ malib/rl/pg/policy.py | 13 ++++- malib/rl/pg/trainer.py | 4 +- malib/rl/ppo/__init__.py | 6 +- malib/rl/ppo/config.py | 29 +++++++++- malib/rl/ppo/policy.py | 12 +--- malib/rl/ppo/trainer.py | 1 + malib/rl/random/random_trainer.py | 6 +- malib/rollout/envs/gym/env.py | 25 ++++----- malib/rollout/manager.py | 7 +-- malib/rollout/rolloutworker.py | 2 + malib/scenarios/sarl_scenario.py | 41 ++++++++------ malib/settings.py | 2 +- malib/utils/data.py | 5 ++ malib/utils/stopping_conditions.py | 2 + tests/agents/test_independent_agent.py | 62 ++++++++++++++++----- tests/rollout/test_pb_rollout_worker.py | 4 +- 28 files changed, 322 insertions(+), 134 deletions(-) diff --git a/.gitignore b/.gitignore index 211626d1..219e1e5f 100644 --- a/.gitignore +++ b/.gitignore @@ -134,4 +134,5 @@ dmypy.json _build logs demos -prof/ \ No newline at end of file +prof/ +runs \ No newline at end of file diff --git a/examples/sarl/ppo_gym.py b/examples/sarl/ppo_gym.py index 795ad8e6..6776308c 100644 --- a/examples/sarl/ppo_gym.py +++ b/examples/sarl/ppo_gym.py @@ -23,22 +23,29 @@ class FeatureHandler(BaseFeature): def feature_handler_meta_gen(env_desc, agent_id): - def f(device): + """Return a generator of feature handler meta. + + Args: + env_desc (_type_): _description_ + agent_id (_type_): _description_ + """ + + def f(device="cpu"): # define the data schema _spaces = { Episode.DONE: spaces.Discrete(1), Episode.CUR_OBS: env_desc["observation_spaces"][agent_id], Episode.ACTION: env_desc["action_spaces"][agent_id], - Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32), + Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(), dtype=np.float32), Episode.NEXT_OBS: env_desc["observation_spaces"][agent_id], } # you should know the maximum of replaybuffer before training np_memory = { - k: np.zeros((100,) + v.shape, dtype=v.dtype) for k, v in _spaces.items() + k: np.zeros((10000,) + v.shape, dtype=v.dtype) for k, v in _spaces.items() } - return FeatureHandler(_spaces, np_memory, device) + return FeatureHandler(_spaces, np_memory, device=device) return f @@ -51,7 +58,7 @@ def f(device): args = parser.parse_args() - trainer_config = DEFAULT_CONFIG["training_config"].copy() + trainer_config = DEFAULT_CONFIG.TRAINING_CONIG.copy() trainer_config["total_timesteps"] = int(1e6) trainer_config["use_cuda"] = args.use_cuda @@ -80,13 +87,13 @@ def f(device): rollout_config=RolloutConfig( num_workers=1, ), - agent_mapping_func=lambda agent: agent, stopping_conditions={ - "training": {"max_iteration": int(1e10)}, - "rollout": {"max_iteration": 1000, "minimum_reward_improvement": 1.0}, + "golbal": {"max_iteration": 1000, "minimum_reward_improvement": 1.0}, + "rollout": {"max_iteration": 1}, + "training": {"max_iteration": 1}, }, ) - results = sarl_scenario.execution_plan(scenario=scenario, verbose=True) + results = sarl_scenario.execution_plan(scenario=scenario, verbose=False) print(results) diff --git a/malib/backend/dataset_server/data_loader.py b/malib/backend/dataset_server/data_loader.py index 6cf9ff8d..11412bbb 100644 --- a/malib/backend/dataset_server/data_loader.py +++ b/malib/backend/dataset_server/data_loader.py @@ -39,6 +39,8 @@ def __init__( self.max_message_length = max_message_length def start_server(self): + """Launch a dataset service.""" + self.server_port = find_free_port() self.server = service_wrapper( self.grpc_thread_num_workers, diff --git a/malib/backend/dataset_server/feature.py b/malib/backend/dataset_server/feature.py index 8df9ce17..9ea582dc 100644 --- a/malib/backend/dataset_server/feature.py +++ b/malib/backend/dataset_server/feature.py @@ -18,14 +18,26 @@ def __init__( block_size: int = None, device: str = "cpu", ) -> None: + """Constructing a feature handler for data preprocessing. + + Args: + spaces (Dict[str, spaces.Space]): A dict of spaces + np_memory (Dict[str, np.ndarray]): A dict of memory placeholders + block_size (int, optional): Block size. Defaults to None. + device (str, optional): Device name. Defaults to "cpu". + """ + self.rw_lock = rwlock.RWLockFair() self._device = device self._spaces = spaces - self._block_size = min(block_size or np.iinfo(np.longlong).max, list(np_memory.values())[0].shape[0]) + self._block_size = min( + block_size or np.iinfo(np.longlong).max, + list(np_memory.values())[0].shape[0], + ) self._available_size = 0 self._flag = 0 self._shared_memory = { - k: torch.from_numpy(v[:self._block_size]).to(device).share_memory_() + k: torch.from_numpy(v[: self._block_size]).to(device).share_memory_() for k, v in np_memory.items() } diff --git a/malib/common/task.py b/malib/common/task.py index 15279f6c..0200eea6 100644 --- a/malib/common/task.py +++ b/malib/common/task.py @@ -18,9 +18,9 @@ class Task: @dataclass class RolloutTask(Task): - strategy_specs: Dict[str, Any] = field(default_factory=dict()) - stopping_conditions: Dict[str, Any] = field(default_factory=dict()) - data_entrypoints: Dict[str, Any] = field(default_factory=dict()) + strategy_specs: Dict[str, Any] = field(default_factory=dict) + stopping_conditions: Dict[str, Any] = field(default_factory=dict) + data_entrypoints: Dict[str, Any] = field(default_factory=dict) @classmethod def from_raw( @@ -36,15 +36,20 @@ def from_raw( @dataclass class OptimizationTask(Task): - stop_conditions: Dict[str, Any] + stopping_conditions: Dict[str, Any] """stopping conditions for optimization task, e.g., max iteration, max time, etc.""" - strategy_specs: Dict[str, Any] = field(default_factory=dict()) - """a dict of strategy specs, which defines the strategy spec for each agent.""" + # strategy_specs: Dict[str, Any] = field(default_factory=dict) + # """a dict of strategy specs, which defines the strategy spec for each agent.""" active_agents: List[AgentID] = field(default_factory=list) """a list of active agents, which defines the agents that will be trained in this optimization task. None for all""" + save_interval: int = 2 + """the interval of saving checkpoints""" + + model_dir: str = "" + @classmethod def from_raw( cls, dict_style: Union[Dict[str, Any], "OptimizationTask"], **kwargs diff --git a/malib/learner/indepdent_learner.py b/malib/learner/indepdent_learner.py index ae029b48..2a6b45b8 100644 --- a/malib/learner/indepdent_learner.py +++ b/malib/learner/indepdent_learner.py @@ -32,5 +32,7 @@ class IndependentAgent(Learner): - def multiagent_post_process(self, batch: Dict[AgentID, Dict[str, torch.Tensor]]) -> Dict[str, Any]: + def multiagent_post_process( + self, batch: Dict[AgentID, Dict[str, torch.Tensor]] + ) -> Dict[str, Any]: return to_torch(batch, device=self.device) diff --git a/malib/learner/learner.py b/malib/learner/learner.py index ffd6bf5c..45251e5a 100644 --- a/malib/learner/learner.py +++ b/malib/learner/learner.py @@ -28,7 +28,9 @@ import time import traceback +import os +import json import torch import ray @@ -40,6 +42,7 @@ from malib.utils.logging import Logger from malib.utils.tianshou_batch import Batch from malib.utils.monitor import write_to_tensorboard +from malib.utils.stopping_conditions import get_stopper from malib.remote.interface import RemoteInterface from malib.common.task import OptimizationTask from malib.common.strategy_spec import StrategySpec @@ -107,9 +110,18 @@ def __init__( self._governed_agents = governed_agents self._strategy_spec = strategy_spec self._custom_config = custom_config + # Do not add policy to strategy spec now, since we only update it + # when new checkpoint is ready. self._policy = strategy_spec.gen_policy(device=device) self._summary_writer = tensorboard.SummaryWriter(log_dir=log_dir) + self._model_dir = os.path.join(log_dir, "models") + + if not os.path.exists(self._model_dir): + os.makedirs(self._model_dir) + + # save metastate to current log_dir + self.save_metastate(log_dir) # load policy for trainer self._trainer: Trainer = algorithm.trainer( @@ -140,6 +152,16 @@ def __init__( self._total_epoch = 0 self._verbose = verbose + def save_metastate(self, log_dir): + with open("{}/metastate.json".format(log_dir), "w") as f: + json.dump( + { + "runtime_id": self._runtime_id, + "governed_agents": self.governed_agents, + }, + f, + ) + @abstractmethod def multiagent_post_process( self, @@ -215,13 +237,12 @@ def get_interface_state(self) -> Dict[str, Any]: "total_epoch": self._total_epoch, "policy_num": len(self._strategy_spec), } - + def step(self, prints: bool = False): while ( - self.data_loader.dataset.readable_block_size - < self.data_loader.batch_size + self.data_loader.dataset.readable_block_size < self.data_loader.batch_size ): - time.sleep(1) + return for data in self.data_loader: batch_dict = self.multiagent_post_process(data) @@ -239,10 +260,13 @@ def step(self, prints: bool = False): prefix=f"Learner/{self._runtime_id}", ) if prints: - print(self._total_step, step_info) + print(self._total_epoch, self._total_step, step_info) self._total_epoch += 1 + # TODO(ming): should merge step before return + return step_info_list + def train(self, task: OptimizationTask) -> Dict[str, Any]: """Executes a optimization task and returns the final interface state. @@ -255,10 +279,22 @@ def train(self, task: OptimizationTask) -> Dict[str, Any]: """ self.set_running(True) + stopper = get_stopper(task.stopping_conditions) try: while self.is_running(): - self.step() + results = self.step() + if results is None: # indicates the dataset is not ready + break + if self._total_epoch % task.save_interval == 0: + ck_path = os.path.join( + self._model_dir, f"checkpoint-{self._total_epoch}.ckpt" + ) + torch.save(self.policy.state_dict(), ck_path) + Logger.info("save checkpoint to {}".format(ck_path)) + self.strategy_spec.register_policy_id(ck_path) + if stopper.should_stop(results): + break except Exception as e: Logger.warning( f"training pipe is terminated. caused by: {traceback.format_exc()}" @@ -270,6 +306,8 @@ def train(self, task: OptimizationTask) -> Dict[str, Any]: self._total_epoch, self._total_step ) ) + # hard set False to stop training + self.set_running(False) return self.get_interface_state() def reset(self): diff --git a/malib/learner/manager.py b/malib/learner/manager.py index 8ed48f09..2a90f83a 100644 --- a/malib/learner/manager.py +++ b/malib/learner/manager.py @@ -117,7 +117,6 @@ def __init__( observation_space=group_info["observation_space"][rid], action_space=group_info["action_space"][rid], algorithm=algorithm, - agent_mapping_func=agent_mapping_func, governed_agents=agents, custom_config=learner_config.custom_config, feature_handler_gen=learner_config.feature_handler_meta_gen( @@ -152,6 +151,7 @@ def __init__( self._agent_mapping_func = agent_mapping_func self._learners = learners self._thread_pool = ThreadPoolExecutor(max_workers=len(learners)) + # FIXME(ming): deprecated self._stopping_conditions = stopping_conditions # init strategy spec @@ -211,6 +211,12 @@ def runtime_ids(self) -> Tuple[str]: return self._runtime_ids + def get_strategy_specs(self) -> Dict[str, StrategySpec]: + values = ray.get( + [v.get_strategy_spec.remote() for v in self._learners.values()] + ) + return dict(zip(self._learners.keys(), values)) + def add_policies( self, interface_ids: Sequence[str] = None, n: Union[int, Dict[str, int]] = 1 ) -> Dict[str, Type[StrategySpec]]: @@ -240,7 +246,7 @@ def add_policies( return strategy_spec_dict - def submit(self, task: OptimizationTask): + def submit(self, task: OptimizationTask, wait: bool = False): """Submit a training task, the manager will distribute it to the corresponding learners. Args: @@ -248,14 +254,21 @@ def submit(self, task: OptimizationTask): """ # retrieve learners with active agents - for aid in task.active_agents: - rid = self._agent_mapping_func(aid) - if rid not in self._learners: - raise RuntimeError(f"Agent {aid} is not registered in training manager") - else: - learner = self._learners[rid] - ray_task = learner.train.remote(task) - self.pending_tasks.append(ray_task) + rids = ( + list(self._learners.keys()) + if task.active_agents is None + else [self._agent_mapping_func(aid) for aid in task.active_agents] + ) + + for rid in rids: + learner = self._learners[rid] + ray_task = learner.train.remote(task) + self.pending_tasks.append(ray_task) + if wait: + result_list = self.wait() + return result_list + else: + return None def retrive_results(self) -> Generator: """Return a generator of results. diff --git a/malib/rl/a2c/policy.py b/malib/rl/a2c/policy.py index febd3a9f..053ad872 100644 --- a/malib/rl/a2c/policy.py +++ b/malib/rl/a2c/policy.py @@ -33,40 +33,41 @@ from malib.rl.pg import PGPolicy from malib.models.torch import continuous, discrete +from malib.models.torch.net import ActorCritic class A2CPolicy(PGPolicy): - def __init__( - self, - observation_space: spaces.Space, - action_space: spaces.Space, - model_config: Dict[str, Any], - custom_config: Dict[str, Any], - **kwargs - ): - super().__init__( - observation_space, action_space, model_config, custom_config, **kwargs - ) - - preprocess_net: nn.Module = self.actor.preprocess - if isinstance(action_space, spaces.Discrete): - self.critic = discrete.Critic( + def create_model(self): + # since a PGPolicy creates a model as an Actor. + actor = super().create_model() + + preprocess_net: nn.Module = actor.preprocess + if isinstance(self.action_space, spaces.Discrete): + critic = discrete.Critic( preprocess_net=preprocess_net, - hidden_sizes=model_config["hidden_sizes"], + hidden_sizes=self.model_config["hidden_sizes"], device=self.device, ) - elif isinstance(action_space, spaces.Box): - self.critic = continuous.Critic( + elif isinstance(self.action_space, spaces.Box): + critic = continuous.Critic( preprocess_net=preprocess_net, - hidden_sizes=model_config["hidden_sizes"], + hidden_sizes=self.model_config["hidden_sizes"], device=self.device, ) else: raise TypeError( - "Unexpected action space type: {}".format(type(action_space)) + "Unexpected action space type: {}".format(type(self.action_space)) ) - self.register_state(self.critic, "critic") + return ActorCritic(actor, critic) + + @property + def actor(self): + return self.model.actor + + @property + def critic(self): + return self.model.critic def value_function(self, observation: torch.Tensor, evaluate: bool, **kwargs): """Compute values of critic.""" diff --git a/malib/rl/a2c/trainer.py b/malib/rl/a2c/trainer.py index 9531b3c0..127be650 100644 --- a/malib/rl/a2c/trainer.py +++ b/malib/rl/a2c/trainer.py @@ -43,13 +43,10 @@ class A2CTrainer(Trainer): def setup(self): - parameter_dict = self.policy.parameters() - # concate parameters - parameters = set(itertools.chain(*parameter_dict.values())) self.optimizer = getattr(optim, self.training_config["optimizer"])( - parameters, lr=self.training_config["lr"] + self.policy.parameters(), lr=self.training_config["lr"] ) - self.parameters = parameters + self.parameters = self.policy.parameters() self.lr_scheduler: torch.optim.lr_scheduler.LambdaLR = None # runtime return averaging diff --git a/malib/rl/common/policy.py b/malib/rl/common/policy.py index 2f0b1d38..c30e2d5a 100644 --- a/malib/rl/common/policy.py +++ b/malib/rl/common/policy.py @@ -118,6 +118,32 @@ def __init__( def create_model(self) -> nn.Module: raise NotImplementedError + @property + def actor(self) -> nn.Module: + """Return an Actor network. + + Raises: + NotImplementedError: Not implemented error. + + Returns: + nn.Module: An Actor. + """ + + raise NotImplementedError + + @property + def critic(self) -> nn.Module: + """Return a Critic network. + + Raises: + NotImplementedError: Not implemented error. + + Returns: + nn.Module: A Critic. + """ + + return NotImplementedError + @property def dist_fn(self) -> Distribution: return self._dist_fn diff --git a/malib/rl/pg/config.py b/malib/rl/pg/config.py index 72db27f7..d0f3bdf6 100644 --- a/malib/rl/pg/config.py +++ b/malib/rl/pg/config.py @@ -33,6 +33,15 @@ class Config: "minibatch": 2, "batch_size": 32, "gamma": 0.99, + "repeats": 1, + "ratio_clip": 0.2, + "dual_clip": None, + "vf_ratio": 0.1, + "ent_ratio": 0.01, + "use_adv_norm": False, + "adv_norm_eps": 1e-8, + "use_grad_norm": False, + "use_value_clip": False, } CUSTOM_CONFIG = {} diff --git a/malib/rl/pg/policy.py b/malib/rl/pg/policy.py index a1a075fc..0c39dcab 100644 --- a/malib/rl/pg/policy.py +++ b/malib/rl/pg/policy.py @@ -102,6 +102,17 @@ def create_model(self): "Unexpected action space type: {}".format(type(self.action_space)) ) + @property + def actor(self): + if isinstance(self._model, nn.Module): + return self._model + else: + return self._model.actor + + @property + def critic(self): + raise RuntimeError("PG has no critic network can be called!") + def value_function(self, observation: torch.Tensor, evaluate: bool, **kwargs): """Compute values of critic.""" @@ -116,7 +127,7 @@ def compute_action( **kwargs ) -> PolicyReturn: with torch.inference_mode(): - logits, hidden = self.model(observation, state=hidden_state) + logits, hidden = self.actor(observation, state=hidden_state) if isinstance(logits, tuple): dist = self.dist_fn.proba_distribution(*logits) else: diff --git a/malib/rl/pg/trainer.py b/malib/rl/pg/trainer.py index 80533ae9..b01273cc 100644 --- a/malib/rl/pg/trainer.py +++ b/malib/rl/pg/trainer.py @@ -42,9 +42,7 @@ class PGTrainer(Trainer): def __init__(self, training_config: Dict[str, Any], policy_instance: Policy = None): # merge from default - training_config = merge_dicts( - Config.TRAINING_CONFIG, training_config or {} - ) + training_config = merge_dicts(Config.TRAINING_CONFIG, training_config or {}) super().__init__(training_config, policy_instance) def setup(self): diff --git a/malib/rl/ppo/__init__.py b/malib/rl/ppo/__init__.py index 31ec9130..7ae63672 100644 --- a/malib/rl/ppo/__init__.py +++ b/malib/rl/ppo/__init__.py @@ -24,4 +24,8 @@ from .policy import PPOPolicy from .trainer import PPOTrainer -from .config import DEFAULT_CONFIG +from .config import Config + +POLICY = PPOPolicy +TRAINER = PPOTrainer +DEFAULT_CONFIG = Config diff --git a/malib/rl/ppo/config.py b/malib/rl/ppo/config.py index bde9a653..5ca2d51f 100644 --- a/malib/rl/ppo/config.py +++ b/malib/rl/ppo/config.py @@ -1 +1,28 @@ -DEFAULT_CONFIG = {} +class Config: + + TRAINING_CONIG = { + "gae_lambda": 0.95, + "optimizer": "Adam", + "lr": 1e-4, + "reward_norm": None, + "n_repeat": 2, + "minibatch": 2, + "batch_size": 32, + "gamma": 0.99, + "repeats": 1, + "ratio_clip": 0.2, + "dual_clip": None, + "vf_ratio": 0.1, + "ent_ratio": 0.01, + "use_adv_norm": False, + "adv_norm_eps": 1e-8, + "use_grad_norm": False, + "use_value_clip": False, + } + + CUSTOM_CONFIG = {} + + MODEL_CONFIG = { + "preprocess_net": {"net_type": None, "config": {"hidden_sizes": [64]}}, + "hidden_sizes": [64], + } diff --git a/malib/rl/ppo/policy.py b/malib/rl/ppo/policy.py index 26e06d3e..27c8ff4c 100644 --- a/malib/rl/ppo/policy.py +++ b/malib/rl/ppo/policy.py @@ -32,14 +32,4 @@ class PPOPolicy(A2CPolicy): - def __init__( - self, - observation_space: spaces.Space, - action_space: spaces.Space, - model_config: Dict[str, Any], - custom_config: Dict[str, Any], - **kwargs - ): - super().__init__( - observation_space, action_space, model_config, custom_config, **kwargs - ) + pass diff --git a/malib/rl/ppo/trainer.py b/malib/rl/ppo/trainer.py index 126946c8..9c6a159f 100644 --- a/malib/rl/ppo/trainer.py +++ b/malib/rl/ppo/trainer.py @@ -44,6 +44,7 @@ def train(self, batch: Batch) -> Dict[str, List[float]]: use_grad_norm = self.training_config["use_grad_norm"] use_value_clip = self.training_config["use_value_clip"] + # XXX(ming): or we should keep a list of them losses, clip_losses, vf_losses, ent_losses = 0.0, 0.0, 0.0, 0.0 for step in range(repeats): diff --git a/malib/rl/random/random_trainer.py b/malib/rl/random/random_trainer.py index fa7c1545..0a31dff3 100644 --- a/malib/rl/random/random_trainer.py +++ b/malib/rl/random/random_trainer.py @@ -18,12 +18,10 @@ def __init__(self, training_config: Dict[str, Any], policy_instance: Policy = No def post_process(self, batch: Batch, agent_filter: Sequence[AgentID]) -> Batch: return batch - + def train(self, batch: Batch) -> Dict[str, Any]: time.sleep(random.random()) - return { - "loss": random.random() - } + return {"loss": random.random()} def setup(self): self.optimizer: Type[optim.Optimizer] = getattr( diff --git a/malib/rollout/envs/gym/env.py b/malib/rollout/envs/gym/env.py index 90845dbc..ef1863bc 100644 --- a/malib/rollout/envs/gym/env.py +++ b/malib/rollout/envs/gym/env.py @@ -31,29 +31,26 @@ class GymEnv(Environment): """Single agent gym envrionment""" def __init__(self, **configs): - super(GymEnv, self).__init__(**configs) - - env_id = self._configs["env_id"] - scenario_configs = self._configs.get("scenario_configs", {}) + env_id = configs["env_id"] + scenario_configs = configs.get("scenario_configs", {}) - self.is_sequential = False self._env = gym.make(env_id, **scenario_configs) self._default_agent = "agent" + + super(GymEnv, self).__init__(**configs) + self._observation_spaces = {self._default_agent: self._env.observation_space} self._action_spaces = {self._default_agent: self._env.action_space} self._trainable_agents = [self._default_agent] - @property - def possible_agents(self) -> List[AgentID]: - return [self._default_agent] + def register_action_spaces(self): + return {self._default_agent: self._env.action_space} - @property - def observation_spaces(self) -> Dict[AgentID, gym.Space]: - return self._observation_spaces + def register_agents(self): + return [self._default_agent] - @property - def action_spaces(self) -> Dict[AgentID, gym.Space]: - return self._action_spaces + def register_observation_spaces(self): + return {self._default_agent: self._env.observation_space} def time_step( self, actions: Dict[AgentID, Any] diff --git a/malib/rollout/manager.py b/malib/rollout/manager.py index 146d6f4d..70302245 100644 --- a/malib/rollout/manager.py +++ b/malib/rollout/manager.py @@ -57,7 +57,7 @@ def validate_strategy_specs(specs: Dict[str, StrategySpec]): for rid, spec in specs.items(): if len(spec) < 1: - raise ValueError(f"Empty spec for runtime_id={rid}") + continue # check prob list expected_prob_list = spec.meta_data.get( "prob_list", [1 / len(spec)] * len(spec) @@ -127,10 +127,7 @@ def __init__( self._runtime_ids = tuple(group_info["agent_groups"].keys()) self._group_info = group_info - assert ( - "rollout" in stopping_conditions - ), f"Stopping conditions should contain `rollout`: {stopping_conditions}" - + # FIXME(ming): deprecated self.stopping_conditions = stopping_conditions @property diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index 4137d356..41086193 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -31,10 +31,12 @@ import ray +from pprint import pprint from ray.util import ActorPool from torch.utils import tensorboard from malib import settings +from malib.utils.logging import Logger from malib.utils.typing import AgentID from malib.utils.stopping_conditions import get_stopper from malib.utils.monitor import write_to_tensorboard diff --git a/malib/scenarios/sarl_scenario.py b/malib/scenarios/sarl_scenario.py index b50ffcfb..69c042e6 100644 --- a/malib/scenarios/sarl_scenario.py +++ b/malib/scenarios/sarl_scenario.py @@ -60,16 +60,17 @@ def __init__( stopping_conditions, ) self.num_policy_each_interface = 1 + self.num_worker = 1 self.resource_config = resource_config or {"training": None, "rollout": None} def create_global_stopper(self) -> StoppingCondition: - return get_stopper(self.stopping_conditions) + return get_stopper(self.stopping_conditions.get("global", {})) def execution_plan(scenario: SARLScenario, verbose: bool = True): # TODO(ming): simplize the initialization of training and rollout manager with a scenario instance as input learner_manager = LearnerManager( - stopping_conditions=scenario.stopping_conditions, + stopping_conditions=scenario.stopping_conditions["training"], algorithm=scenario.algorithm, env_desc=scenario.env_desc, agent_mapping_func=scenario.agent_mapping_func, @@ -89,8 +90,16 @@ def execution_plan(scenario: SARLScenario, verbose: bool = True): verbose=verbose, ) + startegy_spces = learner_manager.get_strategy_specs() + # uopdate rollout_config with inference client id + scenario.rollout_config.inference_entry_points = ( + inference_manager.inference_entry_points + ) + rollout_manager = RolloutWorkerManager( - stopping_conditions=scenario.stopping_conditions, + stopping_conditions=scenario.stopping_conditions["rollout"], + # FIXME(ming): change num_worker to parallel_task_num, + # or equivalent to the agent number. num_worker=scenario.num_worker, group_info=scenario.group_info, rollout_config=scenario.rollout_config, @@ -103,37 +112,35 @@ def execution_plan(scenario: SARLScenario, verbose: bool = True): league = League(learner_manager, rollout_manager, inference_manager) - # TODO(ming): further check is needed optimization_task = OptimizationTask( - stop_conditions=scenario.stopping_conditions["training"], - strategy_specs=None, + stopping_conditions=scenario.stopping_conditions["training"], active_agents=None, ) rollout_task = RolloutTask( - strategy_specs=None, + strategy_specs=startegy_spces, stopping_conditions=scenario.stopping_conditions["rollout"], - data_entrypoint_mapping=learner_manager.data_entrypoints, + data_entrypoints=learner_manager.data_entrypoints, ) - evaluation_task = RolloutTask( - strategy_specs=None, - ) + evaluation_task = RolloutTask() stopper = scenario.create_global_stopper() epoch_cnt = 0 while True: rollout_results = league.submit(rollout_task, wait=True) + print("Results of Rollout: {}".format(rollout_results)) training_results = league.submit(optimization_task, wait=True) - evaluation_results = league.submit(evaluation_task, wait=True) + print("Results of Training: {}".format(training_results)) + evaluation_results = {} + # league.submit(evaluation_task, wait=True) + # print("Results of evaluation: {}".format(evaluation_results)) epoch_cnt += 1 - if stopper.should_stop( - evaluation_results, training_results, rollout_results, epoch_cnt - ): + if stopper.should_stop(evaluation_results): break - if epoch_cnt % scenario.save_interval == 0: - league.save_checkpoint(global_step=epoch_cnt) + # if epoch_cnt % scenario.save_interval == 0: + # league.save_checkpoint(global_step=epoch_cnt) results = league.get_results() league.terminate() diff --git a/malib/settings.py b/malib/settings.py index 428a46b2..bf594d51 100644 --- a/malib/settings.py +++ b/malib/settings.py @@ -1,3 +1,3 @@ import logging -LOG_LEVEL = logging.DEBUG +LOG_LEVEL = logging.INFO diff --git a/malib/utils/data.py b/malib/utils/data.py index a223011a..8e5ff1d9 100644 --- a/malib/utils/data.py +++ b/malib/utils/data.py @@ -151,6 +151,11 @@ def to_torch( if dtype is not None: x = x.type(dtype) return x + elif isinstance(x, dict): + new_x = {} + for k, v in x.items(): + new_x[k] = to_torch(v, dtype, device) + return new_x elif isinstance(x, torch.Tensor): # second often case if dtype is not None: x = x.type(dtype) diff --git a/malib/utils/stopping_conditions.py b/malib/utils/stopping_conditions.py index 27584ede..14d9dd8b 100644 --- a/malib/utils/stopping_conditions.py +++ b/malib/utils/stopping_conditions.py @@ -104,6 +104,8 @@ def get_stopper(conditions: Dict[str, Any]): ) if "max_iteration" in conditions: stoppings.append(MaxIterationStopping(conditions["max_iteration"])) + if len(conditions) == 0: + stoppings.append(NoStoppingCondition()) if len(stoppings) == 0: raise NotImplementedError(f"unkonw stopping condition type: {conditions}") diff --git a/tests/agents/test_independent_agent.py b/tests/agents/test_independent_agent.py index 6d1b300a..ce15a477 100644 --- a/tests/agents/test_independent_agent.py +++ b/tests/agents/test_independent_agent.py @@ -37,7 +37,7 @@ def construct_dataset( - feature_handler=None, feature_handler_cls=None, feature_handler_kwargs=None + feature_handler=None, feature_handler_cls=None, feature_handler_kwargs={} ): return DynamicDataset( grpc_thread_num_workers=1, @@ -84,22 +84,39 @@ def construct_algorithm(module_path, model_config={}, trainer_config={}): ) +from typing import Dict + +import time + +from gym import spaces +from threading import Thread, Event from malib.mocker.mocker_utils import FakeFeatureHandler from malib.rollout.episode import Episode +from malib.backend.dataset_server.utils import send_data + + +def datasend_thread(entrypoint: str, batch: Dict[str, np.ndarray], event: Event): + while not event.is_set(): + send_data(data=batch, entrypoint=entrypoint) + event.wait(1) -@pytest.mark.parametrize("module_path", [ - 'malib.rl.random' -]) +@pytest.mark.parametrize("module_path", ["malib.rl.random"]) class TestIndependentAgent: def test_learner_with_outer_dataset(self, module_path): - obs_space = gym.spaces.Box(low=-1, high=1, shape=(1, 1), dtype=np.float32) - act_space = gym.spaces.Discrete(2) + obs_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32) + act_space = spaces.Discrete(2) + block_size = 300 + np_memory = { - Episode.CUR_OBS: np.zeros() + Episode.CUR_OBS: np.zeros( + (block_size,) + obs_space.shape, dtype=np.float32 + ), + Episode.ACTION: np.zeros((block_size,), dtype=np.int32), } governed_agents = ["default"] + # we construct a dataset outside the learner dataset = construct_dataset( feature_handler=FakeFeatureHandler( { @@ -107,23 +124,42 @@ def test_learner_with_outer_dataset(self, module_path): Episode.ACTION: act_space, }, np_memory, - block_size=100, + block_size=block_size, device="cpu", ) ) algorithm = construct_algorithm(module_path) learner = construct_learner( - algorithm, governed_agents, custom_config=None, dataset=dataset + obs_space, + act_space, + algorithm, + governed_agents, + custom_config=None, + dataset=dataset, ) - for _ in range(10): + # start a thread to generate data + dataset.start_server() + batch = dataset.feature_handler.generate_batch(10) + event = Event() + thread = Thread(target=datasend_thread, args=(dataset.entrypoint, batch, event)) + thread.start() + + # run 10 trails + for i in range(10): learner.step(prints=True) + time.sleep(0.2) + print("-------- {}/10 --------".format(i + 1)) + + event.set() + thread.join() + dataset.close() - def test_learner_with_outer_feature_handler(self): + def test_learner_with_outer_feature_handler(self, module_path): pass - def test_learner_with_feature_handler_gen(self): + def test_learner_with_feature_handler_gen(self, module_path): pass - def test_learner_with_dataset_gen(self): + def test_learner_with_dataset_gen(self, module_path): pass diff --git a/tests/rollout/test_pb_rollout_worker.py b/tests/rollout/test_pb_rollout_worker.py index c8a68352..63c0ad9c 100644 --- a/tests/rollout/test_pb_rollout_worker.py +++ b/tests/rollout/test_pb_rollout_worker.py @@ -173,7 +173,7 @@ def test_rollout_with_data_entrypoint(self, n_player: int): ), log_dir=log_dir, ) - + # create a batch of inference servers, serve for rollout workers (shared among them) infer_manager = InferenceManager( group_info=group_info, algorithm=algorithm, @@ -196,7 +196,7 @@ def test_rollout_with_data_entrypoint(self, n_player: int): ) for agent in agents } - + # create a single PB rollout worker, for task execution worker = PBRolloutWorker( env_desc=env_desc, agent_groups=group_info["agent_groups"],