Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLA-1634] Compression spec #5164

Merged
merged 11 commits into from
Mar 22, 2021
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ public virtual byte[] GetCompressedObservation()
}

/// <inheritdoc/>
public virtual SensorCompressionType GetCompressionType()
public virtual CompressionSpec GetCompressionSpec()
{
return SensorCompressionType.None;
return CompressionSpec.Default();
}
Comment on lines +54 to 57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A potential improvement of this is to have a m_CompressionSpec and return that like we do with m_ObservationSpec.
Not sure if it's worthy though given this looks pretty light weight.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cached the shapes previously, because they required allocating memory. I kept that pattern for the ObservationSpecs - it's not really necessary for performance (struct with InplaceArray, so no allocations), but it's not a bad idea since they shouldn't change at runtime. I don't think it's necessary for CompressionSpecs.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ public void Update() { }
public void Reset() { }

/// <inheritdoc/>
public SensorCompressionType GetCompressionType()
public CompressionSpec GetCompressionSpec()
{
return m_CompressionType;
return CompressionSpec.Default();
}
}

14 changes: 6 additions & 8 deletions com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public enum Match3ObservationType
/// or uncompressed visual observations. Uses AbstractBoard.GetCellType()
/// and AbstractBoard.GetSpecialType() to determine the observation values.
/// </summary>
public class Match3Sensor : ISparseChannelSensor, IBuiltInSensor
public class Match3Sensor : ISensor, IBuiltInSensor
{
private Match3ObservationType m_ObservationType;
private AbstractBoard m_Board;
Expand All @@ -47,7 +47,6 @@ public class Match3Sensor : ISparseChannelSensor, IBuiltInSensor
private int m_Columns;
private int m_NumCellTypes;
private int m_NumSpecialTypes;
private ISparseChannelSensor sparseChannelSensorImplementation;

private int SpecialTypeSize
{
Expand Down Expand Up @@ -214,24 +213,23 @@ public void Reset()
{
}

/// <inheritdoc/>
public SensorCompressionType GetCompressionType()
internal SensorCompressionType GetCompressionType()
{
return m_ObservationType == Match3ObservationType.CompressedVisual ?
SensorCompressionType.PNG :
SensorCompressionType.None;
}

/// <inheritdoc/>
public string GetName()
public CompressionSpec GetCompressionSpec()
{
return m_Name;
return new CompressionSpec(GetCompressionType(), m_SparseChannelMapping);
}

/// <inheritdoc/>
public int[] GetCompressedChannelMapping()
public string GetName()
{
return m_SparseChannelMapping;
return m_Name;
}

/// <inheritdoc/>
Expand Down
4 changes: 2 additions & 2 deletions com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,9 @@ public string GetName()
}

/// <inheritdoc/>
public virtual SensorCompressionType GetCompressionType()
public virtual CompressionSpec GetCompressionSpec()
{
return CompressionType;
return new CompressionSpec(CompressionType);
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ public void Update()
public void Reset() { }

/// <inheritdoc/>
public SensorCompressionType GetCompressionType()
public CompressionSpec GetCompressionSpec()
{
return SensorCompressionType.None;
return CompressionSpec.Default();
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public void TestVisualObservations()
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType);

var expectedObs = new float[]
{
Expand Down Expand Up @@ -140,7 +140,7 @@ public void TestVisualObservationsSpecial()
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType);

var expectedObs = new float[]
{
Expand Down Expand Up @@ -178,7 +178,7 @@ public void TestCompressedVisualObservations()
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType);

var pngData = sensor.GetCompressedObservation();
if (WritePNGDataToFile)
Expand Down Expand Up @@ -218,7 +218,7 @@ public void TestCompressedVisualObservationsSpecial()
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType);

var concatenatedPngData = sensor.GetCompressedObservation();
var pathPrefix = "match3obs_special";
Expand Down
5 changes: 4 additions & 1 deletion com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ details.
`WriteMask(int branch, IEnumerable<int> actionIndices)` was replaced with
`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. (#5060)
- IActuator now implements IHeuristicProvider. (#5110)
- `ISensor.GetObservationShape()` was removed, and `GetObservationSpec()` was added. (#5127)
- `ISensor.GetObservationShape()` was removed, and `GetObservationSpec()` was added. The `ITypedSensor`
and `IDimensionPropertiesSensor` interfaces were removed. (#5127)
- `ISensor.GetCompressionType()` was removed, and `GetCompressionSpec()` was added. The `ISparseChannelSensor`
interface was removed. (#5164)

#### ml-agents / ml-agents-envs / gym-unity (Python)

Expand Down
2 changes: 1 addition & 1 deletion com.unity.ml-agents/Runtime/Analytics/Events.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public static EventObservationSpec FromSensor(ISensor sensor)
return new EventObservationSpec
{
SensorName = sensor.GetName(),
CompressionType = sensor.GetCompressionType().ToString(),
CompressionType = sensor.GetCompressionSpec().SensorCompressionType.ToString(),
BuiltInSensorType = (int)builtInSensorType,
DimensionInfos = dimInfos,
};
Expand Down
42 changes: 7 additions & 35 deletions com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,8 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
var obsSpec = sensor.GetObservationSpec();
var shape = obsSpec.Shape;
ObservationProto observationProto = null;
var compressionType = sensor.GetCompressionType();
var compressionSpec = sensor.GetCompressionSpec();
var compressionType = compressionSpec.SensorCompressionType;
// Check capabilities if we need to concatenate PNGs
if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3)
{
Expand All @@ -365,7 +366,7 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3)
{
var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping;
var isTrivialMapping = IsTrivialMapping(sensor);
var isTrivialMapping = compressionSpec.IsTrivialMapping();
if (!trainerCanHandleMapping && !isTrivialMapping)
{
if (!s_HaveWarnedTrainerCapabilitiesMapping)
Expand Down Expand Up @@ -411,18 +412,17 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
throw new UnityAgentsException(
$"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " +
"You must return a byte[]. If you don't want to use compressed observations, " +
"return SensorCompressionType.None from GetCompressionType()."
"return CompressionSpec.Default() from GetCompressionSpec()."
);
}
observationProto = new ObservationProto
{
CompressedData = ByteString.CopyFrom(compressedObs),
CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
CompressionType = (CompressionTypeProto)sensor.GetCompressionSpec().SensorCompressionType,
};
var compressibleSensor = sensor as ISparseChannelSensor;
if (compressibleSensor != null)
if (compressionSpec.CompressedChannelMapping != null)
{
observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
observationProto.CompressedChannelMapping.AddRange(compressionSpec.CompressedChannelMapping);
}
}

Expand Down Expand Up @@ -488,34 +488,6 @@ public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps)
};
}

internal static bool IsTrivialMapping(ISensor sensor)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this logic to CompressionSpec

{
var compressibleSensor = sensor as ISparseChannelSensor;
if (compressibleSensor is null)
{
return true;
}
var mapping = compressibleSensor.GetCompressedChannelMapping();
if (mapping == null)
{
return true;
}
// check if mapping equals zero mapping
if (mapping.Length == 3 && mapping.All(m => m == 0))
{
return true;
}
// check if mapping equals identity mapping
for (var i = 0; i < mapping.Length; i++)
{
if (mapping[i] != i)
{
return false;
}
}
return true;
}

#region Analytics
internal static TrainingEnvironmentInitializedEvent ToTrainingEnvironmentInitializedEvent(
this TrainingEnvironmentInitialized inputProto)
Expand Down
2 changes: 1 addition & 1 deletion com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void StepSensors(List<ISensor> sensors)
{
foreach (var sensor in sensors)
{
if (sensor.GetCompressionType() == SensorCompressionType.None)
if (sensor.GetCompressionSpec().SensorCompressionType == SensorCompressionType.None)
{
m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationSpec(), 0);
sensor.Write(m_ObservationWriter);
Expand Down
4 changes: 2 additions & 2 deletions com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ public void Reset()
}

/// <inheritdoc/>
public SensorCompressionType GetCompressionType()
public CompressionSpec GetCompressionSpec()
{
return SensorCompressionType.None;
return CompressionSpec.Default();
}

/// <inheritdoc/>
Expand Down
4 changes: 2 additions & 2 deletions com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ public void Update() { }
public void Reset() { }

/// <inheritdoc/>
public SensorCompressionType GetCompressionType()
public CompressionSpec GetCompressionSpec()
{
return m_CompressionType;
return new CompressionSpec(m_CompressionType);
}

/// <summary>
Expand Down
109 changes: 109 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
using System.Linq;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// The compression setting for visual/camera observations.
/// </summary>
public enum SensorCompressionType
{
/// <summary>
/// No compression. Data is preserved as float arrays.
/// </summary>
None,

/// <summary>
/// PNG format. Data will be stored in binary format.
/// </summary>
PNG
}

/// <summary>
/// A description of the compression used for observations.
/// </summary>
/// <remarks>
/// Most ISensor implementations can't take advantage of compression,
/// and should return CompressionSpec.Default() from their ISensor.GetCompressionSpec() methods.
/// Visual observations, or mulitdimensional categorical observations (for example, image segmentation
/// or the piece types in a match-3 game board) can use PNG compression reduce the amount of
/// data transferred between Unity and the trainer.
/// </remarks>
public struct CompressionSpec
{
internal SensorCompressionType m_SensorCompressionType;

/// <summary>
/// The compression type that the sensor will use for its observations.
/// </summary>
public SensorCompressionType SensorCompressionType
{
get => m_SensorCompressionType;
}

internal int[] m_CompressedChannelMapping;

/// The mapping of the channels in compressed data to the actual channel after decompression.
/// The mapping is a list of integer index with the same length as
/// the number of output observation layers (channels), including padding if there's any.
/// Each index indicates the actual channel the layer will go into.
/// Layers with the same index will be averaged, and layers with negative index will be dropped.
/// For example, mapping for CameraSensor using grayscale and stacking of two: [0, 0, 0, 1, 1, 1]
/// Mapping for GridSensor of 4 channels and stacking of two: [0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1]
public int[] CompressedChannelMapping
{
get => m_CompressedChannelMapping;
}

/// <summary>
/// Return a CompressionSpec indicating possible compression.
/// </summary>
/// <param name="sensorCompressionType">The compression type to use.</param>
/// <param name="compressedChannelMapping">Optional mapping mapping of the channels in compressed data to the
/// actual channel after decompression.</param>
public CompressionSpec(SensorCompressionType sensorCompressionType, int[] compressedChannelMapping = null)
{
m_SensorCompressionType = sensorCompressionType;
m_CompressedChannelMapping = compressedChannelMapping;
}

/// <summary>
/// Return a CompressionSpec indicating no compression. This is recommended for most sensors.
/// </summary>
/// <returns></returns>
public static CompressionSpec Default()
{
return new CompressionSpec
{
m_SensorCompressionType = SensorCompressionType.None,
m_CompressedChannelMapping = null
};
}

/// <summary>
/// Return whether the compressed channel mapping is "trivial"; if so it doesn't need to be sent to the
/// trainer.
/// </summary>
/// <returns></returns>
internal bool IsTrivialMapping()
{
var mapping = CompressedChannelMapping;
if (mapping == null)
{
return true;
}
// check if mapping equals zero mapping
if (mapping.Length == 3 && mapping.All(m => m == 0))
{
return true;
}
// check if mapping equals identity mapping
for (var i = 0; i < mapping.Length; i++)
{
if (mapping[i] != i)
{
return false;
}
}
return true;
}
}
}
3 changes: 3 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/CompressionSpec.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading