Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the goal conditioning sensors with the new observation specs #5159

Merged
merged 12 commits into from
Mar 29, 2021
1 change: 1 addition & 0 deletions com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public override void OnInspectorGUI()
EditorGUILayout.PropertyField(so.FindProperty("m_Height"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true);
}
EditorGUI.EndDisabledGroup();
EditorGUILayout.PropertyField(so.FindProperty("m_Compression"), true);
Expand Down
30 changes: 30 additions & 0 deletions com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using UnityEditor;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Editor
{
[CustomEditor(typeof(VectorSensorComponent))]
[CanEditMultipleObjects]
internal class VectorSensorComponentEditor : UnityEditor.Editor
{
public override void OnInspectorGUI()
{
var so = serializedObject;
so.Update();

// Drawing the VectorSensorComponent

EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
// These fields affect the sensor order or observation size,
// So can't be changed at runtime.
EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_observationSize"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true);
}
EditorGUI.EndDisabledGroup();

so.ApplyModifiedProperties();
}
}
}
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,17 @@ public SensorCompressionType CompressionType
/// <param name="grayscale">Whether to convert the generated image to grayscale or keep color.</param>
/// <param name="name">The name of the camera sensor.</param>
/// <param name="compression">The compression to apply to the generated image.</param>
/// <param name="observationType">The type of observation.</param>
public CameraSensor(
Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression)
Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression, ObservationType observationType = ObservationType.Default)
{
m_Camera = camera;
m_Width = width;
m_Height = height;
m_Grayscale = grayscale;
m_Name = name;
var channels = grayscale ? 1 : 3;
m_ObservationSpec = ObservationSpec.Visual(height, width, channels);
m_ObservationSpec = ObservationSpec.Visual(height, width, channels, observationType);
m_CompressionType = compression;
}

Expand Down
14 changes: 13 additions & 1 deletion com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ public bool Grayscale
set { m_Grayscale = value; }
}

[HideInInspector, SerializeField]
ObservationType m_ObservationType;

/// <summary>
/// The type of the observation.
/// </summary>
public ObservationType SensorObservationType
{
get { return m_ObservationType; }
set { m_ObservationType = value; UpdateSensor(); }
}

[HideInInspector, SerializeField]
[Range(1, 50)]
[Tooltip("Number of camera frames that will be stacked before being fed to the neural network.")]
Expand Down Expand Up @@ -108,7 +120,7 @@ public int ObservationStacks
/// <returns>The created <see cref="CameraSensor"/> object for this component.</returns>
public override ISensor CreateSensor()
{
m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression);
m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression, m_ObservationType);

if (ObservationStacks != 1)
{
Expand Down
10 changes: 0 additions & 10 deletions com.unity.ml-agents/Runtime/Sensors/ISensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,6 @@ public enum ObservationType
/// Collected observations contain goal information.
/// </summary>
Goal = 1,

chriselion marked this conversation as resolved.
Show resolved Hide resolved
/// <summary>
/// Collected observations contain reward information.
/// </summary>
Reward = 2,

/// <summary>
/// Collected observations are messages from other agents.
/// </summary>
Message = 3,
}

/// <summary>
Expand Down
10 changes: 7 additions & 3 deletions com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,20 @@ public class VectorSensor : ISensor, IBuiltInSensor
/// </summary>
/// <param name="observationSize">Number of vector observations.</param>
/// <param name="name">Name of the sensor.</param>
public VectorSensor(int observationSize, string name = null)
public VectorSensor(int observationSize, string name = null, ObservationType observationType = ObservationType.Default)
{
if (name == null)
if (name == null || name == "")
vincentpierre marked this conversation as resolved.
Show resolved Hide resolved
{
name = $"VectorSensor_size{observationSize}";
if (observationType != ObservationType.Default)
{
name += "_goal";
chriselion marked this conversation as resolved.
Show resolved Hide resolved
}
}

m_Observations = new List<float>(observationSize);
m_Name = name;
m_ObservationSpec = ObservationSpec.Vector(observationSize);
m_ObservationSpec = ObservationSpec.Vector(observationSize, observationType);
}

/// <inheritdoc/>
Expand Down
62 changes: 62 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using UnityEngine;
using UnityEngine.Serialization;

namespace Unity.MLAgents.Sensors
{
[AddComponentMenu("ML Agents/Vector Sensor", (int)MenuGroup.Sensors)]
public class VectorSensorComponent : SensorComponent
vincentpierre marked this conversation as resolved.
Show resolved Hide resolved
{
/// <summary>
/// Name of the generated <see cref="VectorSensor"/> object.
/// Note that changing this at runtime does not affect how the Agent sorts the sensors.
/// </summary>
public string SensorName
{
get { return m_SensorName; }
set { m_SensorName = value; }
}
[HideInInspector, SerializeField]
private string m_SensorName = "VectorSensor";

public int ObservationSize
{
get { return m_observationSize; }
set { m_observationSize = value; }
}

[HideInInspector, SerializeField]
int m_observationSize;
chriselion marked this conversation as resolved.
Show resolved Hide resolved

[HideInInspector, SerializeField]
ObservationType m_ObservationType;

VectorSensor m_sensor;
chriselion marked this conversation as resolved.
Show resolved Hide resolved

public ObservationType ObservationType
{
get { return m_ObservationType; }
set { m_ObservationType = value; }
}

/// <summary>
/// Creates a VectorSensor.
/// </summary>
/// <returns></returns>
public override ISensor CreateSensor()
{
m_sensor = new VectorSensor(m_observationSize, m_SensorName, m_ObservationType);
return m_sensor;
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
return new[] { m_observationSize };
}

public VectorSensor GetSensor()
{
return m_sensor;
}
}
}
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,22 @@ public void TestCameraSensor()
}
}
}

[Test]
public void TestObservationType()
{
var width = 24;
var height = 16;
var camera = Camera.main;
var sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None);
var spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.Default);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.Goal);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Goal);
}
}
}
14 changes: 14 additions & 0 deletions com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ public void TestAddObservationFloat()
SensorTestHelper.CompareObservation(sensor, new[] { 1.2f });
}

[Test]
public void TestObservationType()
{
var sensor = new VectorSensor(1);
var spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new VectorSensor(1, observationType: ObservationType.Default);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new VectorSensor(1, observationType: ObservationType.Goal);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Goal);
}

[Test]
public void TestAddObservationInt()
{
Expand Down
4 changes: 0 additions & 4 deletions ml-agents-envs/mlagents_envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,10 +487,6 @@ class ObservationType(Enum):
DEFAULT = 0
# Observation contains goal information for current task.
GOAL = 1
# Observation contains reward information for current task.
REWARD = 2
# Observation contains a message from another agent.
MESSAGE = 3


class ObservationSpec(NamedTuple):
Expand Down