diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 613d652929..9c9279f530 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -29,9 +29,12 @@ and this project adheres to 2. env_params.restarts_rate_limit_n (--restarts-rate-limit-n) [default=1] 3. env_params.restarts_rate_limit_period_s (--restarts-rate-limit-period-s) [default=60] -- Added a new `--deterministic` cli flag to deterministically select the most probable actions in policy. The same thing can -be achieved by adding `deterministic: true` under `network_settings` of the run options configuration. -- Extra tensors are now serialized to support deterministic action selection in onnx. (#5597) + +- Deterministic action selection is now supported during training and inference + - Added a new `--deterministic` cli flag to deterministically select the most probable actions in policy. The same thing can + be achieved by adding `deterministic: true` under `network_settings` of the run options configuration.(#5619) + - Extra tensors are now serialized to support deterministic action selection in onnx. (#5593) + - Support inference with deterministic action selection in editor (#5599) ### Bug Fixes - Fixed the bug where curriculum learning would crash because of the incorrect run_options parsing. (#5586) diff --git a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs index c5e8ddc802..a95b2846f3 100644 --- a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs +++ b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs @@ -25,6 +25,7 @@ internal class BehaviorParametersEditor : UnityEditor.Editor const string k_BrainParametersName = "m_BrainParameters"; const string k_ModelName = "m_Model"; const string k_InferenceDeviceName = "m_InferenceDevice"; + const string k_DeterministicInference = "m_DeterministicInference"; const string k_BehaviorTypeName = "m_BehaviorType"; const string k_TeamIdName = "TeamId"; const string k_UseChildSensorsName = "m_UseChildSensors"; @@ -68,6 +69,7 @@ public override void OnInspectorGUI() EditorGUILayout.PropertyField(so.FindProperty(k_ModelName), true); EditorGUI.indentLevel++; EditorGUILayout.PropertyField(so.FindProperty(k_InferenceDeviceName), true); + EditorGUILayout.PropertyField(so.FindProperty(k_DeterministicInference), true); EditorGUI.indentLevel--; } needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck(); @@ -156,7 +158,7 @@ void DisplayFailedModelChecks() { var failedChecks = Inference.BarracudaModelParamLoader.CheckModel( barracudaModel, brainParameters, sensors, actuatorComponents, - observableAttributeSensorTotalSize, behaviorParameters.BehaviorType + observableAttributeSensorTotalSize, behaviorParameters.BehaviorType, behaviorParameters.DeterministicInference ); foreach (var check in failedChecks) { diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs index 85cab21b26..409ccb1dca 100644 --- a/com.unity.ml-agents/Runtime/Academy.cs +++ b/com.unity.ml-agents/Runtime/Academy.cs @@ -616,14 +616,16 @@ void EnvironmentReset() /// /// The inference device (CPU or GPU) the ModelRunner will use. /// + /// Inference only: set to true if the action selection from model should be + /// Deterministic. /// The ModelRunner compatible with the input settings. internal ModelRunner GetOrCreateModelRunner( - NNModel model, ActionSpec actionSpec, InferenceDevice inferenceDevice) + NNModel model, ActionSpec actionSpec, InferenceDevice inferenceDevice, bool deterministicInference = false) { var modelRunner = m_ModelRunners.Find(x => x.HasModel(model, inferenceDevice)); if (modelRunner == null) { - modelRunner = new ModelRunner(model, actionSpec, inferenceDevice, m_InferenceSeed); + modelRunner = new ModelRunner(model, actionSpec, inferenceDevice, m_InferenceSeed, deterministicInference); m_ModelRunners.Add(modelRunner); m_InferenceSeed++; } diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs b/com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs index 6fd3872535..5e7338c057 100644 --- a/com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs +++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs @@ -112,8 +112,10 @@ public static int GetNumVisualInputs(this Model model) /// /// The Barracuda engine model for loading static parameters. /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. /// Array of the output tensor names of the model - public static string[] GetOutputNames(this Model model) + public static string[] GetOutputNames(this Model model, bool deterministicInference = false) { var names = new List(); @@ -122,13 +124,13 @@ public static string[] GetOutputNames(this Model model) return names.ToArray(); } - if (model.HasContinuousOutputs()) + if (model.HasContinuousOutputs(deterministicInference)) { - names.Add(model.ContinuousOutputName()); + names.Add(model.ContinuousOutputName(deterministicInference)); } - if (model.HasDiscreteOutputs()) + if (model.HasDiscreteOutputs(deterministicInference)) { - names.Add(model.DiscreteOutputName()); + names.Add(model.DiscreteOutputName(deterministicInference)); } var modelVersion = model.GetVersion(); @@ -149,8 +151,10 @@ public static string[] GetOutputNames(this Model model) /// /// The Barracuda engine model for loading static parameters. /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. /// True if the model has continuous action outputs. - public static bool HasContinuousOutputs(this Model model) + public static bool HasContinuousOutputs(this Model model, bool deterministicInference = false) { if (model == null) return false; @@ -160,8 +164,13 @@ public static bool HasContinuousOutputs(this Model model) } else { - return model.outputs.Contains(TensorNames.ContinuousActionOutput) && - (int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0; + bool hasStochasticOutput = !deterministicInference && + model.outputs.Contains(TensorNames.ContinuousActionOutput); + bool hasDeterministicOutput = deterministicInference && + model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput); + + return (hasStochasticOutput || hasDeterministicOutput) && + (int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0; } } @@ -194,8 +203,10 @@ public static int ContinuousOutputSize(this Model model) /// /// The Barracuda engine model for loading static parameters. /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. /// Tensor name of continuous action output. - public static string ContinuousOutputName(this Model model) + public static string ContinuousOutputName(this Model model, bool deterministicInference = false) { if (model == null) return null; @@ -205,7 +216,7 @@ public static string ContinuousOutputName(this Model model) } else { - return TensorNames.ContinuousActionOutput; + return deterministicInference ? TensorNames.DeterministicContinuousActionOutput : TensorNames.ContinuousActionOutput; } } @@ -215,8 +226,10 @@ public static string ContinuousOutputName(this Model model) /// /// The Barracuda engine model for loading static parameters. /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. /// True if the model has discrete action outputs. - public static bool HasDiscreteOutputs(this Model model) + public static bool HasDiscreteOutputs(this Model model, bool deterministicInference = false) { if (model == null) return false; @@ -226,7 +239,12 @@ public static bool HasDiscreteOutputs(this Model model) } else { - return model.outputs.Contains(TensorNames.DiscreteActionOutput) && model.DiscreteOutputSize() > 0; + bool hasStochasticOutput = !deterministicInference && + model.outputs.Contains(TensorNames.DiscreteActionOutput); + bool hasDeterministicOutput = deterministicInference && + model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput); + return (hasStochasticOutput || hasDeterministicOutput) && + model.DiscreteOutputSize() > 0; } } @@ -279,8 +297,10 @@ public static int DiscreteOutputSize(this Model model) /// /// The Barracuda engine model for loading static parameters. /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. /// Tensor name of discrete action output. - public static string DiscreteOutputName(this Model model) + public static string DiscreteOutputName(this Model model, bool deterministicInference = false) { if (model == null) return null; @@ -290,7 +310,7 @@ public static string DiscreteOutputName(this Model model) } else { - return TensorNames.DiscreteActionOutput; + return deterministicInference ? TensorNames.DeterministicDiscreteActionOutput : TensorNames.DiscreteActionOutput; } } @@ -316,9 +336,11 @@ public static bool SupportsContinuousAndDiscrete(this Model model) /// The Barracuda engine model for loading static parameters. /// /// Output list of failure messages - /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. /// True if the model contains all the expected tensors. - public static bool CheckExpectedTensors(this Model model, List failedModelChecks) + /// TODO: add checks for deterministic actions + public static bool CheckExpectedTensors(this Model model, List failedModelChecks, bool deterministicInference = false) { // Check the presence of model version var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber); @@ -343,7 +365,9 @@ public static bool CheckExpectedTensors(this Model model, List fail // Check the presence of action output tensor if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) && !model.outputs.Contains(TensorNames.ContinuousActionOutput) && - !model.outputs.Contains(TensorNames.DiscreteActionOutput)) + !model.outputs.Contains(TensorNames.DiscreteActionOutput) && + !model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput) && + !model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput)) { failedModelChecks.Add( FailedCheck.Warning("The model does not contain any Action Output Node.") @@ -373,22 +397,51 @@ public static bool CheckExpectedTensors(this Model model, List fail } else { - if (model.outputs.Contains(TensorNames.ContinuousActionOutput) && - model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null) + if (model.outputs.Contains(TensorNames.ContinuousActionOutput)) { - failedModelChecks.Add( - FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.") + if (model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null) + { + failedModelChecks.Add( + FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.") ); - return false; + return false; + } + + else if (!model.HasContinuousOutputs(deterministicInference)) + { + var actionType = deterministicInference ? "deterministic" : "stochastic"; + var actionName = deterministicInference ? "Deterministic" : ""; + failedModelChecks.Add( + FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Continuous Action Output Tensor. Uncheck `Deterministic inference` flag..") + ); + return false; + } } - if (model.outputs.Contains(TensorNames.DiscreteActionOutput) && - model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null) + + if (model.outputs.Contains(TensorNames.DiscreteActionOutput)) { - failedModelChecks.Add( - FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.") + if (model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null) + { + failedModelChecks.Add( + FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.") ); - return false; + return false; + } + else if (!model.HasDiscreteOutputs(deterministicInference)) + { + var actionType = deterministicInference ? "deterministic" : "stochastic"; + var actionName = deterministicInference ? "Deterministic" : ""; + failedModelChecks.Add( + FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Discrete Action Output Tensor. Uncheck `Deterministic inference` flag.") + ); + return false; + } + } + + + + } return true; } diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs index 21917b303d..6fe10566fd 100644 --- a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs +++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs @@ -122,6 +122,8 @@ public static FailedCheck CheckModelVersion(Model model) /// Attached actuator components /// Sum of the sizes of all ObservableAttributes. /// BehaviorType or the Agent to check. + /// Inference only: set to true if the action selection from model should be + /// deterministic. /// A IEnumerable of the checks that failed public static IEnumerable CheckModel( Model model, @@ -129,7 +131,8 @@ public static IEnumerable CheckModel( ISensor[] sensors, ActuatorComponent[] actuatorComponents, int observableAttributeTotalSize = 0, - BehaviorType behaviorType = BehaviorType.Default + BehaviorType behaviorType = BehaviorType.Default, + bool deterministicInference = false ) { List failedModelChecks = new List(); @@ -148,7 +151,7 @@ public static IEnumerable CheckModel( return failedModelChecks; } - var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks); + var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks, deterministicInference); if (!hasExpectedTensors) { return failedModelChecks; @@ -181,7 +184,7 @@ public static IEnumerable CheckModel( else if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0) { failedModelChecks.AddRange( - CheckInputTensorPresence(model, brainParameters, memorySize, sensors) + CheckInputTensorPresence(model, brainParameters, memorySize, sensors, deterministicInference) ); failedModelChecks.AddRange( CheckInputTensorShape(model, brainParameters, sensors, observableAttributeTotalSize) @@ -195,7 +198,7 @@ public static IEnumerable CheckModel( ); failedModelChecks.AddRange( - CheckOutputTensorPresence(model, memorySize) + CheckOutputTensorPresence(model, memorySize, deterministicInference) ); return failedModelChecks; } @@ -318,6 +321,8 @@ ISensor[] sensors /// The memory size that the model is expecting. /// /// Array of attached sensor components + /// Inference only: set to true if the action selection from model should be + /// Deterministic. /// /// A IEnumerable of the checks that failed /// @@ -325,7 +330,8 @@ static IEnumerable CheckInputTensorPresence( Model model, BrainParameters brainParameters, int memory, - ISensor[] sensors + ISensor[] sensors, + bool deterministicInference = false ) { var failedModelChecks = new List(); @@ -356,7 +362,7 @@ ISensor[] sensors } // If the model uses discrete control but does not have an input for action masks - if (model.HasDiscreteOutputs()) + if (model.HasDiscreteOutputs(deterministicInference)) { if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder)) { @@ -376,17 +382,19 @@ ISensor[] sensors /// The Barracuda engine model for loading static parameters /// /// The memory size that the model is expecting/ + /// Inference only: set to true if the action selection from model should be + /// deterministic. /// /// A IEnumerable of the checks that failed /// - static IEnumerable CheckOutputTensorPresence(Model model, int memory) + static IEnumerable CheckOutputTensorPresence(Model model, int memory, bool deterministicInference = false) { var failedModelChecks = new List(); // If there is no Recurrent Output but the model is Recurrent. if (memory > 0) { - var allOutputs = model.GetOutputNames().ToList(); + var allOutputs = model.GetOutputNames(deterministicInference).ToList(); if (!allOutputs.Any(x => x == TensorNames.RecurrentOutput)) { failedModelChecks.Add( diff --git a/com.unity.ml-agents/Runtime/Inference/ModelRunner.cs b/com.unity.ml-agents/Runtime/Inference/ModelRunner.cs index 422e0f9744..f59b54ee23 100644 --- a/com.unity.ml-agents/Runtime/Inference/ModelRunner.cs +++ b/com.unity.ml-agents/Runtime/Inference/ModelRunner.cs @@ -28,6 +28,7 @@ internal class ModelRunner InferenceDevice m_InferenceDevice; IWorker m_Engine; bool m_Verbose = false; + bool m_DeterministicInference; string[] m_OutputNames; IReadOnlyList m_InferenceInputs; List m_InferenceOutputs; @@ -48,18 +49,22 @@ internal class ModelRunner /// option for most of ML Agents models. /// The seed that will be used to initialize the RandomNormal /// and Multinomial objects used when running inference. + /// Inference only: set to true if the action selection from model should be + /// deterministic. /// Throws an error when the model is null /// public ModelRunner( NNModel model, ActionSpec actionSpec, InferenceDevice inferenceDevice, - int seed = 0) + int seed = 0, + bool deterministicInference = false) { Model barracudaModel; m_Model = model; m_ModelName = model.name; m_InferenceDevice = inferenceDevice; + m_DeterministicInference = deterministicInference; m_TensorAllocator = new TensorCachingAllocator(); if (model != null) { @@ -108,11 +113,12 @@ public ModelRunner( } m_InferenceInputs = barracudaModel.GetInputTensors(); - m_OutputNames = barracudaModel.GetOutputNames(); + m_OutputNames = barracudaModel.GetOutputNames(m_DeterministicInference); + m_TensorGenerator = new TensorGenerator( - seed, m_TensorAllocator, m_Memories, barracudaModel); + seed, m_TensorAllocator, m_Memories, barracudaModel, m_DeterministicInference); m_TensorApplier = new TensorApplier( - actionSpec, seed, m_TensorAllocator, m_Memories, barracudaModel); + actionSpec, seed, m_TensorAllocator, m_Memories, barracudaModel, m_DeterministicInference); m_InputsByName = new Dictionary(); m_InferenceOutputs = new List(); } diff --git a/com.unity.ml-agents/Runtime/Inference/TensorApplier.cs b/com.unity.ml-agents/Runtime/Inference/TensorApplier.cs index d311010aae..a03b3d927e 100644 --- a/com.unity.ml-agents/Runtime/Inference/TensorApplier.cs +++ b/com.unity.ml-agents/Runtime/Inference/TensorApplier.cs @@ -44,12 +44,15 @@ public interface IApplier /// Tensor allocator /// Dictionary of AgentInfo.id to memory used to pass to the inference model. /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. public TensorApplier( ActionSpec actionSpec, int seed, ITensorAllocator allocator, Dictionary> memories, - object barracudaModel = null) + object barracudaModel = null, + bool deterministicInference = false) { // If model is null, no inference to run and exception is thrown before reaching here. if (barracudaModel == null) @@ -64,13 +67,13 @@ public TensorApplier( } if (actionSpec.NumContinuousActions > 0) { - var tensorName = model.ContinuousOutputName(); + var tensorName = model.ContinuousOutputName(deterministicInference); m_Dict[tensorName] = new ContinuousActionOutputApplier(actionSpec); } var modelVersion = model.GetVersion(); if (actionSpec.NumDiscreteActions > 0) { - var tensorName = model.DiscreteOutputName(); + var tensorName = model.DiscreteOutputName(deterministicInference); if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0) { m_Dict[tensorName] = new LegacyDiscreteActionOutputApplier(actionSpec, seed, allocator); diff --git a/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs b/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs index feb521ebd8..39bed85792 100644 --- a/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs +++ b/com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs @@ -44,11 +44,14 @@ void Generate( /// Tensor allocator. /// Dictionary of AgentInfo.id to memory for use in the inference model. /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. public TensorGenerator( int seed, ITensorAllocator allocator, Dictionary> memories, - object barracudaModel = null) + object barracudaModel = null, + bool deterministicInference = false) { // If model is null, no inference to run and exception is thrown before reaching here. if (barracudaModel == null) @@ -76,13 +79,13 @@ public TensorGenerator( // Generators for Outputs - if (model.HasContinuousOutputs()) + if (model.HasContinuousOutputs(deterministicInference)) { - m_Dict[model.ContinuousOutputName()] = new BiDimensionalOutputGenerator(allocator); + m_Dict[model.ContinuousOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator); } - if (model.HasDiscreteOutputs()) + if (model.HasDiscreteOutputs(deterministicInference)) { - m_Dict[model.DiscreteOutputName()] = new BiDimensionalOutputGenerator(allocator); + m_Dict[model.DiscreteOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator); } m_Dict[TensorNames.RecurrentOutput] = new BiDimensionalOutputGenerator(allocator); m_Dict[TensorNames.ValueEstimateOutput] = new BiDimensionalOutputGenerator(allocator); diff --git a/com.unity.ml-agents/Runtime/Inference/TensorNames.cs b/com.unity.ml-agents/Runtime/Inference/TensorNames.cs index dc20e1f8f3..48ae04b5f6 100644 --- a/com.unity.ml-agents/Runtime/Inference/TensorNames.cs +++ b/com.unity.ml-agents/Runtime/Inference/TensorNames.cs @@ -23,6 +23,8 @@ internal static class TensorNames public const string DiscreteActionOutputShape = "discrete_action_output_shape"; public const string ContinuousActionOutput = "continuous_actions"; public const string DiscreteActionOutput = "discrete_actions"; + public const string DeterministicContinuousActionOutput = "deterministic_continuous_actions"; + public const string DeterministicDiscreteActionOutput = "deterministic_discrete_actions"; // Deprecated TensorNames entries for backward compatibility public const string IsContinuousControlDeprecated = "is_continuous_control"; diff --git a/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs index 5e76084b20..96a15b50d8 100644 --- a/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs @@ -45,6 +45,11 @@ internal class BarracudaPolicy : IPolicy ActionBuffers m_LastActionBuffer; int m_AgentId; + /// + /// Inference only: set to true if the action selection from model should be + /// deterministic. + /// + bool m_DeterministicInference; /// /// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors. @@ -73,19 +78,23 @@ internal class BarracudaPolicy : IPolicy /// The Neural Network to use. /// Which device Barracuda will run on. /// The name of the behavior. + /// Inference only: set to true if the action selection from model should be + /// deterministic. public BarracudaPolicy( ActionSpec actionSpec, IList actuators, NNModel model, InferenceDevice inferenceDevice, - string behaviorName + string behaviorName, + bool deterministicInference = false ) { - var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, actionSpec, inferenceDevice); + var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, actionSpec, inferenceDevice, deterministicInference); m_ModelRunner = modelRunner; m_BehaviorName = behaviorName; m_ActionSpec = actionSpec; m_Actuators = actuators; + m_DeterministicInference = deterministicInference; } /// diff --git a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs index ae05284d50..b0d369b910 100644 --- a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs +++ b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs @@ -177,6 +177,20 @@ public bool UseChildSensors set { m_UseChildSensors = value; } } + [HideInInspector] + [SerializeField] + [Tooltip("Set action selection to deterministic, Only applies to inference from within unity.")] + private bool m_DeterministicInference = false; + + /// + /// Whether to select actions deterministically during inference from the provided neural network. + /// + public bool DeterministicInference + { + get { return m_DeterministicInference; } + set { m_DeterministicInference = value; } + } + /// /// Whether or not to use all the actuator components attached to child GameObjects of the agent. /// Note that changing this after the Agent has been initialized will not have any effect. @@ -228,7 +242,7 @@ internal IPolicy GeneratePolicy(ActionSpec actionSpec, ActuatorManager actuatorM "Either assign a model, or change to a different Behavior Type." ); } - return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName); + return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName, m_DeterministicInference); } case BehaviorType.Default: if (Academy.Instance.IsCommunicatorOn) @@ -237,7 +251,7 @@ internal IPolicy GeneratePolicy(ActionSpec actionSpec, ActuatorManager actuatorM } if (m_Model != null) { - return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName); + return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName, m_DeterministicInference); } else { diff --git a/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs b/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs index 0e81c4f8ad..da802a38d5 100644 --- a/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs +++ b/com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs @@ -1,3 +1,4 @@ +using System; using System.Linq; using NUnit.Framework; using UnityEngine; @@ -6,9 +7,29 @@ using Unity.MLAgents.Actuators; using Unity.MLAgents.Inference; using Unity.MLAgents.Policies; +using System.Collections.Generic; namespace Unity.MLAgents.Tests { + public class FloatThresholdComparer : IEqualityComparer + { + private readonly float _threshold; + public FloatThresholdComparer(float threshold) + { + _threshold = threshold; + } + + public bool Equals(float x, float y) + { + return Math.Abs(x - y) < _threshold; + } + + public int GetHashCode(float f) + { + throw new NotImplementedException("Unable to generate a hash code for threshold floats, do not use this method"); + } + } + [TestFixture] public class ModelRunnerTest { @@ -19,6 +40,9 @@ public class ModelRunnerTest const string k_hybridONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction_v1_0.onnx"; const string k_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated_v1_0.nn"; const string k_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated_v1_0.nn"; + // models with deterministic action tensors + private const string k_deterministic_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx"; + private const string k_deterministic_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx"; NNModel hybridONNXModelV2; NNModel continuousONNXModel; @@ -26,6 +50,8 @@ public class ModelRunnerTest NNModel hybridONNXModel; NNModel continuousNNModel; NNModel discreteNNModel; + NNModel deterministicDiscreteNNModel; + NNModel deterministicContinuousNNModel; Test3DSensorComponent sensor_21_20_3; Test3DSensorComponent sensor_20_22_3; @@ -55,6 +81,8 @@ public void SetUp() hybridONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_hybridONNXPath, typeof(NNModel)); continuousNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousNNPath, typeof(NNModel)); discreteNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteNNPath, typeof(NNModel)); + deterministicDiscreteNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_deterministic_discreteNNPath, typeof(NNModel)); + deterministicContinuousNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_deterministic_continuousNNPath, typeof(NNModel)); var go = new GameObject("SensorA"); sensor_21_20_3 = go.AddComponent(); sensor_21_20_3.Sensor = new Test3DSensor("SensorA", 21, 20, 3); @@ -71,6 +99,8 @@ public void TestModelExist() Assert.IsNotNull(continuousNNModel); Assert.IsNotNull(discreteNNModel); Assert.IsNotNull(hybridONNXModelV2); + Assert.IsNotNull(deterministicDiscreteNNModel); + Assert.IsNotNull(deterministicContinuousNNModel); } [Test] @@ -99,6 +129,15 @@ public void TestCreation() // This one was trained with 2.0 so it should not raise an error: modelRunner = new ModelRunner(hybridONNXModelV2, new ActionSpec(2, new[] { 2, 3 }), inferenceDevice); modelRunner.Dispose(); + + // V2.0 Model that has serialized deterministic action tensors, discrete + modelRunner = new ModelRunner(deterministicDiscreteNNModel, new ActionSpec(0, new[] { 7 }), inferenceDevice); + modelRunner.Dispose(); + // V2.0 Model that has serialized deterministic action tensors, continuous + modelRunner = new ModelRunner(deterministicContinuousNNModel, + GetContinuous2vis8vec2actionActionSpec(), inferenceDevice, + deterministicInference: true); + modelRunner.Dispose(); } [Test] @@ -138,5 +177,60 @@ public void TestRunModel() Assert.AreEqual(actionSpec.NumDiscreteActions, modelRunner.GetAction(1).DiscreteActions.Length); modelRunner.Dispose(); } + + + [Test] + public void TestRunModel_stochastic() + { + var actionSpec = GetContinuous2vis8vec2actionActionSpec(); + // deterministicInference = false by default + var modelRunner = new ModelRunner(deterministicContinuousNNModel, actionSpec, InferenceDevice.Burst); + var sensor_8 = new Sensors.VectorSensor(8, "VectorSensor8"); + var info1 = new AgentInfo(); + var obs = new[] + { + sensor_8, + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }.ToList(); + info1.episodeId = 1; + modelRunner.PutObservations(info1, obs); + modelRunner.DecideBatch(); + var stochAction1 = (float[])modelRunner.GetAction(1).ContinuousActions.Array.Clone(); + + modelRunner.PutObservations(info1, obs); + modelRunner.DecideBatch(); + var stochAction2 = (float[])modelRunner.GetAction(1).ContinuousActions.Array.Clone(); + // Stochastic action selection should output randomly different action values with same obs + Assert.IsFalse(Enumerable.SequenceEqual(stochAction1, stochAction2, new FloatThresholdComparer(0.001f))); + modelRunner.Dispose(); + } + [Test] + public void TestRunModel_deterministic() + { + var actionSpec = GetContinuous2vis8vec2actionActionSpec(); + var modelRunner = new ModelRunner(deterministicContinuousNNModel, actionSpec, InferenceDevice.Burst); + var sensor_8 = new Sensors.VectorSensor(8, "VectorSensor8"); + var info1 = new AgentInfo(); + var obs = new[] + { + sensor_8, + sensor_21_20_3.CreateSensors()[0], + sensor_20_22_3.CreateSensors()[0] + }.ToList(); + var deterministicModelRunner = new ModelRunner(deterministicContinuousNNModel, actionSpec, InferenceDevice.Burst, + deterministicInference: true); + info1.episodeId = 1; + deterministicModelRunner.PutObservations(info1, obs); + deterministicModelRunner.DecideBatch(); + var deterministicAction1 = (float[])deterministicModelRunner.GetAction(1).ContinuousActions.Array.Clone(); + + deterministicModelRunner.PutObservations(info1, obs); + deterministicModelRunner.DecideBatch(); + var deterministicAction2 = (float[])deterministicModelRunner.GetAction(1).ContinuousActions.Array.Clone(); + // Deterministic action selection should output same action everytime + Assert.IsTrue(Enumerable.SequenceEqual(deterministicAction1, deterministicAction2, new FloatThresholdComparer(0.001f))); + modelRunner.Dispose(); + } } } diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx b/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx new file mode 100644 index 0000000000..56c1cd4355 Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx differ diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx.meta b/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx.meta new file mode 100644 index 0000000000..cc92cc94b8 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx.meta @@ -0,0 +1,14 @@ +fileFormatVersion: 2 +guid: e905d8f9eadcf45aa8c485594fecba6d +ScriptedImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 2 + userData: + assetBundleName: + assetBundleVariant: + script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3} + optimizeModel: 1 + forceArbitraryBatchSize: 1 + treatErrorsAsWarnings: 0 + importMode: 1 diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx b/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx new file mode 100644 index 0000000000..3aa846e204 Binary files /dev/null and b/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx differ diff --git a/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx.meta b/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx.meta new file mode 100644 index 0000000000..a141a55235 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx.meta @@ -0,0 +1,14 @@ +fileFormatVersion: 2 +guid: d132cc9c934a54fdc99758427373e038 +ScriptedImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 2 + userData: + assetBundleName: + assetBundleVariant: + script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3} + optimizeModel: 1 + forceArbitraryBatchSize: 1 + treatErrorsAsWarnings: 0 + importMode: 1 diff --git a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs index 0c3b6312b4..743212a69b 100644 --- a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs +++ b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs @@ -70,6 +70,7 @@ public IEnumerator RuntimeApiTestWithEnumeratorPasses() behaviorParams.BehaviorName = "TestBehavior"; behaviorParams.TeamId = 42; behaviorParams.UseChildSensors = true; + behaviorParams.DeterministicInference = false; behaviorParams.ObservableAttributeHandling = ObservableAttributeOptions.ExamineAll; diff --git a/docs/Getting-Started.md b/docs/Getting-Started.md index c4ec5332ad..7baf461ebd 100644 --- a/docs/Getting-Started.md +++ b/docs/Getting-Started.md @@ -119,6 +119,9 @@ example. **Note** : You can modify multiple game objects in a scene by selecting them all at once using the search bar in the Scene Hierarchy. 1. Set the **Inference Device** to use for this model as `CPU`. +1. If the model is trained with Release 19 or later, you can select + `Deterministic Inference` to choose actions deterministically from the model. + Works only for inference within unity with no python process involved. 1. Click the **Play** button in the Unity Editor and you will see the platforms balance the balls using the pre-trained model. diff --git a/docs/Learning-Environment-Design-Agents.md b/docs/Learning-Environment-Design-Agents.md index 8953894957..dc7022189f 100644 --- a/docs/Learning-Environment-Design-Agents.md +++ b/docs/Learning-Environment-Design-Agents.md @@ -987,6 +987,9 @@ be called independently of the `Max Step` property. training) - `Inference Device` - Whether to use CPU or GPU to run the model during inference + - `Deterministic Inference` - Weather to set action selection to deterministic, + Only applies to inference from within unity (with no python process involved) and + Release 19 or later. - `Behavior Type` - Determines whether the Agent will do training, inference, or use its Heuristic() method: - `Default` - the Agent will train if they connect to a python trainer, diff --git a/docs/images/3dball_learning_brain.png b/docs/images/3dball_learning_brain.png index c133bf2779..68757fa6ce 100644 Binary files a/docs/images/3dball_learning_brain.png and b/docs/images/3dball_learning_brain.png differ