Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ and this project adheres to

- Added a new `--deterministic` cli flag to deterministically select the most probable actions in policy. The same thing can
be achieved by adding `deterministic: true` under `network_settings` of the run options configuration.
- Extra tensors are now serialized to support deterministic action selection in onnx. (#5597)
### Bug Fixes
- Fixed the bug where curriculum learning would crash because of the incorrect run_options parsing. (#5586)

Expand Down
33 changes: 33 additions & 0 deletions ml-agents/mlagents/trainers/tests/torch/test_action_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,36 @@ def test_get_probs_and_entropy():

for ent, val in zip(entropies[0].tolist(), [1.4189, 0.6191, 0.6191]):
assert ent == pytest.approx(val, abs=0.01)


def test_get_onnx_deterministic_tensors():
inp_size = 4
act_size = 2
action_model, masks = create_action_model(inp_size, act_size)
sample_inp = torch.ones((1, inp_size))
out_tensors = action_model.get_action_out(sample_inp, masks=masks)
(
continuous_out,
discrete_out,
action_out_deprecated,
deterministic_continuous_out,
deterministic_discrete_out,
) = out_tensors
assert continuous_out.shape == (1, 2)
assert discrete_out.shape == (1, 2)
assert deterministic_discrete_out.shape == (1, 2)
assert deterministic_continuous_out.shape == (1, 2)

# Second sampling from same distribution
out_tensors2 = action_model.get_action_out(sample_inp, masks=masks)
(
continuous_out_2,
discrete_out_2,
action_out_2_deprecated,
deterministic_continuous_out_2,
deterministic_discrete_out_2,
) = out_tensors2
assert ~torch.all(torch.eq(continuous_out, continuous_out_2))
assert torch.all(
torch.eq(deterministic_continuous_out, deterministic_continuous_out_2)
)
28 changes: 25 additions & 3 deletions ml-agents/mlagents/trainers/torch/action_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mlagents.trainers.torch.action_log_probs import ActionLogProbs
from mlagents_envs.base_env import ActionSpec


EPSILON = 1e-7 # Small value to avoid divide by zero


Expand Down Expand Up @@ -173,23 +174,44 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten
"""
dists = self._get_dists(inputs, masks)
continuous_out, discrete_out, action_out_deprecated = None, None, None
deterministic_continuous_out, deterministic_discrete_out = (
None,
None,
) # deterministic actions
if self.action_spec.continuous_size > 0 and dists.continuous is not None:
continuous_out = dists.continuous.exported_model_output()
action_out_deprecated = dists.continuous.exported_model_output()
action_out_deprecated = continuous_out
deterministic_continuous_out = dists.continuous.deterministic_sample()
if self._clip_action_on_export:
continuous_out = torch.clamp(continuous_out, -3, 3) / 3
action_out_deprecated = torch.clamp(action_out_deprecated, -3, 3) / 3
action_out_deprecated = continuous_out
deterministic_continuous_out = (
torch.clamp(deterministic_continuous_out, -3, 3) / 3
)
if self.action_spec.discrete_size > 0 and dists.discrete is not None:
discrete_out_list = [
discrete_dist.exported_model_output()
for discrete_dist in dists.discrete
]
discrete_out = torch.cat(discrete_out_list, dim=1)
action_out_deprecated = torch.cat(discrete_out_list, dim=1)
deterministic_discrete_out_list = [
discrete_dist.deterministic_sample() for discrete_dist in dists.discrete
]
deterministic_discrete_out = torch.cat(
deterministic_discrete_out_list, dim=1
)

# deprecated action field does not support hybrid action
if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0:
action_out_deprecated = None
return continuous_out, discrete_out, action_out_deprecated
return (
continuous_out,
discrete_out,
action_out_deprecated,
deterministic_continuous_out,
deterministic_discrete_out,
)

def forward(
self, inputs: torch.Tensor, masks: torch.Tensor
Expand Down
5 changes: 5 additions & 0 deletions ml-agents/mlagents/trainers/torch/model_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,13 @@ class TensorNames:
recurrent_output = "recurrent_out"
memory_size = "memory_size"
version_number = "version_number"

continuous_action_output_shape = "continuous_action_output_shape"
discrete_action_output_shape = "discrete_action_output_shape"
continuous_action_output = "continuous_actions"
discrete_action_output = "discrete_actions"
deterministic_continuous_action_output = "deterministic_continuous_actions"
deterministic_discrete_action_output = "deterministic_discrete_actions"

# Deprecated TensorNames entries for backward compatibility
is_continuous_control_deprecated = "is_continuous_control"
Expand Down Expand Up @@ -122,6 +125,7 @@ def __init__(self, policy):
self.output_names += [
TensorNames.continuous_action_output,
TensorNames.continuous_action_output_shape,
TensorNames.deterministic_continuous_action_output,
]
self.dynamic_axes.update(
{TensorNames.continuous_action_output: {0: "batch"}}
Expand All @@ -130,6 +134,7 @@ def __init__(self, policy):
self.output_names += [
TensorNames.discrete_action_output,
TensorNames.discrete_action_output_shape,
TensorNames.deterministic_discrete_action_output,
]
self.dynamic_axes.update({TensorNames.discrete_action_output: {0: "batch"}})

Expand Down
14 changes: 12 additions & 2 deletions ml-agents/mlagents/trainers/torch/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,12 +676,22 @@ def forward(
cont_action_out,
disc_action_out,
action_out_deprecated,
deterministic_cont_action_out,
deterministic_disc_action_out,
) = self.action_model.get_action_out(encoding, masks)
export_out = [self.version_number, self.memory_size_vector]
if self.action_spec.continuous_size > 0:
export_out += [cont_action_out, self.continuous_act_size_vector]
export_out += [
cont_action_out,
self.continuous_act_size_vector,
deterministic_cont_action_out,
]
if self.action_spec.discrete_size > 0:
export_out += [disc_action_out, self.discrete_act_size_vector]
export_out += [
disc_action_out,
self.discrete_act_size_vector,
deterministic_disc_action_out,
]
if self.network_body.memory_size > 0:
export_out += [memories_out]
return tuple(export_out)
Expand Down