Skip to content

Commit

Permalink
[MLA-1634] Add ObservationSpec and update ISensor interfaces (#5127)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Elion authored and surfnerd committed Mar 18, 2021
1 parent 80a69a9 commit 540595a
Show file tree
Hide file tree
Showing 54 changed files with 1,009 additions and 374 deletions.
2 changes: 1 addition & 1 deletion DevProject/Packages/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"com.unity.package-manager-doctools": "1.7.0-preview",
"com.unity.package-validation-suite": "0.19.0-preview",
"com.unity.purchasing": "2.2.1",
"com.unity.test-framework": "1.1.20",
"com.unity.test-framework": "1.1.22",
"com.unity.test-framework.performance": "2.2.0-preview",
"com.unity.testtools.codecoverage": "1.0.0-pre.3",
"com.unity.textmeshpro": "2.0.1",
Expand Down
2 changes: 1 addition & 1 deletion DevProject/Packages/packages-lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@
"url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates"
},
"com.unity.test-framework": {
"version": "1.1.20",
"version": "1.1.22",
"depth": 0,
"source": "registry",
"dependencies": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@
"m_Name": "Settings",
"m_Path": "ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json",
"m_Dictionary": {
"m_DictionaryValues": []
"m_DictionaryValues": [
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "Path",
"value": "{\"m_Value\":\"{ProjectPath}\"}"
},
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "HistoryPath",
"value": "{\"m_Value\":\"{ProjectPath}\"}"
},
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "IncludeAssemblies",
"value": "{\"m_Value\":\"Assembly-CSharp,Runtime,Unity.ML-Agents,Unity.ML-Agents.Extensions\"}"
}
]
}
}
4 changes: 2 additions & 2 deletions DevProject/ProjectSettings/ProjectVersion.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
m_EditorVersion: 2019.4.19f1
m_EditorVersionWithRevision: 2019.4.19f1 (ca5b14067cec)
m_EditorVersion: 2019.4.20f1
m_EditorVersionWithRevision: 2019.4.20f1 (6dd1c08eedfa)
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ public override void WriteObservation(float[] output)
}

/// <inheritdoc/>
public override int[] GetObservationShape()
public override ObservationSpec GetObservationSpec()
{
return new[] { BasicController.k_Extents };
return ObservationSpec.Vector(BasicController.k_Extents);
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public abstract class SensorBase : ISensor
public abstract void WriteObservation(float[] output);

/// <inheritdoc/>
public abstract int[] GetObservationShape();
public abstract ObservationSpec GetObservationSpec();

/// <inheritdoc/>
public abstract string GetName();
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ public class TestTextureSensor : ISensor
{
Texture2D m_Texture;
string m_Name;
int[] m_Shape;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;

/// <summary>
Expand All @@ -25,7 +25,7 @@ public TestTextureSensor(
var width = texture.width;
var height = texture.height;
m_Name = name;
m_Shape = new[] { height, width, 3 };
m_ObservationSpec = ObservationSpec.Visual(height, width, 3);
m_CompressionType = compressionType;
}

Expand All @@ -36,9 +36,9 @@ public string GetName()
}

/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
{
return m_Shape;
return m_ObservationSpec;
}

/// <inheritdoc/>
Expand Down
12 changes: 6 additions & 6 deletions com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class Match3Sensor : ISparseChannelSensor, IBuiltInSensor
{
private Match3ObservationType m_ObservationType;
private AbstractBoard m_Board;
private int[] m_Shape;
private ObservationSpec m_ObservationSpec;
private int[] m_SparseChannelMapping;
private string m_Name;

Expand Down Expand Up @@ -70,9 +70,9 @@ public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string n
m_NumSpecialTypes = board.NumSpecialTypes;

m_ObservationType = obsType;
m_Shape = obsType == Match3ObservationType.Vector ?
new[] { m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize) } :
new[] { m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize };
m_ObservationSpec = obsType == Match3ObservationType.Vector
? ObservationSpec.Vector(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize))
: ObservationSpec.Visual(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize);

// See comment in GetCompressedObservation()
var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3);
Expand All @@ -96,9 +96,9 @@ public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string n
}

/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
{
return m_Shape;
return m_ObservationSpec;
}

/// <inheritdoc/>
Expand Down
30 changes: 17 additions & 13 deletions com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ public enum GridDepthType { Channel, ChannelHot };
protected bool Initialized = false;

/// <summary>
/// Array holding the dimensions of the resulting tensor
/// Cached ObservationSpec
/// </summary>
private int[] m_Shape;
private ObservationSpec m_ObservationSpec;

