Skip to content

Commit

Permalink
Added testing and returning staked observation flat.
Browse files Browse the repository at this point in the history
  • Loading branch information
cmard committed Oct 19, 2021
1 parent f8b17c7 commit bf0919d
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
2 changes: 1 addition & 1 deletion com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1190,7 +1190,7 @@ public ReadOnlyCollection<float> GetObservations()
/// <see cref="Heuristic(in ActionBuffers)"/> method to avoid recomputing the observations.
/// </summary>
/// <returns>A read-only view of the stacked observations list.</returns>
public ReadOnlyCollection<ReadOnlyCollection<float>> GetStackedObservations()
public ReadOnlyCollection<float> GetStackedObservations()
{
return stackedCollectObservationsSensor.GetStackedObservations();
}
Expand Down
17 changes: 11 additions & 6 deletions com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ public class StackingSensor : ISensor, IBuiltInSensor
/// Buffer of previous observations
/// </summary>
float[][] m_StackedObservations;
//[
//[1,2]
//[3,4]
//[5,6]
//]

byte[][] m_StackedCompressedObservations;

Expand Down Expand Up @@ -286,15 +291,15 @@ public BuiltInSensorType GetBuiltInSensorType()
/// Returns a read-only view of the observations that added.
/// </summary>
/// <returns>A read-only view of the observations list.</returns>
internal ReadOnlyCollection<ReadOnlyCollection<float>> GetStackedObservations()
internal ReadOnlyCollection<float> GetStackedObservations()
{
List<ReadOnlyCollection<float>> layer = new List<ReadOnlyCollection<float>>();
foreach (float[] l in m_StackedObservations)
List<float> observations = new List<float>();
for (var i = 0; i < m_NumStackedObservations; i++)
{
layer.Add(l.ToList().AsReadOnly());
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
observations.AddRange(m_StackedObservations[obsIndex].ToList());
}

return layer.AsReadOnly();
return observations.AsReadOnly();
}
}
}
14 changes: 13 additions & 1 deletion com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,29 +76,41 @@ public void AssertStackingReset()
public void TestVectorStacking()
{
VectorSensor wrapped = new VectorSensor(2);
ISensor sensor = new StackingSensor(wrapped, 3);
StackingSensor sensor = new StackingSensor(wrapped, 3);

wrapped.AddObservation(new[] { 1f, 2f });
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 1f, 2f });
var data = sensor.GetStackedObservations();
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 0f, 0f, 1f, 2f }));

sensor.Update();
wrapped.AddObservation(new[] { 3f, 4f });
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 2f, 3f, 4f });
data = sensor.GetStackedObservations();
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 1f, 2f, 3f, 4f }));

sensor.Update();
wrapped.AddObservation(new[] { 5f, 6f });
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f, 5f, 6f });
data = sensor.GetStackedObservations();
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 1f, 2f, 3f, 4f, 5f, 6f }));

sensor.Update();
wrapped.AddObservation(new[] { 7f, 8f });
SensorTestHelper.CompareObservation(sensor, new[] { 3f, 4f, 5f, 6f, 7f, 8f });
data = sensor.GetStackedObservations();
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 3f, 4f, 5f, 6f, 7f, 8f }));

sensor.Update();
wrapped.AddObservation(new[] { 9f, 10f });
SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f });
data = sensor.GetStackedObservations();
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f }));

// Check that if we don't call Update(), the same observations are produced
SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f });
data = sensor.GetStackedObservations();
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f }));
}

[Test]
Expand Down

0 comments on commit bf0919d

Please sign in to comment.