Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLA-1634] Add ObservationSpec and update ISensor interfaces #5127

Merged
merged 23 commits into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DevProject/Packages/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"com.unity.package-manager-doctools": "1.7.0-preview",
"com.unity.package-validation-suite": "0.19.0-preview",
"com.unity.purchasing": "2.2.1",
"com.unity.test-framework": "1.1.20",
"com.unity.test-framework": "1.1.22",
"com.unity.test-framework.performance": "2.2.0-preview",
"com.unity.testtools.codecoverage": "1.0.0-pre.3",
"com.unity.textmeshpro": "2.0.1",
Expand Down
2 changes: 1 addition & 1 deletion DevProject/Packages/packages-lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@
"url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates"
},
"com.unity.test-framework": {
"version": "1.1.20",
"version": "1.1.22",
"depth": 0,
"source": "registry",
"dependencies": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@
"m_Name": "Settings",
"m_Path": "ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json",
"m_Dictionary": {
"m_DictionaryValues": []
"m_DictionaryValues": [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what to do with this. It got added in 7718dfe#diff-41dd07c8a0392e1f6c2a7ba1277f3426cdc92aa8258e563912c39db66800ade0 but wasn't computing coverage on the assemblies we care about.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I think it's just for local runs of code coverage

{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "Path",
"value": "{\"m_Value\":\"{ProjectPath}\"}"
},
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "HistoryPath",
"value": "{\"m_Value\":\"{ProjectPath}\"}"
},
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "IncludeAssemblies",
"value": "{\"m_Value\":\"Assembly-CSharp,Runtime,Unity.ML-Agents,Unity.ML-Agents.Extensions\"}"
}
]
}
}
4 changes: 2 additions & 2 deletions DevProject/ProjectSettings/ProjectVersion.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
m_EditorVersion: 2019.4.19f1
m_EditorVersionWithRevision: 2019.4.19f1 (ca5b14067cec)
m_EditorVersion: 2019.4.20f1
m_EditorVersionWithRevision: 2019.4.20f1 (6dd1c08eedfa)
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ public override void WriteObservation(float[] output)
}

/// <inheritdoc/>
public override int[] GetObservationShape()
public override ObservationSpec GetObservationSpec()
{
return new[] { BasicController.k_Extents };
return ObservationSpec.Vector(BasicController.k_Extents);
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public abstract class SensorBase : ISensor
public abstract void WriteObservation(float[] output);

/// <inheritdoc/>
public abstract int[] GetObservationShape();
public abstract ObservationSpec GetObservationSpec();

/// <inheritdoc/>
public abstract string GetName();
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ public class TestTextureSensor : ISensor
{
Texture2D m_Texture;
string m_Name;
int[] m_Shape;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;

/// <summary>
Expand All @@ -25,7 +25,7 @@ public TestTextureSensor(
var width = texture.width;
var height = texture.height;
m_Name = name;
m_Shape = new[] { height, width, 3 };
m_ObservationSpec = ObservationSpec.Visual(height, width, 3);
m_CompressionType = compressionType;
}

Expand All @@ -36,9 +36,9 @@ public string GetName()
}

/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
{
return m_Shape;
return m_ObservationSpec;
}

/// <inheritdoc/>
Expand Down
12 changes: 6 additions & 6 deletions com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class Match3Sensor : ISparseChannelSensor, IBuiltInSensor
{
private Match3ObservationType m_ObservationType;
private AbstractBoard m_Board;
private int[] m_Shape;
private ObservationSpec m_ObservationSpec;
private int[] m_SparseChannelMapping;
private string m_Name;

Expand Down Expand Up @@ -70,9 +70,9 @@ public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string n
m_NumSpecialTypes = board.NumSpecialTypes;

m_ObservationType = obsType;
m_Shape = obsType == Match3ObservationType.Vector ?
new[] { m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize) } :
new[] { m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize };
m_ObservationSpec = obsType == Match3ObservationType.Vector
? ObservationSpec.Vector(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize))
: ObservationSpec.Visual(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize);

// See comment in GetCompressedObservation()
var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3);
Expand All @@ -96,9 +96,9 @@ public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string n
}

/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
{
return m_Shape;
return m_ObservationSpec;
}

/// <inheritdoc/>
Expand Down
30 changes: 17 additions & 13 deletions com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ public enum GridDepthType { Channel, ChannelHot };
protected bool Initialized = false;

/// <summary>
/// Array holding the dimensions of the resulting tensor
/// Cached ObservationSpec
/// </summary>
private int[] m_Shape;
private ObservationSpec m_ObservationSpec;