//
// Debug Parameters
Expand Down Expand Up @@ -423,7 +423,7 @@ public virtual void Start()
// Default root reference to current game object
if (rootReference == null)
rootReference = gameObject;
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);

compressedImgs = new List<byte[]>();
byteSizesBytesList = new List<byte[]>();
Expand Down Expand Up @@ -475,14 +475,6 @@ public void ClearPerceptionBuffer()
}
}

/// <summary>Gets the shape of the grid observation</summary>
/// <returns>integer array shape of the grid observation</returns>
public int[] GetFloatObservationShape()
{
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
}

/// <inheritdoc/>
public string GetName()
{
Expand Down Expand Up @@ -914,10 +906,22 @@ void ISensor.Update()

/// <summary>Gets the observation shape</summary>
/// <returns>int[] of the observation shape</returns>
public ObservationSpec GetObservationSpec()
{
// Lazy update
var shape = m_ObservationSpec.Shape;
if (shape[0] != GridNumSideX || shape[1] != GridNumSideZ || shape[2] != ObservationPerCell)
{
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);
}
return m_ObservationSpec;
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
var shape = m_ObservationSpec.Shape;
return new int[] { shape[0], shape[1], shape[2] };
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Unity.MLAgents.Extensions.Sensors
/// </summary>
public class PhysicsBodySensor : ISensor, IBuiltInSensor
{
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_SensorName;

PoseExtractor m_PoseExtractor;
Expand Down Expand Up @@ -44,7 +44,7 @@ string sensorName
}

var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}

#if UNITY_2020_1_OR_NEWER
Expand All @@ -65,14 +65,14 @@ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settin
}

var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}
#endif

/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
{
return m_Shape;
return m_ObservationSpec;
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public void TestVectorObservations()

var expectedShape = new[] { 3 * 3 * 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

var expectedObs = new float[]
{
Expand Down Expand Up @@ -65,7 +65,7 @@ public void TestVectorObservationsSpecial()

var expectedShape = new[] { 3 * 3 * (2 + 3) };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

var expectedObs = new float[]
{
Expand Down Expand Up @@ -94,7 +94,7 @@ public void TestVisualObservations()

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

Expand Down Expand Up @@ -138,7 +138,7 @@ public void TestVisualObservationsSpecial()

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

Expand Down Expand Up @@ -176,7 +176,7 @@ public void TestCompressedVisualObservations()

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

Expand Down Expand Up @@ -216,7 +216,7 @@ public void TestCompressedVisualObservationsSpecial()

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void OneChannelDepthOne()
gridSensor.Start();

int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());

}

Expand All @@ -52,7 +52,7 @@ public void OneChannelDepthTwo()
gridSensor.Start();

int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());

}

Expand All @@ -67,7 +67,7 @@ public void TwoChannelsDepthTwoOne()
gridSensor.Start();

int[] expectedShape = { 10, 10, 3 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());

}

Expand All @@ -82,7 +82,7 @@ public void TwoChannelsDepthThreeThree()
gridSensor.Start();

int[] expectedShape = { 10, 10, 6 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void OneChannel()
gridSensor.Start();

int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

[Test]
Expand All @@ -49,7 +49,7 @@ public void TwoChannel()
gridSensor.Start();

int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

[Test]
Expand All @@ -63,7 +63,7 @@ public void SevenChannel()
gridSensor.Start();

int[] expectedShape = { 10, 10, 7 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
}
}
10 changes: 6 additions & 4 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@ and this project adheres to
## [Unreleased]
### Major Changes
#### com.unity.ml-agents (C#)
======
- Several breaking interface changes were made. See the
[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more
details.
- Some methods previously marked as `Obsolete` have been removed. If you were using these methods, you need to replace them with their supported counterpart.
- The interface for disabling discrete actions in `IDiscreteActionMask` has changed.
`WriteMask(int branch, IEnumerable<int> actionIndices)` was replaced with
`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. See the
`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. (#5060)
- IActuator now implements IHeuristicProvider. (#5110)
[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more
details. (#5060)
- `ISensor.GetObservationShape()` was removed, and `GetObservationSpec()` was added. (#5127)

#### ml-agents / ml-agents-envs / gym-unity (Python)

### Minor Changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ public interface IDiscreteActionMask
/// <summary>
/// Set whether or not the action index for the given branch is allowed.
/// </summary>
/// <remarks>
/// By default, all discrete actions are allowed.
/// If isEnabled is false, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndex correspond
Expand Down
Loading

0 comments on commit 540595a

Please sign in to comment.