Skip to content

Commit 9160d85

Browse files
authored
Method to return stacked observations (#5547)
* Method to return stacked observations * Added testing and returning staked observation flat. * Update the comment lines. * Remove brainstorm commits.
1 parent 05c0275 commit 9160d85

File tree

4 files changed

+49
-4
lines changed

4 files changed

+49
-4
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ and this project adheres to
1818

1919
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
2020
- Added the capacity to initialize behaviors from any checkpoint and not just the latest one (#5525)
21-
21+
- Added the ability to get a read-only view of the stacked observations (#5523)
2222
#### ml-agents / ml-agents-envs / gym-unity (Python)
2323
- Set gym version in gym-unity to gym release 0.20.0
2424
- Added support for having `beta`, `epsilon`, and `learning rate` on separate schedules (affects only PPO and POCA). (#5538)

com.unity.ml-agents/Runtime/Agent.cs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,11 @@ internal struct AgentParameters
320320
/// </summary>
321321
internal VectorSensor collectObservationsSensor;
322322

323+
/// <summary>
324+
/// StackingSensor which is written to by AddVectorObs
325+
/// </summary>
326+
internal StackingSensor stackedCollectObservationsSensor;
327+
323328
private RecursionChecker m_CollectObservationsChecker = new RecursionChecker("CollectObservations");
324329
private RecursionChecker m_OnEpisodeBeginChecker = new RecursionChecker("OnEpisodeBegin");
325330

@@ -981,9 +986,9 @@ internal void InitializeSensors()
981986
collectObservationsSensor = new VectorSensor(param.VectorObservationSize);
982987
if (param.NumStackedVectorObservations > 1)
983988
{
984-
var stackingSensor = new StackingSensor(
989+
stackedCollectObservationsSensor = new StackingSensor(
985990
collectObservationsSensor, param.NumStackedVectorObservations);
986-
sensors.Add(stackingSensor);
991+
sensors.Add(stackedCollectObservationsSensor);
987992
}
988993
else
989994
{
@@ -1179,6 +1184,17 @@ public ReadOnlyCollection<float> GetObservations()
11791184
return collectObservationsSensor.GetObservations();
11801185
}
11811186

1187+
/// <summary>
1188+
/// Returns a read-only view of the stacked observations that were generated in
1189+
/// <see cref="CollectObservations(VectorSensor)"/>. This is mainly useful inside of a
1190+
/// <see cref="Heuristic(in ActionBuffers)"/> method to avoid recomputing the observations.
1191+
/// </summary>
1192+
/// <returns>A read-only view of the stacked observations list.</returns>
1193+
public ReadOnlyCollection<float> GetStackedObservations()
1194+
{
1195+
return stackedCollectObservationsSensor.GetStackedObservations();
1196+
}
1197+
11821198
/// <summary>
11831199
/// Implement `WriteDiscreteActionMask()` to collects the masks for discrete
11841200
/// actions. When using discrete actions, the agent will not perform the masked

com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using System;
2+
using System.Collections.Generic;
3+
using System.Collections.ObjectModel;
24
using System.Linq;
35
using UnityEngine;
46
using Unity.Barracuda;
@@ -279,5 +281,20 @@ public BuiltInSensorType GetBuiltInSensorType()
279281
IBuiltInSensor wrappedBuiltInSensor = m_WrappedSensor as IBuiltInSensor;
280282
return wrappedBuiltInSensor?.GetBuiltInSensorType() ?? BuiltInSensorType.Unknown;
281283
}
284+
285+
/// <summary>
286+
/// Returns the stacked observations as a read-only collection.
287+
/// </summary>
288+
/// <returns>The stacked observations as a read-only collection.</returns>
289+
internal ReadOnlyCollection<float> GetStackedObservations()
290+
{
291+
List<float> observations = new List<float>();
292+
for (var i = 0; i < m_NumStackedObservations; i++)
293+
{
294+
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
295+
observations.AddRange(m_StackedObservations[obsIndex].ToList());
296+
}
297+
return observations.AsReadOnly();
298+
}
282299
}
283300
}

com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,29 +76,41 @@ public void AssertStackingReset()
7676
public void TestVectorStacking()
7777
{
7878
VectorSensor wrapped = new VectorSensor(2);
79-
ISensor sensor = new StackingSensor(wrapped, 3);
79+
StackingSensor sensor = new StackingSensor(wrapped, 3);
8080

8181
wrapped.AddObservation(new[] { 1f, 2f });
8282
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 1f, 2f });
83+
var data = sensor.GetStackedObservations();
84+
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 0f, 0f, 1f, 2f }));
8385

8486
sensor.Update();
8587
wrapped.AddObservation(new[] { 3f, 4f });
8688
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 2f, 3f, 4f });
89+
data = sensor.GetStackedObservations();
90+
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 1f, 2f, 3f, 4f }));
8791

8892
sensor.Update();
8993
wrapped.AddObservation(new[] { 5f, 6f });
9094
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f, 5f, 6f });
95+
data = sensor.GetStackedObservations();
96+
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 1f, 2f, 3f, 4f, 5f, 6f }));
9197

9298
sensor.Update();
9399
wrapped.AddObservation(new[] { 7f, 8f });
94100
SensorTestHelper.CompareObservation(sensor, new[] { 3f, 4f, 5f, 6f, 7f, 8f });
101+
data = sensor.GetStackedObservations();
102+
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 3f, 4f, 5f, 6f, 7f, 8f }));
95103

96104
sensor.Update();
97105
wrapped.AddObservation(new[] { 9f, 10f });
98106
SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f });
107+
data = sensor.GetStackedObservations();
108+
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f }));
99109

100110
// Check that if we don't call Update(), the same observations are produced
101111
SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f });
112+
data = sensor.GetStackedObservations();
113+
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f }));
102114
}
103115

104116
[Test]

0 commit comments

Comments
 (0)