//
// Debug Parameters
Expand Down Expand Up @@ -423,7 +423,7 @@ public virtual void Start()
// Default root reference to current game object
if (rootReference == null)
rootReference = gameObject;
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);

compressedImgs = new List<byte[]>();
byteSizesBytesList = new List<byte[]>();
Expand Down Expand Up @@ -475,14 +475,6 @@ public void ClearPerceptionBuffer()
}
}

/// <summary>Gets the shape of the grid observation</summary>
/// <returns>integer array shape of the grid observation</returns>
public int[] GetFloatObservationShape()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was unused.

{
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
}

/// <inheritdoc/>
public string GetName()
{
Expand Down Expand Up @@ -914,10 +906,22 @@ void ISensor.Update()

/// <summary>Gets the observation shape</summary>
/// <returns>int[] of the observation shape</returns>
public ObservationSpec GetObservationSpec()
{
// Lazy update
var shape = m_ObservationSpec.Shape;
if (shape[0] != GridNumSideX || shape[1] != GridNumSideZ || shape[2] != ObservationPerCell)
{
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);
}
return m_ObservationSpec;
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
var shape = m_ObservationSpec.Shape;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the SensorComponent override, not sure if we should change that to return a spec too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would lean toward yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'm going to punt that to a separate PR.

return new int[] { shape[0], shape[1], shape[2] };
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Unity.MLAgents.Extensions.Sensors
/// </summary>
public class PhysicsBodySensor : ISensor, IBuiltInSensor
{
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_SensorName;

PoseExtractor m_PoseExtractor;
Expand Down Expand Up @@ -44,7 +44,7 @@ string sensorName
}

var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}

#if UNITY_2020_1_OR_NEWER
Expand All @@ -65,14 +65,14 @@ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settin
}

var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}
#endif

/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
{
return m_Shape;
return m_ObservationSpec;
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public void TestVectorObservations()

var expectedShape = new[] { 3 * 3 * 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

var expectedObs = new float[]
{
Expand Down Expand Up @@ -65,7 +65,7 @@ public void TestVectorObservationsSpecial()

var expectedShape = new[] { 3 * 3 * (2 + 3) };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

var expectedObs = new float[]
{
Expand Down Expand Up @@ -94,7 +94,7 @@ public void TestVisualObservations()

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

Expand Down Expand Up @@ -138,7 +138,7 @@ public void TestVisualObservationsSpecial()

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

Expand Down Expand Up @@ -176,7 +176,7 @@ public void TestCompressedVisualObservations()

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

Expand Down Expand Up @@ -216,7 +216,7 @@ public void TestCompressedVisualObservationsSpecial()

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void OneChannelDepthOne()
gridSensor.Start();

int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());

}

Expand All @@ -52,7 +52,7 @@ public void OneChannelDepthTwo()
gridSensor.Start();

int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());

}

Expand All @@ -67,7 +67,7 @@ public void TwoChannelsDepthTwoOne()
gridSensor.Start();

int[] expectedShape = { 10, 10, 3 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());

}

Expand All @@ -82,7 +82,7 @@ public void TwoChannelsDepthThreeThree()
gridSensor.Start();

int[] expectedShape = { 10, 10, 6 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void OneChannel()
gridSensor.Start();

int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

[Test]
Expand All @@ -49,7 +49,7 @@ public void TwoChannel()
gridSensor.Start();

int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

[Test]
Expand All @@ -63,7 +63,7 @@ public void SevenChannel()
gridSensor.Start();

int[] expectedShape = { 10, 10, 7 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
}
}
9 changes: 6 additions & 3 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ and this project adheres to
## [Unreleased]
### Major Changes
#### com.unity.ml-agents (C#)
- Several breaking interface changes were made. See the
[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more
details.
- Some methods previously marked as `Obsolete` have been removed. If you were using these methods, you need to replace them with their supported counterpart.
- The interface for disabling discrete actions in `IDiscreteActionMask` has changed.
`WriteMask(int branch, IEnumerable<int> actionIndices)` was replaced with
`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. See the
`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. (#5060)
- IActuator now implements IHeuristicProvider. (#5110)
[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more
details. (#5060)
- `ISensor.GetObservationShape()` was removed, and `GetObservationSpec()` was added. (#5127)

#### ml-agents / ml-agents-envs / gym-unity (Python)

### Minor Changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ public interface IDiscreteActionMask
/// <summary>
/// Set whether or not the action index for the given branch is allowed.
/// </summary>
/// <remarks>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not from here, but was broken in staging.

/// By default, all discrete actions are allowed.
/// If isEnabled is false, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndex correspond
Expand Down
Loading