diff --git a/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs b/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
index d04438053f..7bca30e7c9 100644
--- a/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
+++ b/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
@@ -51,9 +51,9 @@ public virtual byte[] GetCompressedObservation()
}
///
- public virtual SensorCompressionType GetCompressionType()
+ public virtual CompressionSpec GetCompressionSpec()
{
- return SensorCompressionType.None;
+ return CompressionSpec.Default();
}
}
}
diff --git a/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs b/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs
index 24cee1303e..b4702b15d1 100644
--- a/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs
+++ b/Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs
@@ -62,9 +62,9 @@ public void Update() { }
public void Reset() { }
///
- public SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return m_CompressionType;
+ return CompressionSpec.Default();
}
}
diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
index 99b6bcd121..647e56fddb 100644
--- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
@@ -35,7 +35,7 @@ public enum Match3ObservationType
/// or uncompressed visual observations. Uses AbstractBoard.GetCellType()
/// and AbstractBoard.GetSpecialType() to determine the observation values.
///
- public class Match3Sensor : ISparseChannelSensor, IBuiltInSensor
+ public class Match3Sensor : ISensor, IBuiltInSensor
{
private Match3ObservationType m_ObservationType;
private AbstractBoard m_Board;
@@ -47,7 +47,6 @@ public class Match3Sensor : ISparseChannelSensor, IBuiltInSensor
private int m_Columns;
private int m_NumCellTypes;
private int m_NumSpecialTypes;
- private ISparseChannelSensor sparseChannelSensorImplementation;
private int SpecialTypeSize
{
@@ -214,8 +213,7 @@ public void Reset()
{
}
- ///
- public SensorCompressionType GetCompressionType()
+ internal SensorCompressionType GetCompressionType()
{
return m_ObservationType == Match3ObservationType.CompressedVisual ?
SensorCompressionType.PNG :
@@ -223,15 +221,15 @@ public SensorCompressionType GetCompressionType()
}
///
- public string GetName()
+ public CompressionSpec GetCompressionSpec()
{
- return m_Name;
+ return new CompressionSpec(GetCompressionType(), m_SparseChannelMapping);
}
///
- public int[] GetCompressedChannelMapping()
+ public string GetName()
{
- return m_SparseChannelMapping;
+ return m_Name;
}
///
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
index f130069221..ad3886ace7 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
@@ -482,9 +482,9 @@ public string GetName()
}
///
- public virtual SensorCompressionType GetCompressionType()
+ public virtual CompressionSpec GetCompressionSpec()
{
- return CompressionType;
+ return new CompressionSpec(CompressionType);
}
///
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
index 3bf83b07fd..32d35603bf 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
@@ -110,9 +110,9 @@ public void Update()
public void Reset() { }
///
- public SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return SensorCompressionType.None;
+ return CompressionSpec.Default();
}
///
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
index 8e975f007c..9013d3d73f 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
@@ -96,7 +96,7 @@ public void TestVisualObservations()
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape);
- Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());
+ Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType);
var expectedObs = new float[]
{
@@ -140,7 +140,7 @@ public void TestVisualObservationsSpecial()
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape);
- Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());
+ Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType);
var expectedObs = new float[]
{
@@ -178,7 +178,7 @@ public void TestCompressedVisualObservations()
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape);
- Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());
+ Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType);
var pngData = sensor.GetCompressedObservation();
if (WritePNGDataToFile)
@@ -218,7 +218,7 @@ public void TestCompressedVisualObservationsSpecial()
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray.FromList(expectedShape), sensor.GetObservationSpec().Shape);
- Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());
+ Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType);
var concatenatedPngData = sensor.GetCompressedObservation();
var pathPrefix = "match3obs_special";
diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md
index b3e2ba27bc..a1e2d3829e 100755
--- a/com.unity.ml-agents/CHANGELOG.md
+++ b/com.unity.ml-agents/CHANGELOG.md
@@ -18,7 +18,10 @@ details.
`WriteMask(int branch, IEnumerable actionIndices)` was replaced with
`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. (#5060)
- IActuator now implements IHeuristicProvider. (#5110)
-- `ISensor.GetObservationShape()` was removed, and `GetObservationSpec()` was added. (#5127)
+- `ISensor.GetObservationShape()` was removed, and `GetObservationSpec()` was added. The `ITypedSensor`
+and `IDimensionPropertiesSensor` interfaces were removed. (#5127)
+- `ISensor.GetCompressionType()` was removed, and `GetCompressionSpec()` was added. The `ISparseChannelSensor`
+interface was removed. (#5164)
#### ml-agents / ml-agents-envs / gym-unity (Python)
diff --git a/com.unity.ml-agents/Runtime/Analytics/Events.cs b/com.unity.ml-agents/Runtime/Analytics/Events.cs
index fb8f8b901b..fe0552d962 100644
--- a/com.unity.ml-agents/Runtime/Analytics/Events.cs
+++ b/com.unity.ml-agents/Runtime/Analytics/Events.cs
@@ -117,7 +117,7 @@ public static EventObservationSpec FromSensor(ISensor sensor)
return new EventObservationSpec
{
SensorName = sensor.GetName(),
- CompressionType = sensor.GetCompressionType().ToString(),
+ CompressionType = sensor.GetCompressionSpec().SensorCompressionType.ToString(),
BuiltInSensorType = (int)builtInSensorType,
DimensionInfos = dimInfos,
};
diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
index 357caf9772..8005f3c416 100644
--- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
+++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
@@ -342,7 +342,8 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
var obsSpec = sensor.GetObservationSpec();
var shape = obsSpec.Shape;
ObservationProto observationProto = null;
- var compressionType = sensor.GetCompressionType();
+ var compressionSpec = sensor.GetCompressionSpec();
+ var compressionType = compressionSpec.SensorCompressionType;
// Check capabilities if we need to concatenate PNGs
if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3)
{
@@ -365,7 +366,7 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3)
{
var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping;
- var isTrivialMapping = IsTrivialMapping(sensor);
+ var isTrivialMapping = compressionSpec.IsTrivialMapping();
if (!trainerCanHandleMapping && !isTrivialMapping)
{
if (!s_HaveWarnedTrainerCapabilitiesMapping)
@@ -411,18 +412,17 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
throw new UnityAgentsException(
$"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " +
"You must return a byte[]. If you don't want to use compressed observations, " +
- "return SensorCompressionType.None from GetCompressionType()."
+ "return CompressionSpec.Default() from GetCompressionSpec()."
);
}
observationProto = new ObservationProto
{
CompressedData = ByteString.CopyFrom(compressedObs),
- CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
+ CompressionType = (CompressionTypeProto)sensor.GetCompressionSpec().SensorCompressionType,
};
- var compressibleSensor = sensor as ISparseChannelSensor;
- if (compressibleSensor != null)
+ if (compressionSpec.CompressedChannelMapping != null)
{
- observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
+ observationProto.CompressedChannelMapping.AddRange(compressionSpec.CompressedChannelMapping);
}
}
@@ -488,34 +488,6 @@ public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps)
};
}
- internal static bool IsTrivialMapping(ISensor sensor)
- {
- var compressibleSensor = sensor as ISparseChannelSensor;
- if (compressibleSensor is null)
- {
- return true;
- }
- var mapping = compressibleSensor.GetCompressedChannelMapping();
- if (mapping == null)
- {
- return true;
- }
- // check if mapping equals zero mapping
- if (mapping.Length == 3 && mapping.All(m => m == 0))
- {
- return true;
- }
- // check if mapping equals identity mapping
- for (var i = 0; i < mapping.Length; i++)
- {
- if (mapping[i] != i)
- {
- return false;
- }
- }
- return true;
- }
-
#region Analytics
internal static TrainingEnvironmentInitializedEvent ToTrainingEnvironmentInitializedEvent(
this TrainingEnvironmentInitialized inputProto)
diff --git a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
index 3a0556f701..aed52c8e59 100644
--- a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
+++ b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
@@ -127,7 +127,7 @@ void StepSensors(List sensors)
{
foreach (var sensor in sensors)
{
- if (sensor.GetCompressionType() == SensorCompressionType.None)
+ if (sensor.GetCompressionSpec().SensorCompressionType == SensorCompressionType.None)
{
m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationSpec(), 0);
sensor.Write(m_ObservationWriter);
diff --git a/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
index e721141eda..5c08546da9 100644
--- a/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
@@ -94,9 +94,9 @@ public void Reset()
}
///
- public SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return SensorCompressionType.None;
+ return CompressionSpec.Default();
}
///
diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
index 5c3167f2dc..6dae738c5d 100644
--- a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
@@ -117,9 +117,9 @@ public void Update() { }
public void Reset() { }
///
- public SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return m_CompressionType;
+ return new CompressionSpec(m_CompressionType);
}
///
diff --git a/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs b/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs
new file mode 100644
index 0000000000..fe53839f53
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs
@@ -0,0 +1,109 @@
+using System.Linq;
+namespace Unity.MLAgents.Sensors
+{
+ ///
+ /// The compression setting for visual/camera observations.
+ ///
+ public enum SensorCompressionType
+ {
+ ///
+ /// No compression. Data is preserved as float arrays.
+ ///
+ None,
+
+ ///
+ /// PNG format. Data will be stored in binary format.
+ ///
+ PNG
+ }
+
+ ///
+ /// A description of the compression used for observations.
+ ///
+ ///
+ /// Most ISensor implementations can't take advantage of compression,
+ /// and should return CompressionSpec.Default() from their ISensor.GetCompressionSpec() methods.
+ /// Visual observations, or mulitdimensional categorical observations (for example, image segmentation
+ /// or the piece types in a match-3 game board) can use PNG compression reduce the amount of
+ /// data transferred between Unity and the trainer.
+ ///
+ public struct CompressionSpec
+ {
+ internal SensorCompressionType m_SensorCompressionType;
+
+ ///
+ /// The compression type that the sensor will use for its observations.
+ ///
+ public SensorCompressionType SensorCompressionType
+ {
+ get => m_SensorCompressionType;
+ }
+
+ internal int[] m_CompressedChannelMapping;
+
+ /// The mapping of the channels in compressed data to the actual channel after decompression.
+ /// The mapping is a list of integer index with the same length as
+ /// the number of output observation layers (channels), including padding if there's any.
+ /// Each index indicates the actual channel the layer will go into.
+ /// Layers with the same index will be averaged, and layers with negative index will be dropped.
+ /// For example, mapping for CameraSensor using grayscale and stacking of two: [0, 0, 0, 1, 1, 1]
+ /// Mapping for GridSensor of 4 channels and stacking of two: [0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1]
+ public int[] CompressedChannelMapping
+ {
+ get => m_CompressedChannelMapping;
+ }
+
+ ///
+ /// Return a CompressionSpec indicating possible compression.
+ ///
+ /// The compression type to use.
+ /// Optional mapping mapping of the channels in compressed data to the
+ /// actual channel after decompression.
+ public CompressionSpec(SensorCompressionType sensorCompressionType, int[] compressedChannelMapping = null)
+ {
+ m_SensorCompressionType = sensorCompressionType;
+ m_CompressedChannelMapping = compressedChannelMapping;
+ }
+
+ ///
+ /// Return a CompressionSpec indicating no compression. This is recommended for most sensors.
+ ///
+ ///
+ public static CompressionSpec Default()
+ {
+ return new CompressionSpec
+ {
+ m_SensorCompressionType = SensorCompressionType.None,
+ m_CompressedChannelMapping = null
+ };
+ }
+
+ ///
+ /// Return whether the compressed channel mapping is "trivial"; if so it doesn't need to be sent to the
+ /// trainer.
+ ///
+ ///
+ internal bool IsTrivialMapping()
+ {
+ var mapping = CompressedChannelMapping;
+ if (mapping == null)
+ {
+ return true;
+ }
+ // check if mapping equals zero mapping
+ if (mapping.Length == 3 && mapping.All(m => m == 0))
+ {
+ return true;
+ }
+ // check if mapping equals identity mapping
+ for (var i = 0; i < mapping.Length; i++)
+ {
+ if (mapping[i] != i)
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta b/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta
new file mode 100644
index 0000000000..3bbac496d7
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 0ddff1d1b7ad4170acb1a10272d4a8c2
+timeCreated: 1616006929
\ No newline at end of file
diff --git a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs
index 62f63b3f19..62f6f78a08 100644
--- a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs
@@ -1,21 +1,5 @@
namespace Unity.MLAgents.Sensors
{
- ///
- /// The compression setting for visual/camera observations.
- ///
- public enum SensorCompressionType
- {
- ///
- /// No compression. Data is preserved as float arrays.
- ///
- None,
-
- ///
- /// PNG format. Data will be stored in binary format.
- ///
- PNG
- }
-
///
/// The Dimension property flags of the observations
///
@@ -112,11 +96,11 @@ public interface ISensor
void Reset();
///
- /// Return the compression type being used. If no compression is used, return
- /// .
+ /// Return information on the compression type being used. If no compression is used, return
+ /// .
///
- /// Compression type used by the sensor.
- SensorCompressionType GetCompressionType();
+ /// CompressionSpec used by the sensor.
+ CompressionSpec GetCompressionSpec();
///
/// Get the name of the sensor. This is used to ensure deterministic sorting of the sensors
diff --git a/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs
deleted file mode 100644
index 06b517eb42..0000000000
--- a/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs
+++ /dev/null
@@ -1,20 +0,0 @@
-namespace Unity.MLAgents.Sensors
-{
- ///
- /// Sensor interface for sparse channel sensor which requires a compressed channel mapping.
- ///
- public interface ISparseChannelSensor : ISensor
- {
- ///
- /// Returns the mapping of the channels in compressed data to the actual channel after decompression.
- /// The mapping is a list of interger index with the same length as
- /// the number of output observation layers (channels), including padding if there's any.
- /// Each index indicates the actual channel the layer will go into.
- /// Layers with the same index will be averaged, and layers with negative index will be dropped.
- /// For example, mapping for CameraSensor using grayscale and stacking of two: [0, 0, 0, 1, 1, 1]
- /// Mapping for GridSensor of 4 channels and stacking of two: [0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1]
- ///
- /// Mapping of the compressed data
- int[] GetCompressedChannelMapping();
- }
-}
diff --git a/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs.meta
deleted file mode 100644
index bebec4f1cf..0000000000
--- a/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs.meta
+++ /dev/null
@@ -1,11 +0,0 @@
-fileFormatVersion: 2
-guid: 63bb76c1e31c24fa5b4a384ea0edbfb0
-MonoImporter:
- externalObjects: {}
- serializedVersion: 2
- defaultReferences: []
- executionOrder: 0
- icon: {instanceID: 0}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
index 21ee70d52e..892cd821d4 100644
--- a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
@@ -361,9 +361,9 @@ public virtual byte[] GetCompressedObservation()
}
///
- public virtual SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return SensorCompressionType.None;
+ return CompressionSpec.Default();
}
///
diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
index 737ea9f563..9d77f7794e 100644
--- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
@@ -91,9 +91,9 @@ public void Update() { }
public void Reset() { }
///
- public SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return SensorCompressionType.None;
+ return CompressionSpec.Default();
}
///
diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
index 3a4b2027ca..745ad023ae 100644
--- a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
@@ -88,9 +88,9 @@ public void Update() { }
public void Reset() { }
///
- public SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return m_CompressionType;
+ return new CompressionSpec(m_CompressionType);
}
///
diff --git a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
index d4003ed7b4..645fe120e6 100644
--- a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
@@ -14,7 +14,7 @@ namespace Unity.MLAgents.Sensors
/// Internally, a circular buffer of arrays is used. The m_CurrentIndex represents the most recent observation.
/// Currently, observations are stacked on the last dimension.
///
- public class StackingSensor : ISparseChannelSensor, IBuiltInSensor
+ public class StackingSensor : ISensor, IBuiltInSensor
{
///
/// The wrapped sensor.
@@ -78,7 +78,7 @@ public StackingSensor(ISensor wrapped, int numStackedObservations)
m_StackedObservations[i] = new float[m_UnstackedObservationSize];
}
- if (m_WrappedSensor.GetCompressionType() != SensorCompressionType.None)
+ if (m_WrappedSensor.GetCompressionSpec().SensorCompressionType != SensorCompressionType.None)
{
m_StackedCompressedObservations = new byte[numStackedObservations][];
m_EmptyCompressedObservation = CreateEmptyPNG();
@@ -154,7 +154,7 @@ public void Reset()
{
Array.Clear(m_StackedObservations[i], 0, m_StackedObservations[i].Length);
}
- if (m_WrappedSensor.GetCompressionType() != SensorCompressionType.None)
+ if (m_WrappedSensor.GetCompressionSpec().SensorCompressionType != SensorCompressionType.None)
{
for (var i = 0; i < m_NumStackedObservations; i++)
{
@@ -200,16 +200,10 @@ public byte[] GetCompressedObservation()
return outputBytes;
}
- ///
- public int[] GetCompressedChannelMapping()
- {
- return m_CompressionMapping;
- }
-
- ///
- public SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return m_WrappedSensor.GetCompressionType();
+ var wrappedSpec = m_WrappedSensor.GetCompressionSpec();
+ return new CompressionSpec(wrappedSpec.SensorCompressionType, m_CompressionMapping);
}
///
@@ -233,7 +227,7 @@ internal byte[] CreateEmptyPNG()
}
///
- /// Constrct stacked CompressedChannelMapping.
+ /// Construct stacked CompressedChannelMapping.
///
internal int[] ConstructStackedCompressedChannelMapping(ISensor wrappedSenesor)
{
@@ -242,11 +236,8 @@ internal int[] ConstructStackedCompressedChannelMapping(ISensor wrappedSenesor)
// Default mapping: {0, 0, 0} for grayscale, identity mapping {1, 2, ..., n} otherwise.
int[] wrappedMapping = null;
int wrappedNumChannel = m_WrappedSpec.Shape[2];
- var sparseChannelSensor = m_WrappedSensor as ISparseChannelSensor;
- if (sparseChannelSensor != null)
- {
- wrappedMapping = sparseChannelSensor.GetCompressedChannelMapping();
- }
+
+ wrappedMapping = wrappedSenesor.GetCompressionSpec().CompressedChannelMapping;
if (wrappedMapping == null)
{
if (wrappedNumChannel == 1)
diff --git a/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
index 4f151efb94..4a583a179f 100644
--- a/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
@@ -102,9 +102,9 @@ public virtual byte[] GetCompressedObservation()
}
///
- public virtual SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return SensorCompressionType.None;
+ return CompressionSpec.Default();
}
///
diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
index 0450b43681..6fa249ad77 100644
--- a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
+++ b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
@@ -82,9 +82,9 @@ public void Update() { }
public void Reset() { }
- public SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return CompressionType;
+ return new CompressionSpec(CompressionType);
}
public string GetName()
@@ -93,19 +93,6 @@ public string GetName()
}
}
- class DummySparseChannelSensor : DummySensor, ISparseChannelSensor
- {
- public int[] Mapping;
- internal DummySparseChannelSensor()
- {
- }
-
- public int[] GetCompressedChannelMapping()
- {
- return Mapping;
- }
- }
-
[Test]
public void TestGetObservationProtoCapabilities()
{
@@ -168,23 +155,6 @@ public void TestGetObservationProtoCapabilities()
}
- [Test]
- public void TestIsTrivialMapping()
- {
- Assert.AreEqual(GrpcExtensions.IsTrivialMapping(new DummySensor()), true);
-
- var sparseChannelSensor = new DummySparseChannelSensor();
- sparseChannelSensor.Mapping = null;
- Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true);
- sparseChannelSensor.Mapping = new[] { 0, 0, 0 };
- Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true);
- sparseChannelSensor.Mapping = new[] { 0, 1, 2, 3, 4 };
- Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true);
- sparseChannelSensor.Mapping = new[] { 1, 2, 3, 4, -1, -1 };
- Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false);
- sparseChannelSensor.Mapping = new[] { 0, 0, 0, 1, 1, 1 };
- Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false);
- }
[Test]
public void TestDefaultTrainingEvents()
{
diff --git a/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs b/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs
index 7a0fc087d8..540671a345 100644
--- a/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs
+++ b/com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs
@@ -64,9 +64,9 @@ public byte[] GetCompressedObservation()
public void Update() { }
public void Reset() { }
- public SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return SensorCompressionType.None;
+ return CompressionSpec.Default();
}
public string GetName()
diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CompressionSpecTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/CompressionSpecTests.cs
new file mode 100644
index 0000000000..95740a5c7c
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CompressionSpecTests.cs
@@ -0,0 +1,30 @@
+using NUnit.Framework;
+using Unity.MLAgents.Sensors;
+
+namespace Unity.MLAgents.Tests
+{
+ [TestFixture]
+ public class CompressionSpecTests
+ {
+ [Test]
+ public void TestIsTrivialMapping()
+ {
+ Assert.IsTrue(CompressionSpec.Default().IsTrivialMapping());
+
+ var spec = new CompressionSpec(SensorCompressionType.PNG, null);
+ Assert.AreEqual(spec.IsTrivialMapping(), true);
+
+ spec = new CompressionSpec(SensorCompressionType.PNG, new[] { 0, 0, 0 });
+ Assert.AreEqual(spec.IsTrivialMapping(), true);
+
+ spec = new CompressionSpec(SensorCompressionType.PNG, new[] { 0, 1, 2, 3, 4 });
+ Assert.AreEqual(spec.IsTrivialMapping(), true);
+
+ spec = new CompressionSpec(SensorCompressionType.PNG, new[] { 1, 2, 3, 4, -1, -1 });
+ Assert.AreEqual(spec.IsTrivialMapping(), false);
+
+ spec = new CompressionSpec(SensorCompressionType.PNG, new[] { 0, 0, 0, 1, 1, 1 });
+ Assert.AreEqual(spec.IsTrivialMapping(), false);
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CompressionSpecTests.cs.meta b/com.unity.ml-agents/Tests/Runtime/Sensor/CompressionSpecTests.cs.meta
new file mode 100644
index 0000000000..d9df7ce5b2
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CompressionSpecTests.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: cd0990de0eb646b0b0531b91c840c9da
+timeCreated: 1616030728
\ No newline at end of file
diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/FloatVisualSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/FloatVisualSensorTests.cs
index fd99d71be7..0f4a7ae1de 100644
--- a/com.unity.ml-agents/Tests/Runtime/Sensor/FloatVisualSensorTests.cs
+++ b/com.unity.ml-agents/Tests/Runtime/Sensor/FloatVisualSensorTests.cs
@@ -64,9 +64,9 @@ public int Write(ObservationWriter writer)
public void Update() { }
public void Reset() { }
- public SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return SensorCompressionType.None;
+ return CompressionSpec.Default();
}
}
diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs
index 62542306cf..d03b2a4a4d 100644
--- a/com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs
+++ b/com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs
@@ -50,9 +50,9 @@ public int Write(ObservationWriter writer)
public void Update() { }
public void Reset() { }
- public SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return SensorCompressionType.None;
+ return CompressionSpec.Default();
}
}
diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs
index 06c48489c7..c959192720 100644
--- a/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs
+++ b/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs
@@ -112,7 +112,7 @@ public void TestVectorStackingReset()
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 5f, 6f });
}
- class Dummy3DSensor : ISparseChannelSensor
+ class Dummy3DSensor : ISensor
{
public SensorCompressionType CompressionType = SensorCompressionType.PNG;
public int[] Mapping;
@@ -157,9 +157,9 @@ public void Update() { }
public void Reset() { }
- public SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return CompressionType;
+ return new CompressionSpec(CompressionType, Mapping);
}
public string GetName()
@@ -167,11 +167,6 @@ public string GetName()
return "Dummy";
}
- public int[] GetCompressedChannelMapping()
- {
- return Mapping;
- }
-
}
[Test]
@@ -181,27 +176,27 @@ public void TestStackingMapping()
var cameraSensor = new CameraSensor(new Camera(), 64, 64,
true, "grayscaleCamera", SensorCompressionType.PNG);
var stackedCameraSensor = new StackingSensor(cameraSensor, 2);
- Assert.AreEqual(stackedCameraSensor.GetCompressedChannelMapping(), new[] { 0, 0, 0, 1, 1, 1 });
+ Assert.AreEqual(stackedCameraSensor.GetCompressionSpec().CompressedChannelMapping, new[] { 0, 0, 0, 1, 1, 1 });
// Test RGB stacked mapping with RenderTextureSensor
var renderTextureSensor = new RenderTextureSensor(new RenderTexture(24, 16, 0),
false, "renderTexture", SensorCompressionType.PNG);
var stackedRenderTextureSensor = new StackingSensor(renderTextureSensor, 2);
- Assert.AreEqual(stackedRenderTextureSensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, 4, 5 });
+ Assert.AreEqual(stackedRenderTextureSensor.GetCompressionSpec().CompressedChannelMapping, new[] { 0, 1, 2, 3, 4, 5 });
// Test mapping with number of layers not being multiple of 3
var dummySensor = new Dummy3DSensor();
dummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4);
dummySensor.Mapping = new[] { 0, 1, 2, 3 };
var stackedDummySensor = new StackingSensor(dummySensor, 2);
- Assert.AreEqual(stackedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });
+ Assert.AreEqual(stackedDummySensor.GetCompressionSpec().CompressedChannelMapping, new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });
// Test mapping with dummy layers that should be dropped
var paddedDummySensor = new Dummy3DSensor();
paddedDummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4);
paddedDummySensor.Mapping = new[] { 0, 1, 2, 3, -1, -1 };
var stackedPaddedDummySensor = new StackingSensor(paddedDummySensor, 2);
- Assert.AreEqual(stackedPaddedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });
+ Assert.AreEqual(stackedPaddedDummySensor.GetCompressionSpec().CompressedChannelMapping, new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });
}
[Test]
diff --git a/com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs b/com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs
index c230657546..ba5013900a 100644
--- a/com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs
+++ b/com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs
@@ -145,9 +145,9 @@ public byte[] GetCompressedObservation()
return new byte[] { 0 };
}
- public SensorCompressionType GetCompressionType()
+ public CompressionSpec GetCompressionSpec()
{
- return compressionType;
+ return new CompressionSpec(compressionType);
}
public string GetName()
diff --git a/docs/Migrating.md b/docs/Migrating.md
index b5cdf380c6..3c98b05498 100644
--- a/docs/Migrating.md
+++ b/docs/Migrating.md
@@ -47,7 +47,8 @@ public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
- The `IActuator` interface now implements `IHeuristicProvider`. Please add the corresponding `Heuristic(in ActionBuffers)`
method to your custom Actuator classes.
-- The `ISensor.GetObservationShape()` method was removed, and `GetObservationSpec()` was added. You can use
+- The `ISensor.GetObservationShape()` method and `ITypedSensor`
+and `IDimensionPropertiesSensor` interfaces were removed, and `GetObservationSpec()` was added. You can use
`ObservationSpec.Vector()` or `ObservationSpec.Visual()` to generate `ObservationSpec`s that are equivalent to
the previous shape. For example, if your old ISensor looked like:
@@ -67,6 +68,26 @@ public override ObservationSpec GetObservationSpec()
}
```
+- The `ISensor.GetCompressionType()` method and `ISparseChannelSensor` interface was removed,
+and `GetCompressionSpec()` was added. You can use `CompressionSpec.Default()` or
+`CompressionSpec.Compressed()` to generate `CompressionSpec`s that are equivalent to
+ the previous values. For example, if your old ISensor looked like:
+ ```csharp
+public virtual SensorCompressionType GetCompressionType()
+{
+ return SensorCompressionType.None;
+}
+```
+
+the equivalent code would now be
+
+```csharp
+public CompressionSpec GetCompressionSpec()
+{
+ return CompressionSpec.Default();
+}
+```
+
## Migrating to Release 13
### Implementing IHeuristic in your IActuator implementations
- If you have any custom actuators, you can now implement the `IHeuristicProvider` interface to have your actuator