From b340c85fd246a4ec85f3c33407f521b024d0de3c Mon Sep 17 00:00:00 2001 From: mahon94 Date: Fri, 22 Oct 2021 10:34:56 -0700 Subject: [PATCH 1/6] Init: actor.forward outputs separate deterministic actions --- .../mlagents/trainers/torch/action_model.py | 29 +++++++++++++++++-- .../mlagents/trainers/torch/distributions.py | 13 +++++++++ .../trainers/torch/model_serialization.py | 6 ++++ ml-agents/mlagents/trainers/torch/networks.py | 18 ++++++++++-- 4 files changed, 61 insertions(+), 5 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/action_model.py b/ml-agents/mlagents/trainers/torch/action_model.py index c5de586e4d..a27b1ffdef 100644 --- a/ml-agents/mlagents/trainers/torch/action_model.py +++ b/ml-agents/mlagents/trainers/torch/action_model.py @@ -9,6 +9,10 @@ from mlagents.trainers.torch.agent_action import AgentAction from mlagents.trainers.torch.action_log_probs import ActionLogProbs from mlagents_envs.base_env import ActionSpec +from mlagents_envs import logging_util + +logger = logging_util.get_logger(__name__) + EPSILON = 1e-7 # Small value to avoid divide by zero @@ -161,23 +165,42 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten """ dists = self._get_dists(inputs, masks) continuous_out, discrete_out, action_out_deprecated = None, None, None + deter_continuous_out, deter_discrete_out = None, None # deterministic actions if self.action_spec.continuous_size > 0 and dists.continuous is not None: continuous_out = dists.continuous.exported_model_output() - action_out_deprecated = dists.continuous.exported_model_output() + action_out_deprecated = continuous_out + deter_continuous_out = dists.continuous.deterministic_sample() if self._clip_action_on_export: continuous_out = torch.clamp(continuous_out, -3, 3) / 3 - action_out_deprecated = torch.clamp(action_out_deprecated, -3, 3) / 3 + action_out_deprecated = continuous_out + deter_continuous_out = torch.clamp(deter_continuous_out, -3, 3) / 3 if self.action_spec.discrete_size > 0 and dists.discrete is not None: + logger.info( + f"dist: {[discrete_dist.probs for discrete_dist in dists.discrete]}" + ) # TODO: remove discrete_out_list = [ discrete_dist.exported_model_output() for discrete_dist in dists.discrete ] + logger.info(f"discretelist {discrete_out_list}") # TODO: remove discrete_out = torch.cat(discrete_out_list, dim=1) action_out_deprecated = torch.cat(discrete_out_list, dim=1) + deter_discrete_out_list = [ + discrete_dist.deterministic_sample() for discrete_dist in dists.discrete + ] + logger.info(f"deterlist {deter_discrete_out_list}") # TODO: remove + deter_discrete_out = torch.cat(deter_discrete_out_list, dim=1) + # deprecated action field does not support hybrid action if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0: action_out_deprecated = None - return continuous_out, discrete_out, action_out_deprecated + return ( + continuous_out, + discrete_out, + action_out_deprecated, + deter_continuous_out, + deter_discrete_out, + ) def forward( self, inputs: torch.Tensor, masks: torch.Tensor diff --git a/ml-agents/mlagents/trainers/torch/distributions.py b/ml-agents/mlagents/trainers/torch/distributions.py index 1f5960d10b..00aee3aaa0 100644 --- a/ml-agents/mlagents/trainers/torch/distributions.py +++ b/ml-agents/mlagents/trainers/torch/distributions.py @@ -16,6 +16,13 @@ def sample(self) -> torch.Tensor: """ pass + @abc.abstractmethod + def deterministic_sample(self) -> torch.Tensor: + """ + Return the most probable sample from this distribution. + """ + pass + @abc.abstractmethod def log_prob(self, value: torch.Tensor) -> torch.Tensor: """ @@ -59,6 +66,9 @@ def sample(self): sample = self.mean + torch.randn_like(self.mean) * self.std return sample + def deterministic_sample(self): + return self.mean + def log_prob(self, value): var = self.std ** 2 log_scale = torch.log(self.std + EPSILON) @@ -113,6 +123,9 @@ def __init__(self, logits): def sample(self): return torch.multinomial(self.probs, 1) + def deterministic_sample(self): + return torch.argmax(self.probs).reshape((1, 1)) + def pdf(self, value): # This function is equivalent to torch.diag(self.probs.T[value.flatten().long()]), # but torch.diag is not supported by ONNX export. diff --git a/ml-agents/mlagents/trainers/torch/model_serialization.py b/ml-agents/mlagents/trainers/torch/model_serialization.py index 0fa946280c..6924f3a5be 100644 --- a/ml-agents/mlagents/trainers/torch/model_serialization.py +++ b/ml-agents/mlagents/trainers/torch/model_serialization.py @@ -56,10 +56,13 @@ class TensorNames: recurrent_output = "recurrent_out" memory_size = "memory_size" version_number = "version_number" + continuous_action_output_shape = "continuous_action_output_shape" discrete_action_output_shape = "discrete_action_output_shape" continuous_action_output = "continuous_actions" discrete_action_output = "discrete_actions" + deterministic_continuous_action_output = "deterministic_continuous_actions" + deterministic_discrete_action_output = "deterministic_discrete_actions" # Deprecated TensorNames entries for backward compatibility is_continuous_control_deprecated = "is_continuous_control" @@ -122,6 +125,7 @@ def __init__(self, policy): self.output_names += [ TensorNames.continuous_action_output, TensorNames.continuous_action_output_shape, + TensorNames.deterministic_continuous_action_output, ] self.dynamic_axes.update( {TensorNames.continuous_action_output: {0: "batch"}} @@ -130,6 +134,7 @@ def __init__(self, policy): self.output_names += [ TensorNames.discrete_action_output, TensorNames.discrete_action_output_shape, + TensorNames.deterministic_discrete_action_output, ] self.dynamic_axes.update({TensorNames.discrete_action_output: {0: "batch"}}) @@ -164,5 +169,6 @@ def export_policy_model(self, output_filepath: str) -> None: input_names=self.input_names, output_names=self.output_names, dynamic_axes=self.dynamic_axes, + verbose=True, # TODO: remove ) logger.info(f"Exported {onnx_output_path}") diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 4a2e1dafc6..a79f5238f0 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -571,6 +571,9 @@ def forward( class SimpleActor(nn.Module, Actor): MODEL_EXPORT_VERSION = 3 # Corresponds to ModelApiVersion.MLAgents2_0 + is_stochastic_action_sampling = ( + True + ) # TODO: this should be a user input both for training and inference def __init__( self, @@ -582,6 +585,7 @@ def __init__( ): super().__init__() self.action_spec = action_spec + # self.is_continuous_int_deprecated = is_stochastic_action_sampling # TODO: self.version_number = torch.nn.Parameter( torch.Tensor([self.MODEL_EXPORT_VERSION]), requires_grad=False ) @@ -675,12 +679,22 @@ def forward( cont_action_out, disc_action_out, action_out_deprecated, + deter_cont_action_out, + deter_disc_action_out, ) = self.action_model.get_action_out(encoding, masks) export_out = [self.version_number, self.memory_size_vector] if self.action_spec.continuous_size > 0: - export_out += [cont_action_out, self.continuous_act_size_vector] + export_out += [ + cont_action_out, + self.continuous_act_size_vector, + deter_cont_action_out, + ] if self.action_spec.discrete_size > 0: - export_out += [disc_action_out, self.discrete_act_size_vector] + export_out += [ + disc_action_out, + self.discrete_act_size_vector, + deter_disc_action_out, + ] if self.network_body.memory_size > 0: export_out += [memories_out] return tuple(export_out) From 07c11d8c52cadc49622a3e3cfa08dd916c0f2035 Mon Sep 17 00:00:00 2001 From: mahon94 Date: Thu, 28 Oct 2021 12:30:46 -0700 Subject: [PATCH 2/6] fix tensor shape for discrete actions --- ml-agents/mlagents/trainers/torch/distributions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/torch/distributions.py b/ml-agents/mlagents/trainers/torch/distributions.py index 00aee3aaa0..c60426c998 100644 --- a/ml-agents/mlagents/trainers/torch/distributions.py +++ b/ml-agents/mlagents/trainers/torch/distributions.py @@ -124,7 +124,7 @@ def sample(self): return torch.multinomial(self.probs, 1) def deterministic_sample(self): - return torch.argmax(self.probs).reshape((1, 1)) + return torch.argmax(self.probs, dim=1, keepdim=True) def pdf(self, value): # This function is equivalent to torch.diag(self.probs.T[value.flatten().long()]), From b047e5d84ad237e03ab9ee6f642c4095fe19a4ce Mon Sep 17 00:00:00 2001 From: mahon94 Date: Thu, 28 Oct 2021 12:38:42 -0700 Subject: [PATCH 3/6] clean up --- ml-agents/mlagents/trainers/torch/action_model.py | 8 -------- ml-agents/mlagents/trainers/torch/model_serialization.py | 1 - ml-agents/mlagents/trainers/torch/networks.py | 4 ---- 3 files changed, 13 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/action_model.py b/ml-agents/mlagents/trainers/torch/action_model.py index a27b1ffdef..7eb1fae0ef 100644 --- a/ml-agents/mlagents/trainers/torch/action_model.py +++ b/ml-agents/mlagents/trainers/torch/action_model.py @@ -9,9 +9,6 @@ from mlagents.trainers.torch.agent_action import AgentAction from mlagents.trainers.torch.action_log_probs import ActionLogProbs from mlagents_envs.base_env import ActionSpec -from mlagents_envs import logging_util - -logger = logging_util.get_logger(__name__) EPSILON = 1e-7 # Small value to avoid divide by zero @@ -175,20 +172,15 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten action_out_deprecated = continuous_out deter_continuous_out = torch.clamp(deter_continuous_out, -3, 3) / 3 if self.action_spec.discrete_size > 0 and dists.discrete is not None: - logger.info( - f"dist: {[discrete_dist.probs for discrete_dist in dists.discrete]}" - ) # TODO: remove discrete_out_list = [ discrete_dist.exported_model_output() for discrete_dist in dists.discrete ] - logger.info(f"discretelist {discrete_out_list}") # TODO: remove discrete_out = torch.cat(discrete_out_list, dim=1) action_out_deprecated = torch.cat(discrete_out_list, dim=1) deter_discrete_out_list = [ discrete_dist.deterministic_sample() for discrete_dist in dists.discrete ] - logger.info(f"deterlist {deter_discrete_out_list}") # TODO: remove deter_discrete_out = torch.cat(deter_discrete_out_list, dim=1) # deprecated action field does not support hybrid action diff --git a/ml-agents/mlagents/trainers/torch/model_serialization.py b/ml-agents/mlagents/trainers/torch/model_serialization.py index 6924f3a5be..f204b52445 100644 --- a/ml-agents/mlagents/trainers/torch/model_serialization.py +++ b/ml-agents/mlagents/trainers/torch/model_serialization.py @@ -169,6 +169,5 @@ def export_policy_model(self, output_filepath: str) -> None: input_names=self.input_names, output_names=self.output_names, dynamic_axes=self.dynamic_axes, - verbose=True, # TODO: remove ) logger.info(f"Exported {onnx_output_path}") diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index a79f5238f0..b041148c24 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -571,9 +571,6 @@ def forward( class SimpleActor(nn.Module, Actor): MODEL_EXPORT_VERSION = 3 # Corresponds to ModelApiVersion.MLAgents2_0 - is_stochastic_action_sampling = ( - True - ) # TODO: this should be a user input both for training and inference def __init__( self, @@ -585,7 +582,6 @@ def __init__( ): super().__init__() self.action_spec = action_spec - # self.is_continuous_int_deprecated = is_stochastic_action_sampling # TODO: self.version_number = torch.nn.Parameter( torch.Tensor([self.MODEL_EXPORT_VERSION]), requires_grad=False ) From 085e56e4836cf94bdde1e5e3befa1307c2097509 Mon Sep 17 00:00:00 2001 From: mahon94 Date: Thu, 28 Oct 2021 12:50:23 -0700 Subject: [PATCH 4/6] changelog --- com.unity.ml-agents/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index dab7cb200f..4af61bc356 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -28,6 +28,7 @@ and this project adheres to 1. env_params.max_lifetime_restarts (--max-lifetime-restarts) [default=10] 2. env_params.restarts_rate_limit_n (--restarts-rate-limit-n) [default=1] 3. env_params.restarts_rate_limit_period_s (--restarts-rate-limit-period-s) [default=60] +- Extra tensors are now serialized to support deterministic action selection in onnx. (#5597) ### Bug Fixes - Fixed the bug where curriculum learning would crash because of the incorrect run_options parsing. (#5586) From fb7849f6cc841b2e831eafc002310ef0b9c9acac Mon Sep 17 00:00:00 2001 From: mahon94 Date: Tue, 2 Nov 2021 16:19:08 -0700 Subject: [PATCH 5/6] Renaming --- .../mlagents/trainers/torch/action_model.py | 21 ++++++++++++------- ml-agents/mlagents/trainers/torch/networks.py | 8 +++---- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/action_model.py b/ml-agents/mlagents/trainers/torch/action_model.py index 7eb1fae0ef..e8c6577e92 100644 --- a/ml-agents/mlagents/trainers/torch/action_model.py +++ b/ml-agents/mlagents/trainers/torch/action_model.py @@ -162,15 +162,20 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten """ dists = self._get_dists(inputs, masks) continuous_out, discrete_out, action_out_deprecated = None, None, None - deter_continuous_out, deter_discrete_out = None, None # deterministic actions + deterministic_continuous_out, deterministic_discrete_out = ( + None, + None, + ) # deterministic actions if self.action_spec.continuous_size > 0 and dists.continuous is not None: continuous_out = dists.continuous.exported_model_output() action_out_deprecated = continuous_out - deter_continuous_out = dists.continuous.deterministic_sample() + deterministic_continuous_out = dists.continuous.deterministic_sample() if self._clip_action_on_export: continuous_out = torch.clamp(continuous_out, -3, 3) / 3 action_out_deprecated = continuous_out - deter_continuous_out = torch.clamp(deter_continuous_out, -3, 3) / 3 + deterministic_continuous_out = ( + torch.clamp(deterministic_continuous_out, -3, 3) / 3 + ) if self.action_spec.discrete_size > 0 and dists.discrete is not None: discrete_out_list = [ discrete_dist.exported_model_output() @@ -178,10 +183,12 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten ] discrete_out = torch.cat(discrete_out_list, dim=1) action_out_deprecated = torch.cat(discrete_out_list, dim=1) - deter_discrete_out_list = [ + deterministic_discrete_out_list = [ discrete_dist.deterministic_sample() for discrete_dist in dists.discrete ] - deter_discrete_out = torch.cat(deter_discrete_out_list, dim=1) + deterministic_discrete_out = torch.cat( + deterministic_discrete_out_list, dim=1 + ) # deprecated action field does not support hybrid action if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0: @@ -190,8 +197,8 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten continuous_out, discrete_out, action_out_deprecated, - deter_continuous_out, - deter_discrete_out, + deterministic_continuous_out, + deterministic_discrete_out, ) def forward( diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index b041148c24..7e94be9133 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -675,21 +675,21 @@ def forward( cont_action_out, disc_action_out, action_out_deprecated, - deter_cont_action_out, - deter_disc_action_out, + deterministic_cont_action_out, + deterministic_disc_action_out, ) = self.action_model.get_action_out(encoding, masks) export_out = [self.version_number, self.memory_size_vector] if self.action_spec.continuous_size > 0: export_out += [ cont_action_out, self.continuous_act_size_vector, - deter_cont_action_out, + deterministic_cont_action_out, ] if self.action_spec.discrete_size > 0: export_out += [ disc_action_out, self.discrete_act_size_vector, - deter_disc_action_out, + deterministic_disc_action_out, ] if self.network_body.memory_size > 0: export_out += [memories_out] From afa4d8375d6d165c41bf2c319a87fa135828f5ac Mon Sep 17 00:00:00 2001 From: mahon94 Date: Wed, 3 Nov 2021 09:22:24 -0700 Subject: [PATCH 6/6] Add more tests --- .../trainers/tests/torch/test_action_model.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_action_model.py b/ml-agents/mlagents/trainers/tests/torch/test_action_model.py index 9722931446..facd612755 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_action_model.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_action_model.py @@ -79,3 +79,36 @@ def test_get_probs_and_entropy(): for ent, val in zip(entropies[0].tolist(), [1.4189, 0.6191, 0.6191]): assert ent == pytest.approx(val, abs=0.01) + + +def test_get_onnx_deterministic_tensors(): + inp_size = 4 + act_size = 2 + action_model, masks = create_action_model(inp_size, act_size) + sample_inp = torch.ones((1, inp_size)) + out_tensors = action_model.get_action_out(sample_inp, masks=masks) + ( + continuous_out, + discrete_out, + action_out_deprecated, + deterministic_continuous_out, + deterministic_discrete_out, + ) = out_tensors + assert continuous_out.shape == (1, 2) + assert discrete_out.shape == (1, 2) + assert deterministic_discrete_out.shape == (1, 2) + assert deterministic_continuous_out.shape == (1, 2) + + # Second sampling from same distribution + out_tensors2 = action_model.get_action_out(sample_inp, masks=masks) + ( + continuous_out_2, + discrete_out_2, + action_out_2_deprecated, + deterministic_continuous_out_2, + deterministic_discrete_out_2, + ) = out_tensors2 + assert ~torch.all(torch.eq(continuous_out, continuous_out_2)) + assert torch.all( + torch.eq(deterministic_continuous_out, deterministic_continuous_out_2) + )