Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop var len obs feature refactor model loader checks #4912

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project/ProjectSettings/ProjectVersion.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
m_EditorVersion: 2018.4.20f1
m_EditorVersion: 2018.4.24f1
16 changes: 5 additions & 11 deletions com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,10 @@ void DisplayFailedModelChecks()

// Grab the sensor components, since we need them to determine the observation sizes.
// TODO make these methods of BehaviorParameters
SensorComponent[] sensorComponents;
if (behaviorParameters.UseChildSensors)
{
sensorComponents = behaviorParameters.GetComponentsInChildren<SensorComponent>();
}
else
{
sensorComponents = behaviorParameters.GetComponents<SensorComponent>();
}
var agent = behaviorParameters.gameObject.GetComponent<Agent>();
agent.sensors = new List<ISensor>();
agent.InitializeSensors();
var sensors = agent.sensors.ToArray();

ActuatorComponent[] actuatorComponents;
if (behaviorParameters.UseChildActuators)
Expand All @@ -127,7 +122,6 @@ void DisplayFailedModelChecks()
// Get the total size of the sensors generated by ObservableAttributes.
// If there are any errors (e.g. unsupported type, write-only properties), display them too.
int observableAttributeSensorTotalSize = 0;
var agent = behaviorParameters.GetComponent<Agent>();
if (agent != null && behaviorParameters.ObservableAttributeHandling != ObservableAttributeOptions.Ignore)
{
List<string> observableErrors = new List<string>();
Expand All @@ -146,7 +140,7 @@ void DisplayFailedModelChecks()
if (brainParameters != null)
{
var failedChecks = Inference.BarracudaModelParamLoader.CheckModel(
barracudaModel, brainParameters, sensorComponents, actuatorComponents,
barracudaModel, brainParameters, sensors, actuatorComponents,
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType
);
foreach (var check in failedChecks)
Expand Down
4 changes: 4 additions & 0 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,10 @@ public virtual void Heuristic(in ActionBuffers actionsOut)
/// </summary>
internal void InitializeSensors()
{
if (m_PolicyFactory == null)
{
m_PolicyFactory = GetComponent<BehaviorParameters>();
surfnerd marked this conversation as resolved.
Show resolved Hide resolved
}
if (m_PolicyFactory.ObservableAttributeHandling != ObservableAttributeOptions.Ignore)
{
var excludeInherited =
Expand Down
80 changes: 37 additions & 43 deletions com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ internal class BarracudaModelParamLoader
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="sensorComponents">Attached sensor components</param>
/// <param name="sensors">Attached sensor components</param>
/// <param name="actuatorComponents">Attached actuator components</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <param name="behaviorType">BehaviorType or the Agent to check.</param>
/// <returns>The list the error messages of the checks that failed</returns>
public static IEnumerable<string> CheckModel(Model model, BrainParameters brainParameters,
SensorComponent[] sensorComponents, ActuatorComponent[] actuatorComponents,
ISensor[] sensors, ActuatorComponent[] actuatorComponents,
int observableAttributeTotalSize = 0,
BehaviorType behaviorType = BehaviorType.Default)
{
Expand Down Expand Up @@ -82,13 +82,13 @@ public static IEnumerable<string> CheckModel(Model model, BrainParameters brainP
}

failedModelChecks.AddRange(
CheckInputTensorPresence(model, brainParameters, memorySize, sensorComponents)
CheckInputTensorPresence(model, brainParameters, memorySize, sensors)
);
failedModelChecks.AddRange(
CheckOutputTensorPresence(model, memorySize)
);
failedModelChecks.AddRange(
CheckInputTensorShape(model, brainParameters, sensorComponents, observableAttributeTotalSize)
CheckInputTensorShape(model, brainParameters, sensors, observableAttributeTotalSize)
);
failedModelChecks.AddRange(
CheckOutputTensorShape(model, brainParameters, actuatorComponents)
Expand All @@ -109,15 +109,15 @@ public static IEnumerable<string> CheckModel(Model model, BrainParameters brainP
/// <param name="memory">
/// The memory size that the model is expecting.
/// </param>
/// <param name="sensorComponents">Array of attached sensor components</param>
/// <param name="sensors">Array of attached sensor components</param>
/// <returns>
/// A IEnumerable of string corresponding to the failed input presence checks.
/// </returns>
static IEnumerable<string> CheckInputTensorPresence(
Model model,
BrainParameters brainParameters,
int memory,
SensorComponent[] sensorComponents
ISensor[] sensors
)
{
var failedModelChecks = new List<string>();
Expand All @@ -135,10 +135,9 @@ SensorComponent[] sensorComponents
// If there are not enough Visual Observation Input compared to what the
// sensors expect.
var visObsIndex = 0;
var varLenIndex = 0;
for (var sensorIndex = 0; sensorIndex < sensorComponents.Length; sensorIndex++)
for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sensor = sensorComponents[sensorIndex];
var sensor = sensors[sensorIndex];
if (sensor.GetObservationShape().Length == 3)
{
if (!tensorsNames.Contains(
Expand All @@ -153,12 +152,11 @@ SensorComponent[] sensorComponents
if (sensor.GetObservationShape().Length == 2)
{
if (!tensorsNames.Contains(
TensorNames.ObservationPlaceholderPrefix + varLenIndex))
TensorNames.ObservationPlaceholderPrefix + sensorIndex))
{
failedModelChecks.Add(
"The model does not contain an Observation Placeholder Input " +
$"for sensor component {varLenIndex} ({sensor.GetType().Name}).");
varLenIndex++;
$"for sensor component {sensorIndex} ({sensor.GetType().Name}).");
}
}

Expand Down Expand Up @@ -231,15 +229,15 @@ static IEnumerable<string> CheckOutputTensorPresence(Model model, int memory)
/// Checks that the shape of the visual observation input placeholder is the same as the corresponding sensor.
/// </summary>
/// <param name="tensorProxy">The tensor that is expected by the model</param>
/// <param name="sensorComponent">The sensor that produces the visual observation.</param>
/// <param name="sensor">The sensor that produces the visual observation.</param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckVisualObsShape(
TensorProxy tensorProxy, SensorComponent sensorComponent)
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensorComponent.GetObservationShape();
var shape = sensor.GetObservationShape();
var heightBp = shape[0];
var widthBp = shape[1];
var pixelBp = shape[2];
Expand All @@ -259,15 +257,15 @@ static string CheckVisualObsShape(
/// Checks that the shape of the rank 2 observation input placeholder is the same as the corresponding sensor.
/// </summary>
/// <param name="tensorProxy">The tensor that is expected by the model</param>
/// <param name="sensorComponent">The sensor that produces the visual observation.</param>
/// <param name="sensor">The sensor that produces the visual observation.</param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckRankTwoObsShape(
TensorProxy tensorProxy, SensorComponent sensorComponent)
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensorComponent.GetObservationShape();
var shape = sensor.GetObservationShape();
var dim1Bp = shape[0];
var dim2Bp = shape[1];
var dim1T = tensorProxy.Channels;
Expand All @@ -291,16 +289,16 @@ static string CheckRankTwoObsShape(
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="sensorComponents">Attached sensors</param>
/// <param name="sensors">Attached sensors</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <returns>The list the error messages of the checks that failed</returns>
static IEnumerable<string> CheckInputTensorShape(
Model model, BrainParameters brainParameters, SensorComponent[] sensorComponents,
Model model, BrainParameters brainParameters, ISensor[] sensors,
int observableAttributeTotalSize)
{
var failedModelChecks = new List<string>();
var tensorTester =
new Dictionary<string, Func<BrainParameters, TensorProxy, SensorComponent[], int, string>>()
new Dictionary<string, Func<BrainParameters, TensorProxy, ISensor[], int, string>>()
{
{TensorNames.VectorObservationPlaceholder, CheckVectorObsShape},
{TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape},
Expand All @@ -316,22 +314,20 @@ static IEnumerable<string> CheckInputTensorShape(
}

var visObsIndex = 0;
var varLenIndex = 0;
for (var sensorIndex = 0; sensorIndex < sensorComponents.Length; sensorIndex++)
for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sensorComponent = sensorComponents[sensorIndex];
if (sensorComponent.GetObservationShape().Length == 3)
var sens = sensors[sensorIndex];
if (sens.GetObservationShape().Length == 3)
{

tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] =
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sensorComponent);
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sens);
visObsIndex++;
}
if (sensorComponent.GetObservationShape().Length == 2)
if (sens.GetObservationShape().Length == 2)
{
tensorTester[TensorNames.ObservationPlaceholderPrefix + varLenIndex] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sensorComponent);
varLenIndex++;
tensorTester[TensorNames.ObservationPlaceholderPrefix + sensorIndex] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens);
}
}

Expand All @@ -349,7 +345,7 @@ static IEnumerable<string> CheckInputTensorShape(
else
{
var tester = tensorTester[tensor.name];
var error = tester.Invoke(brainParameters, tensor, sensorComponents, observableAttributeTotalSize);
var error = tester.Invoke(brainParameters, tensor, sensors, observableAttributeTotalSize);
if (error != null)
{
failedModelChecks.Add(error);
Expand All @@ -367,35 +363,33 @@ static IEnumerable<string> CheckInputTensorShape(
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="tensorProxy">The tensor that is expected by the model</param>
/// <param name="sensorComponents">Array of attached sensor components</param>
/// <param name="sensors">Array of attached sensor components</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckVectorObsShape(
BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents,
BrainParameters brainParameters, TensorProxy tensorProxy, ISensor[] sensors,
int observableAttributeTotalSize)
{
var vecObsSizeBp = brainParameters.VectorObservationSize;
var numStackedVector = brainParameters.NumStackedVectorObservations;
var totalVecObsSizeT = tensorProxy.shape[tensorProxy.shape.Length - 1];

var totalVectorSensorSize = 0;
foreach (var sensorComp in sensorComponents)
foreach (var sens in sensors)
{
if (sensorComp.GetObservationShape().Length == 1)
if ((sens.GetObservationShape().Length == 1))
{
totalVectorSensorSize += sensorComp.GetObservationShape()[0];
totalVectorSensorSize += sens.GetObservationShape()[0];
}
}

totalVectorSensorSize += observableAttributeTotalSize;

if (vecObsSizeBp * numStackedVector + totalVectorSensorSize != totalVecObsSizeT)
if (totalVectorSensorSize != totalVecObsSizeT)
{
var sensorSizes = "";
foreach (var sensorComp in sensorComponents)
foreach (var sensorComp in sensors)
{
if (sensorComp.GetObservationShape().Length == 1)
{
Expand All @@ -416,7 +410,7 @@ static string CheckVectorObsShape(
$"but received: \n" +
$"Vector observations: {vecObsSizeBp} x {numStackedVector}\n" +
$"Total [Observable] attributes: {observableAttributeTotalSize}\n" +
$"SensorComponent sizes: {sensorSizes}.";
$"Sensor sizes: {sensorSizes}.";
}
return null;
}
Expand All @@ -429,13 +423,13 @@ static string CheckVectorObsShape(
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="tensorProxy"> The tensor that is expected by the model</param>
/// <param name="sensorComponents">Array of attached sensor components (unused).</param>
/// <param name="sensors">Array of attached sensor components (unused).</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes (unused).</param>
/// <returns>If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.</returns>
static string CheckPreviousActionShape(
BrainParameters brainParameters, TensorProxy tensorProxy,
SensorComponent[] sensorComponents, int observableAttributeTotalSize)
ISensor[] sensors, int observableAttributeTotalSize)
{
var numberActionsBp = brainParameters.ActionSpec.NumDiscreteActions;
var numberActionsT = tensorProxy.shape[tensorProxy.shape.Length - 1];
Expand Down
6 changes: 3 additions & 3 deletions com.unity.ml-agents/Runtime/Inference/ModelRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ internal class ModelRunner

SensorShapeValidator m_SensorShapeValidator = new SensorShapeValidator();

bool m_VisualObservationsInitialized;
bool m_ObservationsInitialized;

/// <summary>
/// Initializes the Brain with the Model that it will use when selecting actions for
Expand Down Expand Up @@ -161,13 +161,13 @@ public void DecideBatch()
{
return;
}
if (!m_VisualObservationsInitialized)
if (!m_ObservationsInitialized)
{
// Just grab the first agent in the collection (any will suffice, really).
// We check for an empty Collection above, so this will always return successfully.
var firstInfo = m_Infos[0];
m_TensorGenerator.InitializeObservations(firstInfo.sensors, m_TensorAllocator);
m_VisualObservationsInitialized = true;
m_ObservationsInitialized = true;
}

Profiler.BeginSample("ModelRunner.DecideAction");
Expand Down
19 changes: 15 additions & 4 deletions com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@

namespace Unity.MLAgents.Sensors
{
/// <summary>
/// A Sensor that allows to observe a variable number of entities.
/// </summary>
public class BufferSensor : ISensor, IDimensionPropertiesSensor, IBuiltInSensor
{
private int m_MaxNumObs;
private int m_ObsSize;
float[] m_ObservationBuffer;
int m_CurrentNumObservables;
static DimensionProperty[] m_DimensionProperties = new DimensionProperty[]{
DimensionProperty.VariableSize,
DimensionProperty.None
};
public BufferSensor(int maxNumberObs, int obsSize)
{
m_MaxNumObs = maxNumberObs;
Expand All @@ -25,10 +32,7 @@ public int[] GetObservationShape()
/// <inheritdoc/>
public DimensionProperty[] GetDimensionProperties()
{
return new DimensionProperty[]{
DimensionProperty.VariableSize,
DimensionProperty.None
};
return m_DimensionProperties;
}

/// <summary>
Expand All @@ -40,6 +44,13 @@ public DimensionProperty[] GetDimensionProperties()
/// <param name="obs"> The float array observation</param>
public void AppendObservation(float[] obs)
{
if (obs.Length != m_ObsSize)
{
throw new UnityAgentsException(
"The BufferSensor was expecting an observation of size " +
$"{m_ObsSize} but received {obs.Length} observations instead."
);
}
if (m_CurrentNumObservables >= m_MaxNumObs)
{
return;
Expand Down
12 changes: 11 additions & 1 deletion com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,23 @@ namespace Unity.MLAgents.Sensors
{

/// <summary>
/// A component for BufferSensor.
/// A SensorComponent that creates a <see cref="BufferSensor"/>.
/// </summary>
[AddComponentMenu("ML Agents/Buffer Sensor", (int)MenuGroup.Sensors)]
public class BufferSensorComponent : SensorComponent
{
/// <summary>
/// This is how many floats each entities will be represented with. This number
/// is fixed and all entities must have the same representation.
/// </summary>
public int ObservableSize;

/// <summary>
/// This is the maximum number of entities the `BufferSensor` will be able to
/// collect.
/// </summary>
public int MaxNumObservables;

private BufferSensor m_Sensor;

/// <inheritdoc/>
Expand Down
Loading