diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md
index 613d652929..9c9279f530 100755
--- a/com.unity.ml-agents/CHANGELOG.md
+++ b/com.unity.ml-agents/CHANGELOG.md
@@ -29,9 +29,12 @@ and this project adheres to
2. env_params.restarts_rate_limit_n (--restarts-rate-limit-n) [default=1]
3. env_params.restarts_rate_limit_period_s (--restarts-rate-limit-period-s) [default=60]
-- Added a new `--deterministic` cli flag to deterministically select the most probable actions in policy. The same thing can
-be achieved by adding `deterministic: true` under `network_settings` of the run options configuration.
-- Extra tensors are now serialized to support deterministic action selection in onnx. (#5597)
+
+- Deterministic action selection is now supported during training and inference
+ - Added a new `--deterministic` cli flag to deterministically select the most probable actions in policy. The same thing can
+ be achieved by adding `deterministic: true` under `network_settings` of the run options configuration.(#5619)
+ - Extra tensors are now serialized to support deterministic action selection in onnx. (#5593)
+ - Support inference with deterministic action selection in editor (#5599)
### Bug Fixes
- Fixed the bug where curriculum learning would crash because of the incorrect run_options parsing. (#5586)
diff --git a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
index c5e8ddc802..a95b2846f3 100644
--- a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
+++ b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
@@ -25,6 +25,7 @@ internal class BehaviorParametersEditor : UnityEditor.Editor
const string k_BrainParametersName = "m_BrainParameters";
const string k_ModelName = "m_Model";
const string k_InferenceDeviceName = "m_InferenceDevice";
+ const string k_DeterministicInference = "m_DeterministicInference";
const string k_BehaviorTypeName = "m_BehaviorType";
const string k_TeamIdName = "TeamId";
const string k_UseChildSensorsName = "m_UseChildSensors";
@@ -68,6 +69,7 @@ public override void OnInspectorGUI()
EditorGUILayout.PropertyField(so.FindProperty(k_ModelName), true);
EditorGUI.indentLevel++;
EditorGUILayout.PropertyField(so.FindProperty(k_InferenceDeviceName), true);
+ EditorGUILayout.PropertyField(so.FindProperty(k_DeterministicInference), true);
EditorGUI.indentLevel--;
}
needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck();
@@ -156,7 +158,7 @@ void DisplayFailedModelChecks()
{
var failedChecks = Inference.BarracudaModelParamLoader.CheckModel(
barracudaModel, brainParameters, sensors, actuatorComponents,
- observableAttributeSensorTotalSize, behaviorParameters.BehaviorType
+ observableAttributeSensorTotalSize, behaviorParameters.BehaviorType, behaviorParameters.DeterministicInference
);
foreach (var check in failedChecks)
{
diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs
index 85cab21b26..409ccb1dca 100644
--- a/com.unity.ml-agents/Runtime/Academy.cs
+++ b/com.unity.ml-agents/Runtime/Academy.cs
@@ -616,14 +616,16 @@ void EnvironmentReset()
///
/// The inference device (CPU or GPU) the ModelRunner will use.
///
+ /// Inference only: set to true if the action selection from model should be
+ /// Deterministic.
/// The ModelRunner compatible with the input settings.
internal ModelRunner GetOrCreateModelRunner(
- NNModel model, ActionSpec actionSpec, InferenceDevice inferenceDevice)
+ NNModel model, ActionSpec actionSpec, InferenceDevice inferenceDevice, bool deterministicInference = false)
{
var modelRunner = m_ModelRunners.Find(x => x.HasModel(model, inferenceDevice));
if (modelRunner == null)
{
- modelRunner = new ModelRunner(model, actionSpec, inferenceDevice, m_InferenceSeed);
+ modelRunner = new ModelRunner(model, actionSpec, inferenceDevice, m_InferenceSeed, deterministicInference);
m_ModelRunners.Add(modelRunner);
m_InferenceSeed++;
}
diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs b/com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
index 6fd3872535..5e7338c057 100644
--- a/com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
+++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
@@ -112,8 +112,10 @@ public static int GetNumVisualInputs(this Model model)
///
/// The Barracuda engine model for loading static parameters.
///
+ /// Inference only: set to true if the action selection from model should be
+ /// deterministic.
/// Array of the output tensor names of the model
- public static string[] GetOutputNames(this Model model)
+ public static string[] GetOutputNames(this Model model, bool deterministicInference = false)
{
var names = new List();
@@ -122,13 +124,13 @@ public static string[] GetOutputNames(this Model model)
return names.ToArray();
}
- if (model.HasContinuousOutputs())
+ if (model.HasContinuousOutputs(deterministicInference))
{
- names.Add(model.ContinuousOutputName());
+ names.Add(model.ContinuousOutputName(deterministicInference));
}
- if (model.HasDiscreteOutputs())
+ if (model.HasDiscreteOutputs(deterministicInference))
{
- names.Add(model.DiscreteOutputName());
+ names.Add(model.DiscreteOutputName(deterministicInference));
}
var modelVersion = model.GetVersion();
@@ -149,8 +151,10 @@ public static string[] GetOutputNames(this Model model)
///
/// The Barracuda engine model for loading static parameters.
///
+ /// Inference only: set to true if the action selection from model should be
+ /// deterministic.
/// True if the model has continuous action outputs.
- public static bool HasContinuousOutputs(this Model model)
+ public static bool HasContinuousOutputs(this Model model, bool deterministicInference = false)
{
if (model == null)
return false;
@@ -160,8 +164,13 @@ public static bool HasContinuousOutputs(this Model model)
}
else
{
- return model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
- (int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0;
+ bool hasStochasticOutput = !deterministicInference &&
+ model.outputs.Contains(TensorNames.ContinuousActionOutput);
+ bool hasDeterministicOutput = deterministicInference &&
+ model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput);
+
+ return (hasStochasticOutput || hasDeterministicOutput) &&
+ (int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0;
}
}
@@ -194,8 +203,10 @@ public static int ContinuousOutputSize(this Model model)
///
/// The Barracuda engine model for loading static parameters.
///
+ /// Inference only: set to true if the action selection from model should be
+ /// deterministic.
/// Tensor name of continuous action output.
- public static string ContinuousOutputName(this Model model)
+ public static string ContinuousOutputName(this Model model, bool deterministicInference = false)
{
if (model == null)
return null;
@@ -205,7 +216,7 @@ public static string ContinuousOutputName(this Model model)
}
else
{
- return TensorNames.ContinuousActionOutput;
+ return deterministicInference ? TensorNames.DeterministicContinuousActionOutput : TensorNames.ContinuousActionOutput;
}
}
@@ -215,8 +226,10 @@ public static string ContinuousOutputName(this Model model)
///
/// The Barracuda engine model for loading static parameters.
///
+ /// Inference only: set to true if the action selection from model should be
+ /// deterministic.
/// True if the model has discrete action outputs.
- public static bool HasDiscreteOutputs(this Model model)
+ public static bool HasDiscreteOutputs(this Model model, bool deterministicInference = false)
{
if (model == null)
return false;
@@ -226,7 +239,12 @@ public static bool HasDiscreteOutputs(this Model model)
}
else
{
- return model.outputs.Contains(TensorNames.DiscreteActionOutput) && model.DiscreteOutputSize() > 0;
+ bool hasStochasticOutput = !deterministicInference &&
+ model.outputs.Contains(TensorNames.DiscreteActionOutput);
+ bool hasDeterministicOutput = deterministicInference &&
+ model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput);
+ return (hasStochasticOutput || hasDeterministicOutput) &&
+ model.DiscreteOutputSize() > 0;
}
}
@@ -279,8 +297,10 @@ public static int DiscreteOutputSize(this Model model)
///
/// The Barracuda engine model for loading static parameters.
///
+ /// Inference only: set to true if the action selection from model should be
+ /// deterministic.
/// Tensor name of discrete action output.
- public static string DiscreteOutputName(this Model model)
+ public static string DiscreteOutputName(this Model model, bool deterministicInference = false)
{
if (model == null)
return null;
@@ -290,7 +310,7 @@ public static string DiscreteOutputName(this Model model)
}
else
{
- return TensorNames.DiscreteActionOutput;
+ return deterministicInference ? TensorNames.DeterministicDiscreteActionOutput : TensorNames.DiscreteActionOutput;
}
}
@@ -316,9 +336,11 @@ public static bool SupportsContinuousAndDiscrete(this Model model)
/// The Barracuda engine model for loading static parameters.
///
/// Output list of failure messages
- ///
+ /// Inference only: set to true if the action selection from model should be
+ /// deterministic.
/// True if the model contains all the expected tensors.
- public static bool CheckExpectedTensors(this Model model, List failedModelChecks)
+ /// TODO: add checks for deterministic actions
+ public static bool CheckExpectedTensors(this Model model, List failedModelChecks, bool deterministicInference = false)
{
// Check the presence of model version
var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);
@@ -343,7 +365,9 @@ public static bool CheckExpectedTensors(this Model model, List fail
// Check the presence of action output tensor
if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) &&
!model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
- !model.outputs.Contains(TensorNames.DiscreteActionOutput))
+ !model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
+ !model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput) &&
+ !model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput))
{
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain any Action Output Node.")
@@ -373,22 +397,51 @@ public static bool CheckExpectedTensors(this Model model, List fail
}
else
{
- if (model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
- model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
+ if (model.outputs.Contains(TensorNames.ContinuousActionOutput))
{
- failedModelChecks.Add(
- FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
+ if (model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
+ {
+ failedModelChecks.Add(
+ FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
);
- return false;
+ return false;
+ }
+
+ else if (!model.HasContinuousOutputs(deterministicInference))
+ {
+ var actionType = deterministicInference ? "deterministic" : "stochastic";
+ var actionName = deterministicInference ? "Deterministic" : "";
+ failedModelChecks.Add(
+ FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Continuous Action Output Tensor. Uncheck `Deterministic inference` flag..")
+ );
+ return false;
+ }
}
- if (model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
- model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
+
+ if (model.outputs.Contains(TensorNames.DiscreteActionOutput))
{
- failedModelChecks.Add(
- FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
+ if (model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
+ {
+ failedModelChecks.Add(
+ FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
);
- return false;
+ return false;
+ }
+ else if (!model.HasDiscreteOutputs(deterministicInference))
+ {
+ var actionType = deterministicInference ? "deterministic" : "stochastic";
+ var actionName = deterministicInference ? "Deterministic" : "";
+ failedModelChecks.Add(
+ FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Discrete Action Output Tensor. Uncheck `Deterministic inference` flag.")
+ );
+ return false;
+ }
+
}
+
+
+
+
}
return true;
}
diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
index 21917b303d..6fe10566fd 100644
--- a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
+++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
@@ -122,6 +122,8 @@ public static FailedCheck CheckModelVersion(Model model)
/// Attached actuator components
/// Sum of the sizes of all ObservableAttributes.
/// BehaviorType or the Agent to check.
+ /// Inference only: set to true if the action selection from model should be
+ /// deterministic.
/// A IEnumerable of the checks that failed
public static IEnumerable CheckModel(
Model model,
@@ -129,7 +131,8 @@ public static IEnumerable CheckModel(
ISensor[] sensors,
ActuatorComponent[] actuatorComponents,
int observableAttributeTotalSize = 0,
- BehaviorType behaviorType = BehaviorType.Default
+ BehaviorType behaviorType = BehaviorType.Default,
+ bool deterministicInference = false
)
{
List failedModelChecks = new List();
@@ -148,7 +151,7 @@ public static IEnumerable CheckModel(
return failedModelChecks;
}
- var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks);
+ var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks, deterministicInference);
if (!hasExpectedTensors)
{
return failedModelChecks;
@@ -181,7 +184,7 @@ public static IEnumerable CheckModel(
else if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0)
{
failedModelChecks.AddRange(
- CheckInputTensorPresence(model, brainParameters, memorySize, sensors)
+ CheckInputTensorPresence(model, brainParameters, memorySize, sensors, deterministicInference)
);
failedModelChecks.AddRange(
CheckInputTensorShape(model, brainParameters, sensors, observableAttributeTotalSize)
@@ -195,7 +198,7 @@ public static IEnumerable CheckModel(
);
failedModelChecks.AddRange(
- CheckOutputTensorPresence(model, memorySize)
+ CheckOutputTensorPresence(model, memorySize, deterministicInference)
);
return failedModelChecks;
}
@@ -318,6 +321,8 @@ ISensor[] sensors
/// The memory size that the model is expecting.
///
/// Array of attached sensor components
+ /// Inference only: set to true if the action selection from model should be
+ /// Deterministic.
///
/// A IEnumerable of the checks that failed
///
@@ -325,7 +330,8 @@ static IEnumerable CheckInputTensorPresence(
Model model,
BrainParameters brainParameters,
int memory,
- ISensor[] sensors
+ ISensor[] sensors,
+ bool deterministicInference = false
)
{
var failedModelChecks = new List();
@@ -356,7 +362,7 @@ ISensor[] sensors
}
// If the model uses discrete control but does not have an input for action masks
- if (model.HasDiscreteOutputs())
+ if (model.HasDiscreteOutputs(deterministicInference))
{
if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder))
{
@@ -376,17 +382,19 @@ ISensor[] sensors
/// The Barracuda engine model for loading static parameters
///
/// The memory size that the model is expecting/
+ /// Inference only: set to true if the action selection from model should be
+ /// deterministic.
///
/// A IEnumerable of the checks that failed
///
- static IEnumerable CheckOutputTensorPresence(Model model, int memory)
+ static IEnumerable CheckOutputTensorPresence(Model model, int memory, bool deterministicInference = false)
{
var failedModelChecks = new List();
// If there is no Recurrent Output but the model is Recurrent.
if (memory > 0)
{
- var allOutputs = model.GetOutputNames().ToList();
+ var allOutputs = model.GetOutputNames(deterministicInference).ToList();
if (!allOutputs.Any(x => x == TensorNames.RecurrentOutput))
{
failedModelChecks.Add(
diff --git a/com.unity.ml-agents/Runtime/Inference/ModelRunner.cs b/com.unity.ml-agents/Runtime/Inference/ModelRunner.cs
index 422e0f9744..f59b54ee23 100644
--- a/com.unity.ml-agents/Runtime/Inference/ModelRunner.cs
+++ b/com.unity.ml-agents/Runtime/Inference/ModelRunner.cs
@@ -28,6 +28,7 @@ internal class ModelRunner
InferenceDevice m_InferenceDevice;
IWorker m_Engine;
bool m_Verbose = false;
+ bool m_DeterministicInference;
string[] m_OutputNames;
IReadOnlyList m_InferenceInputs;
List m_InferenceOutputs;
@@ -48,18 +49,22 @@ internal class ModelRunner
/// option for most of ML Agents models.
/// The seed that will be used to initialize the RandomNormal
/// and Multinomial objects used when running inference.
+ /// Inference only: set to true if the action selection from model should be
+ /// deterministic.
/// Throws an error when the model is null
///
public ModelRunner(
NNModel model,
ActionSpec actionSpec,
InferenceDevice inferenceDevice,
- int seed = 0)
+ int seed = 0,
+ bool deterministicInference = false)
{
Model barracudaModel;
m_Model = model;
m_ModelName = model.name;
m_InferenceDevice = inferenceDevice;
+ m_DeterministicInference = deterministicInference;
m_TensorAllocator = new TensorCachingAllocator();
if (model != null)
{
@@ -108,11 +113,12 @@ public ModelRunner(
}
m_InferenceInputs = barracudaModel.GetInputTensors();
- m_OutputNames = barracudaModel.GetOutputNames();
+ m_OutputNames = barracudaModel.GetOutputNames(m_DeterministicInference);
+
m_TensorGenerator = new TensorGenerator(
- seed, m_TensorAllocator, m_Memories, barracudaModel);
+ seed, m_TensorAllocator, m_Memories, barracudaModel, m_DeterministicInference);
m_TensorApplier = new TensorApplier(
- actionSpec, seed, m_TensorAllocator, m_Memories, barracudaModel);
+ actionSpec, seed, m_TensorAllocator, m_Memories, barracudaModel, m_DeterministicInference);
m_InputsByName = new Dictionary();
m_InferenceOutputs = new List();
}
diff --git a/com.unity.ml-agents/Runtime/Inference/TensorApplier.cs b/com.unity.ml-agents/Runtime/Inference/TensorApplier.cs
index d311010aae..a03b3d927e 100644
--- a/com.unity.ml-agents/Runtime/Inference/TensorApplier.cs
+++ b/com.unity.ml-agents/Runtime/Inference/TensorApplier.cs
@@ -44,12 +44,15 @@ public interface IApplier
/// Tensor allocator
/// Dictionary of AgentInfo.id to memory used to pass to the inference model.
///
+ /// Inference only: set to true if the action selection from model should be
+ /// deterministic.
public TensorApplier(
ActionSpec actionSpec,
int seed,
ITensorAllocator allocator,
Dictionary> memories,
- object barracudaModel = null)
+ object barracudaModel = null,
+ bool deterministicInference = false)
{
// If model is null, no inference to run and exception is thrown before reaching here.
if (barracudaModel == null)
@@ -64,13 +67,13 @@ public TensorApplier(
}
if (actionSpec.NumContinuousActions > 0)
{
- var tensorName = model.ContinuousOutputName();
+ var tensorName = model.ContinuousOutputName(deterministicInference);
m_Dict[tensorName] = new ContinuousActionOutputApplier(actionSpec);
}
var modelVersion = model.GetVersion();
if (actionSpec.NumDiscreteActions > 0)
{
- var tensorName = model.DiscreteOutputName();
+ var tensorName = model.DiscreteOutputName(deterministicInference);
if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0)
{
m_Dict[tensorName] = new LegacyDiscreteActionOutputApplier(actionSpec, seed, allocator);
diff --git a/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs b/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
index feb521ebd8..39bed85792 100644
--- a/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
+++ b/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
@@ -44,11 +44,14 @@ void Generate(
/// Tensor allocator.
/// Dictionary of AgentInfo.id to memory for use in the inference model.
///
+ /// Inference only: set to true if the action selection from model should be
+ /// deterministic.
public TensorGenerator(
int seed,
ITensorAllocator allocator,
Dictionary> memories,
- object barracudaModel = null)
+ object barracudaModel = null,
+ bool deterministicInference = false)
{
// If model is null, no inference to run and exception is thrown before reaching here.
if (barracudaModel == null)
@@ -76,13 +79,13 @@ public TensorGenerator(
// Generators for Outputs
- if (model.HasContinuousOutputs())
+ if (model.HasContinuousOutputs(deterministicInference))
{
- m_Dict[model.ContinuousOutputName()] = new BiDimensionalOutputGenerator(allocator);
+ m_Dict[model.ContinuousOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator);
}
- if (model.HasDiscreteOutputs())
+ if (model.HasDiscreteOutputs(deterministicInference))
{
- m_Dict[model.DiscreteOutputName()] = new BiDimensionalOutputGenerator(allocator);
+ m_Dict[model.DiscreteOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator);
}
m_Dict[TensorNames.RecurrentOutput] = new BiDimensionalOutputGenerator(allocator);
m_Dict[TensorNames.ValueEstimateOutput] = new BiDimensionalOutputGenerator(allocator);
diff --git a/com.unity.ml-agents/Runtime/Inference/TensorNames.cs b/com.unity.ml-agents/Runtime/Inference/TensorNames.cs
index dc20e1f8f3..48ae04b5f6 100644
--- a/com.unity.ml-agents/Runtime/Inference/TensorNames.cs
+++ b/com.unity.ml-agents/Runtime/Inference/TensorNames.cs
@@ -23,6 +23,8 @@ internal static class TensorNames
public const string DiscreteActionOutputShape = "discrete_action_output_shape";
public const string ContinuousActionOutput = "continuous_actions";
public const string DiscreteActionOutput = "discrete_actions";
+ public const string DeterministicContinuousActionOutput = "deterministic_continuous_actions";
+ public const string DeterministicDiscreteActionOutput = "deterministic_discrete_actions";
// Deprecated TensorNames entries for backward compatibility
public const string IsContinuousControlDeprecated = "is_continuous_control";
diff --git a/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
index 5e76084b20..96a15b50d8 100644
--- a/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
+++ b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
@@ -45,6 +45,11 @@ internal class BarracudaPolicy : IPolicy
ActionBuffers m_LastActionBuffer;
int m_AgentId;
+ ///
+ /// Inference only: set to true if the action selection from model should be
+ /// deterministic.
+ ///
+ bool m_DeterministicInference;
///
/// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors.
@@ -73,19 +78,23 @@ internal class BarracudaPolicy : IPolicy
/// The Neural Network to use.
/// Which device Barracuda will run on.
/// The name of the behavior.
+ /// Inference only: set to true if the action selection from model should be
+ /// deterministic.
public BarracudaPolicy(
ActionSpec actionSpec,
IList actuators,
NNModel model,
InferenceDevice inferenceDevice,
- string behaviorName
+ string behaviorName,
+ bool deterministicInference = false
)
{
- var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, actionSpec, inferenceDevice);
+ var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, actionSpec, inferenceDevice, deterministicInference);
m_ModelRunner = modelRunner;
m_BehaviorName = behaviorName;
m_ActionSpec = actionSpec;
m_Actuators = actuators;
+ m_DeterministicInference = deterministicInference;
}
///
diff --git a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
index ae05284d50..b0d369b910 100644
--- a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
+++ b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
@@ -177,6 +177,20 @@ public bool UseChildSensors
set { m_UseChildSensors = value; }
}
+ [HideInInspector]
+ [SerializeField]
+ [Tooltip("Set action selection to deterministic, Only applies to inference from within unity.")]
+ private bool m_DeterministicInference = false;
+
+ ///
+ /// Whether to select actions deterministically during inference from the provided neural network.
+ ///
+ public bool DeterministicInference
+ {
+ get { return m_DeterministicInference; }
+ set { m_DeterministicInference = value; }
+ }
+
///
/// Whether or not to use all the actuator components attached to child GameObjects of the agent.
/// Note that changing this after the Agent has been initialized will not have any effect.
@@ -228,7 +242,7 @@ internal IPolicy GeneratePolicy(ActionSpec actionSpec, ActuatorManager actuatorM
"Either assign a model, or change to a different Behavior Type."
);
}
- return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName);
+ return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName, m_DeterministicInference);
}
case BehaviorType.Default:
if (Academy.Instance.IsCommunicatorOn)
@@ -237,7 +251,7 @@ internal IPolicy GeneratePolicy(ActionSpec actionSpec, ActuatorManager actuatorM
}
if (m_Model != null)
{
- return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName);
+ return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName, m_DeterministicInference);
}
else
{
diff --git a/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs b/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs
index 0e81c4f8ad..da802a38d5 100644
--- a/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs
+++ b/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs
@@ -1,3 +1,4 @@
+using System;
using System.Linq;
using NUnit.Framework;
using UnityEngine;
@@ -6,9 +7,29 @@
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Policies;
+using System.Collections.Generic;
namespace Unity.MLAgents.Tests
{
+ public class FloatThresholdComparer : IEqualityComparer
+ {
+ private readonly float _threshold;
+ public FloatThresholdComparer(float threshold)
+ {
+ _threshold = threshold;
+ }
+
+ public bool Equals(float x, float y)
+ {
+ return Math.Abs(x - y) < _threshold;
+ }
+
+ public int GetHashCode(float f)
+ {
+ throw new NotImplementedException("Unable to generate a hash code for threshold floats, do not use this method");
+ }
+ }
+
[TestFixture]
public class ModelRunnerTest
{
@@ -19,6 +40,9 @@ public class ModelRunnerTest
const string k_hybridONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction_v1_0.onnx";
const string k_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated_v1_0.nn";
const string k_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated_v1_0.nn";
+ // models with deterministic action tensors
+ private const string k_deterministic_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx";
+ private const string k_deterministic_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx";
NNModel hybridONNXModelV2;
NNModel continuousONNXModel;
@@ -26,6 +50,8 @@ public class ModelRunnerTest
NNModel hybridONNXModel;
NNModel continuousNNModel;
NNModel discreteNNModel;
+ NNModel deterministicDiscreteNNModel;
+ NNModel deterministicContinuousNNModel;
Test3DSensorComponent sensor_21_20_3;
Test3DSensorComponent sensor_20_22_3;
@@ -55,6 +81,8 @@ public void SetUp()
hybridONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_hybridONNXPath, typeof(NNModel));
continuousNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousNNPath, typeof(NNModel));
discreteNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteNNPath, typeof(NNModel));
+ deterministicDiscreteNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_deterministic_discreteNNPath, typeof(NNModel));
+ deterministicContinuousNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_deterministic_continuousNNPath, typeof(NNModel));
var go = new GameObject("SensorA");
sensor_21_20_3 = go.AddComponent();
sensor_21_20_3.Sensor = new Test3DSensor("SensorA", 21, 20, 3);
@@ -71,6 +99,8 @@ public void TestModelExist()
Assert.IsNotNull(continuousNNModel);
Assert.IsNotNull(discreteNNModel);
Assert.IsNotNull(hybridONNXModelV2);
+ Assert.IsNotNull(deterministicDiscreteNNModel);
+ Assert.IsNotNull(deterministicContinuousNNModel);
}
[Test]
@@ -99,6 +129,15 @@ public void TestCreation()
// This one was trained with 2.0 so it should not raise an error:
modelRunner = new ModelRunner(hybridONNXModelV2, new ActionSpec(2, new[] { 2, 3 }), inferenceDevice);
modelRunner.Dispose();
+
+ // V2.0 Model that has serialized deterministic action tensors, discrete
+ modelRunner = new ModelRunner(deterministicDiscreteNNModel, new ActionSpec(0, new[] { 7 }), inferenceDevice);
+ modelRunner.Dispose();
+ // V2.0 Model that has serialized deterministic action tensors, continuous
+ modelRunner = new ModelRunner(deterministicContinuousNNModel,
+ GetContinuous2vis8vec2actionActionSpec(), inferenceDevice,
+ deterministicInference: true);
+ modelRunner.Dispose();
}
[Test]
@@ -138,5 +177,60 @@ public void TestRunModel()
Assert.AreEqual(actionSpec.NumDiscreteActions, modelRunner.GetAction(1).DiscreteActions.Length);
modelRunner.Dispose();
}
+
+
+ [Test]
+ public void TestRunModel_stochastic()
+ {
+ var actionSpec = GetContinuous2vis8vec2actionActionSpec();
+ // deterministicInference = false by default
+ var modelRunner = new ModelRunner(deterministicContinuousNNModel, actionSpec, InferenceDevice.Burst);
+ var sensor_8 = new Sensors.VectorSensor(8, "VectorSensor8");
+ var info1 = new AgentInfo();
+ var obs = new[]
+ {
+ sensor_8,
+ sensor_21_20_3.CreateSensors()[0],
+ sensor_20_22_3.CreateSensors()[0]
+ }.ToList();
+ info1.episodeId = 1;
+ modelRunner.PutObservations(info1, obs);
+ modelRunner.DecideBatch();
+ var stochAction1 = (float[])modelRunner.GetAction(1).ContinuousActions.Array.Clone();
+
+ modelRunner.PutObservations(info1, obs);
+ modelRunner.DecideBatch();
+ var stochAction2 = (float[])modelRunner.GetAction(1).ContinuousActions.Array.Clone();
+ // Stochastic action selection should output randomly different action values with same obs
+ Assert.IsFalse(Enumerable.SequenceEqual(stochAction1, stochAction2, new FloatThresholdComparer(0.001f)));
+ modelRunner.Dispose();
+ }
+ [Test]
+ public void TestRunModel_deterministic()
+ {
+ var actionSpec = GetContinuous2vis8vec2actionActionSpec();
+ var modelRunner = new ModelRunner(deterministicContinuousNNModel, actionSpec, InferenceDevice.Burst);
+ var sensor_8 = new Sensors.VectorSensor(8, "VectorSensor8");
+ var info1 = new AgentInfo();
+ var obs = new[]
+ {
+ sensor_8,
+ sensor_21_20_3.CreateSensors()[0],
+ sensor_20_22_3.CreateSensors()[0]
+ }.ToList();
+ var deterministicModelRunner = new ModelRunner(deterministicContinuousNNModel, actionSpec, InferenceDevice.Burst,
+ deterministicInference: true);
+ info1.episodeId = 1;
+ deterministicModelRunner.PutObservations(info1, obs);
+ deterministicModelRunner.DecideBatch();
+ var deterministicAction1 = (float[])deterministicModelRunner.GetAction(1).ContinuousActions.Array.Clone();
+
+ deterministicModelRunner.PutObservations(info1, obs);
+ deterministicModelRunner.DecideBatch();
+ var deterministicAction2 = (float[])deterministicModelRunner.GetAction(1).ContinuousActions.Array.Clone();
+ // Deterministic action selection should output same action everytime
+ Assert.IsTrue(Enumerable.SequenceEqual(deterministicAction1, deterministicAction2, new FloatThresholdComparer(0.001f)));
+ modelRunner.Dispose();
+ }
}
}
diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx b/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx
new file mode 100644
index 0000000000..56c1cd4355
Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx differ
diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx.meta b/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx.meta
new file mode 100644
index 0000000000..cc92cc94b8
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx.meta
@@ -0,0 +1,14 @@
+fileFormatVersion: 2
+guid: e905d8f9eadcf45aa8c485594fecba6d
+ScriptedImporter:
+ internalIDToNameTable: []
+ externalObjects: {}
+ serializedVersion: 2
+ userData:
+ assetBundleName:
+ assetBundleVariant:
+ script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3}
+ optimizeModel: 1
+ forceArbitraryBatchSize: 1
+ treatErrorsAsWarnings: 0
+ importMode: 1
diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx b/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx
new file mode 100644
index 0000000000..3aa846e204
Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx differ
diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx.meta b/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx.meta
new file mode 100644
index 0000000000..a141a55235
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx.meta
@@ -0,0 +1,14 @@
+fileFormatVersion: 2
+guid: d132cc9c934a54fdc99758427373e038
+ScriptedImporter:
+ internalIDToNameTable: []
+ externalObjects: {}
+ serializedVersion: 2
+ userData:
+ assetBundleName:
+ assetBundleVariant:
+ script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3}
+ optimizeModel: 1
+ forceArbitraryBatchSize: 1
+ treatErrorsAsWarnings: 0
+ importMode: 1
diff --git a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
index 0c3b6312b4..743212a69b 100644
--- a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
+++ b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
@@ -70,6 +70,7 @@ public IEnumerator RuntimeApiTestWithEnumeratorPasses()
behaviorParams.BehaviorName = "TestBehavior";
behaviorParams.TeamId = 42;
behaviorParams.UseChildSensors = true;
+ behaviorParams.DeterministicInference = false;
behaviorParams.ObservableAttributeHandling = ObservableAttributeOptions.ExamineAll;
diff --git a/docs/Getting-Started.md b/docs/Getting-Started.md
index c4ec5332ad..7baf461ebd 100644
--- a/docs/Getting-Started.md
+++ b/docs/Getting-Started.md
@@ -119,6 +119,9 @@ example.
**Note** : You can modify multiple game objects in a scene by selecting them
all at once using the search bar in the Scene Hierarchy.
1. Set the **Inference Device** to use for this model as `CPU`.
+1. If the model is trained with Release 19 or later, you can select
+ `Deterministic Inference` to choose actions deterministically from the model.
+ Works only for inference within unity with no python process involved.
1. Click the **Play** button in the Unity Editor and you will see the platforms
balance the balls using the pre-trained model.
diff --git a/docs/Learning-Environment-Design-Agents.md b/docs/Learning-Environment-Design-Agents.md
index 8953894957..dc7022189f 100644
--- a/docs/Learning-Environment-Design-Agents.md
+++ b/docs/Learning-Environment-Design-Agents.md
@@ -987,6 +987,9 @@ be called independently of the `Max Step` property.
training)
- `Inference Device` - Whether to use CPU or GPU to run the model during
inference
+ - `Deterministic Inference` - Weather to set action selection to deterministic,
+ Only applies to inference from within unity (with no python process involved) and
+ Release 19 or later.
- `Behavior Type` - Determines whether the Agent will do training, inference,
or use its Heuristic() method:
- `Default` - the Agent will train if they connect to a python trainer,
diff --git a/docs/images/3dball_learning_brain.png b/docs/images/3dball_learning_brain.png
index c133bf2779..68757fa6ce 100644
Binary files a/docs/images/3dball_learning_brain.png and b/docs/images/3dball_learning_brain.png differ