diff --git a/com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs index 6efe32c66e..62146a6d82 100644 --- a/com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Linq; using UnityEngine; namespace Unity.MLAgents.Sensors @@ -11,7 +12,7 @@ public class GridSensorComponent : SensorComponent { // dummy sensor only used for debug gizmo GridSensorBase m_DebugSensor; - List m_Sensors; + List m_Sensors; internal BoxOverlapChecker m_BoxOverlapChecker; [HideInInspector, SerializeField] @@ -196,7 +197,6 @@ public int ObservationStacks /// public override ISensor[] CreateSensors() { - m_Sensors = new List(); m_BoxOverlapChecker = new BoxOverlapChecker( m_CellScale, m_GridSize, @@ -213,29 +213,33 @@ public override ISensor[] CreateSensors() m_DebugSensor = new GridSensorBase("DebugGridSensor", m_CellScale, m_GridSize, m_DetectableTags, SensorCompressionType.None); m_BoxOverlapChecker.RegisterDebugSensor(m_DebugSensor); - var gridSensors = GetGridSensors(); - if (gridSensors == null || gridSensors.Length < 1) + m_Sensors = GetGridSensors().ToList(); + if (m_Sensors == null || m_Sensors.Count < 1) { throw new UnityAgentsException("GridSensorComponent received no sensors. Specify at least one observation type (OneHot/Counting) to use grid sensors." + "If you're overriding GridSensorComponent.GetGridSensors(), return at least one grid sensor."); } - foreach (var sensor in gridSensors) + // Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once + m_Sensors[0].m_BoxOverlapChecker = m_BoxOverlapChecker; + foreach (var sensor in m_Sensors) { - if (ObservationStacks != 1) - { - m_Sensors.Add(new StackingSensor(sensor, ObservationStacks)); - } - else - { - m_Sensors.Add(sensor); - } m_BoxOverlapChecker.RegisterSensor(sensor); } - // Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once - ((GridSensorBase)m_Sensors[0]).m_BoxOverlapChecker = m_BoxOverlapChecker; - return m_Sensors.ToArray(); + if (ObservationStacks != 1) + { + var sensors = new ISensor[m_Sensors.Count]; + for (var i = 0; i < m_Sensors.Count; i++) + { + sensors[i] = new StackingSensor(m_Sensors[i], ObservationStacks); + } + return sensors; + } + else + { + return m_Sensors.ToArray(); + } } /// @@ -262,7 +266,7 @@ internal void UpdateSensor() m_BoxOverlapChecker.ColliderMask = m_ColliderMask; foreach (var sensor in m_Sensors) { - ((GridSensorBase)sensor).CompressionType = m_CompressionType; + sensor.CompressionType = m_CompressionType; } } } diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs index 8fb7912535..35e0ada86e 100644 --- a/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs @@ -89,7 +89,7 @@ public void TestCreateSensor() gridSensorComponent.SetComponentParameters(tags, useGridSensorBase: true); gridSensorComponent.CreateSensors(); - var componentSensor = (List)typeof(GridSensorComponent).GetField("m_Sensors", + var componentSensor = (List)typeof(GridSensorComponent).GetField("m_Sensors", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(gridSensorComponent); Assert.AreEqual(componentSensor.Count, 1); } @@ -191,6 +191,17 @@ public void TestNoSensors() gridSensorComponent.CreateSensors(); }); } + + [Test] + public void TestStackedSensors() + { + testGo.tag = k_Tag2; + string[] tags = { k_Tag1, k_Tag2 }; + gridSensorComponent.SetComponentParameters(tags, useGridSensorBase: true); + gridSensorComponent.ObservationStacks = 3; + var sensors = gridSensorComponent.CreateSensors(); + Assert.IsInstanceOf(typeof(StackingSensor), sensors[0]); + } } } #endif