Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V2 staging new model version #5080

Merged
merged 24 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2130c39
Make modelCheck have flavors of error messages
vincentpierre Mar 10, 2021
a03e8a7
ONNX exporter v3
vincentpierre Mar 11, 2021
ad7fc3a
Using a better CheckType and a switch statement
vincentpierre Mar 11, 2021
59cd5f9
Removing unused message
vincentpierre Mar 11, 2021
4dec29b
More tests
vincentpierre Mar 11, 2021
abfb268
Use an enum for valid versions and use GetVersion on model directly
vincentpierre Mar 11, 2021
93e18b5
Maybe the model export version a static constant in Python
vincentpierre Mar 11, 2021
4afde10
Use static constructor for FailedCheck
vincentpierre Mar 11, 2021
393f146
Use static constructor for FailedCheck
vincentpierre Mar 11, 2021
d305b5c
Modifying the docstrings
vincentpierre Mar 11, 2021
067933e
renaming LegacyDiscreteActionOutputApplier
vincentpierre Mar 11, 2021
24af8ab
removing testing code
vincentpierre Mar 11, 2021
8a4cadd
better warning message
vincentpierre Mar 11, 2021
bb9ddd6
Nest the CheckTypeEnum into the FailedCheck class
vincentpierre Mar 11, 2021
dc4865b
Update com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoade…
vincentpierre Mar 11, 2021
405b18a
Adding a line explaining that legacy tensor checks are for versions 1…
vincentpierre Mar 11, 2021
fb459a7
Modifying the changelog
vincentpierre Mar 11, 2021
b79d44e
Exporting all the branches size instead of omly the sum (#5092)
vincentpierre Mar 12, 2021
fa22688
addressing comments
vincentpierre Mar 12, 2021
f6317fd
Update com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoade…
vincentpierre Mar 12, 2021
1b48e4b
readding tests
vincentpierre Mar 12, 2021
7d1d336
Adding a comment around the new DiscreteOutputSize method
vincentpierre Mar 15, 2021
a124c96
Clearer warning : Model contains unexpected input > Model requires un…
vincentpierre Mar 15, 2021
61412ad
Fixing a bug in the case where the discrete action tensor does not exist
vincentpierre Mar 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ details. (#5060)

### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
- The `.onnx` models input names have changed. All input placeholders will now use the prefix `obs_` removing the distinction between visual and vector observations. Models created with this version will not be usable with previous versions of the package (#5080)
- The `.onnx` models discrete action output now contains the discrete actions values and not the logits. Models created with this version will not be usable with previous versions of the package (#5080)
#### ml-agents / ml-agents-envs / gym-unity (Python)

### Bug Fixes
Expand Down
16 changes: 15 additions & 1 deletion com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using CheckTypeEnum = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck.CheckTypeEnum;

namespace Unity.MLAgents.Editor
{
Expand Down Expand Up @@ -147,7 +148,20 @@ void DisplayFailedModelChecks()
{
if (check != null)
{
EditorGUILayout.HelpBox(check, MessageType.Warning);
switch (check.CheckType)
{
case CheckTypeEnum.Info:
EditorGUILayout.HelpBox(check.Message, MessageType.Info);
break;
case CheckTypeEnum.Warning:
EditorGUILayout.HelpBox(check.Message, MessageType.Warning);
break;
case CheckTypeEnum.Error:
EditorGUILayout.HelpBox(check.Message, MessageType.Error);
break;
default:
break;
}
}
}
}
Expand Down
44 changes: 42 additions & 2 deletions com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,51 @@ public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int,
}
}

/// <summary>
/// The Applier for the Discrete Action output tensor.
/// </summary>
internal class DiscreteActionOutputApplier : TensorApplier.IApplier
{
readonly ActionSpec m_ActionSpec;


public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
{
m_ActionSpec = actionSpec;
}

public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var agentIndex = 0;
var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
for (var i = 0; i < actionIds.Count; i++)
{
var agentId = actionIds[i];
if (lastActions.ContainsKey(agentId))
{
var actionBuffer = lastActions[agentId];
if (actionBuffer.IsEmpty())
{
actionBuffer = new ActionBuffers(m_ActionSpec);
lastActions[agentId] = actionBuffer;
}
var discreteBuffer = actionBuffer.DiscreteActions;
for (var j = 0; j < actionSize; j++)
{
discreteBuffer[j] = (int)tensorProxy.data[agentIndex, j];
}
}
agentIndex++;
}
}
}


/// <summary>
/// The Applier for the Discrete Action output tensor. Uses multinomial to sample discrete
/// actions from the logits contained in the tensor.
/// </summary>
internal class DiscreteActionOutputApplier : TensorApplier.IApplier
internal class LegacyDiscreteActionOutputApplier : TensorApplier.IApplier
{
readonly int[] m_ActionSize;
readonly Multinomial m_Multinomial;
Expand All @@ -59,7 +99,7 @@ internal class DiscreteActionOutputApplier : TensorApplier.IApplier
readonly float[] m_CdfBuffer;


public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
public LegacyDiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
{
m_ActionSize = actionSpec.BranchSizes;
m_Multinomial = new Multinomial(seed);
Expand Down
62 changes: 51 additions & 11 deletions com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Collections.Generic;
using System.Linq;
using Unity.Barracuda;
using FailedCheck = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck;

namespace Unity.MLAgents.Inference
{
Expand Down Expand Up @@ -38,6 +39,18 @@ public static string[] GetInputNames(this Model model)
return names.ToArray();
}

/// <summary>
/// Get the version of the model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>The api version of the model</returns>
public static int GetVersion(this Model model)
{
return (int)model.GetTensorByName(TensorNames.VersionNumber)[0];
}

/// <summary>
/// Generates the Tensor inputs that are expected to be present in the Model.
/// </summary>
Expand Down Expand Up @@ -226,7 +239,7 @@ public static bool HasDiscreteOutputs(this Model model)
else
{
return model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
(int)model.GetTensorByName(TensorNames.DiscreteActionOutputShape)[0] > 0;
(int)model.DiscreteOutputSize() > 0;
}
}

Expand All @@ -249,7 +262,19 @@ public static int DiscreteOutputSize(this Model model)
else
{
var discreteOutputShape = model.GetTensorByName(TensorNames.DiscreteActionOutputShape);
return discreteOutputShape == null ? 0 : (int)discreteOutputShape[0];
if (discreteOutputShape == null)
{
return 0;
}
else
{
int result = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a behavior change for legacy models, right? Is it correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not. (Unless I messed it up)
Because discreteOutputShape is currently a tensor of shape 1 ([sum_of_branches]) and will be changed to be ([branh0_size, branch1_size, ...]). So doing the sum does the same thing wether it is legacy or new.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks. Could you leave that as a comment here for the future readers?

for (int i = 0; i < discreteOutputShape.length; i++)
{
result += (int)discreteOutputShape[i];
}
return result;
}
}
}

Expand Down Expand Up @@ -298,21 +323,25 @@ public static bool SupportsContinuousAndDiscrete(this Model model)
/// <param name="failedModelChecks">Output list of failure messages</param>
///
/// <returns>True if the model contains all the expected tensors.</returns>
public static bool CheckExpectedTensors(this Model model, List<string> failedModelChecks)
public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks)
chriselion marked this conversation as resolved.
Show resolved Hide resolved
{
// Check the presence of model version
var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);
if (modelApiVersionTensor == null)
{
failedModelChecks.Add($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.");
failedModelChecks.Add(
FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.")
);
return false;
}

// Check the presence of memory size
var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize);
if (memorySizeTensor == null)
{
failedModelChecks.Add($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.");
failedModelChecks.Add(
FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.")
);
return false;
}

Expand All @@ -321,7 +350,9 @@ public static bool CheckExpectedTensors(this Model model, List<string> failedMod
!model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
!model.outputs.Contains(TensorNames.DiscreteActionOutput))
{
failedModelChecks.Add("The model does not contain any Action Output Node.");
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain any Action Output Node.")
);
return false;
}

Expand All @@ -330,13 +361,18 @@ public static bool CheckExpectedTensors(this Model model, List<string> failedMod
{
if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null)
{
failedModelChecks.Add("The model does not contain any Action Output Shape Node.");
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain any Action Output Shape Node.")
);
return false;
}
if (model.GetTensorByName(TensorNames.IsContinuousControlDeprecated) == null)
{
failedModelChecks.Add($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was not found in the model file. " +
"This is only required for model that uses a deprecated model format.");
failedModelChecks.Add(
FailedCheck.Warning($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was " +
"not found in the model file. " +
"This is only required for model that uses a deprecated model format.")
);
return false;
}
}
Expand All @@ -345,13 +381,17 @@ public static bool CheckExpectedTensors(this Model model, List<string> failedMod
if (model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
{
failedModelChecks.Add("The model uses continuous action but does not contain Continuous Action Output Shape Node.");
failedModelChecks.Add(
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
);
return false;
}
if (model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
{
failedModelChecks.Add("The model uses discrete action but does not contain Discrete Action Output Shape Node.");
failedModelChecks.Add(
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
);
return false;
}
}
Expand Down
Loading