Skip to content

Commit

Permalink
Adding the goal conditioning sensors with the new observation specs (#…
Browse files Browse the repository at this point in the history
…5159)

* Fixing networks.py for the merge

* fix compile error

* Adding the goal conditioning sensors with the new observation specs

* addressing feedback

* I forgot to change the m_observationType

* Renaming Goal to GoalSignal (#5190)

* Renaming GOAL to GOAL_SIGNAL

* VectorSensorComponent to use new API

* Adding docstrings

* verbose pytest on github action

Co-authored-by: Chris Elion <chris.elion@unity3d.com>
  • Loading branch information
vincentpierre and Chris Elion committed Mar 29, 2021
1 parent 2072dd2 commit ca0fca7
Show file tree
Hide file tree
Showing 19 changed files with 194 additions and 55 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
pip freeze > pip_versions-${{ matrix.python-version }}.txt
cat pip_versions-${{ matrix.python-version }}.txt
- name: Run pytest
run: pytest --cov=ml-agents --cov=ml-agents-envs --cov=gym-unity --cov-report html --junitxml=junit/test-results-${{ matrix.python-version }}.xml -p no:warnings
run: pytest --cov=ml-agents --cov=ml-agents-envs --cov=gym-unity --cov-report html --junitxml=junit/test-results-${{ matrix.python-version }}.xml -p no:warnings -v
- name: Upload pytest test results
uses: actions/upload-artifact@v2
with:
Expand Down
1 change: 1 addition & 0 deletions com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public override void OnInspectorGUI()
EditorGUILayout.PropertyField(so.FindProperty("m_Height"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true);
}
EditorGUI.EndDisabledGroup();
EditorGUILayout.PropertyField(so.FindProperty("m_Compression"), true);
Expand Down
30 changes: 30 additions & 0 deletions com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using UnityEditor;
using Unity.MLAgents.Sensors;

namespace Unity.MLAgents.Editor
{
[CustomEditor(typeof(VectorSensorComponent))]
[CanEditMultipleObjects]
internal class VectorSensorComponentEditor : UnityEditor.Editor
{
public override void OnInspectorGUI()
{
var so = serializedObject;
so.Update();

// Drawing the VectorSensorComponent

EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
// These fields affect the sensor order or observation size,
// So can't be changed at runtime.
EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationSize"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true);
}
EditorGUI.EndDisabledGroup();

so.ApplyModifiedProperties();
}
}
}
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ static ObservationReflection() {
"b25fdHlwZRgHIAEoDjIqLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0",
"aW9uVHlwZVByb3RvEgwKBG5hbWUYCCABKAkaGQoJRmxvYXREYXRhEgwKBGRh",
"dGEYASADKAJCEgoQb2JzZXJ2YXRpb25fZGF0YSopChRDb21wcmVzc2lvblR5",
"cGVQcm90bxIICgROT05FEAASBwoDUE5HEAEqRgoUT2JzZXJ2YXRpb25UeXBl",
"UHJvdG8SCwoHREVGQVVMVBAAEggKBEdPQUwQARIKCgZSRVdBUkQQAhILCgdN",
"RVNTQUdFEANCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVj",
"dHNiBnByb3RvMw=="));
"cGVQcm90bxIICgROT05FEAASBwoDUE5HEAEqQAoUT2JzZXJ2YXRpb25UeXBl",
"UHJvdG8SCwoHREVGQVVMVBAAEg8KC0dPQUxfU0lHTkFMEAEiBAgCEAIiBAgD",
"EANCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy",
"b3RvMw=="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
Expand All @@ -56,9 +56,7 @@ internal enum CompressionTypeProto {

internal enum ObservationTypeProto {
[pbr::OriginalName("DEFAULT")] Default = 0,
[pbr::OriginalName("GOAL")] Goal = 1,
[pbr::OriginalName("REWARD")] Reward = 2,
[pbr::OriginalName("MESSAGE")] Message = 3,
[pbr::OriginalName("GOAL_SIGNAL")] GoalSignal = 1,
}

#endregion
Expand Down
5 changes: 3 additions & 2 deletions com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,17 @@ public SensorCompressionType CompressionType
/// <param name="grayscale">Whether to convert the generated image to grayscale or keep color.</param>
/// <param name="name">The name of the camera sensor.</param>
/// <param name="compression">The compression to apply to the generated image.</param>
/// <param name="observationType">The type of observation.</param>
public CameraSensor(
Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression)
Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression, ObservationType observationType = ObservationType.Default)
{
m_Camera = camera;
m_Width = width;
m_Height = height;
m_Grayscale = grayscale;
m_Name = name;
var channels = grayscale ? 1 : 3;
m_ObservationSpec = ObservationSpec.Visual(height, width, channels);
m_ObservationSpec = ObservationSpec.Visual(height, width, channels, observationType);
m_CompressionType = compression;
}

Expand Down
14 changes: 13 additions & 1 deletion com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ public bool Grayscale
set { m_Grayscale = value; }
}

[HideInInspector, SerializeField]
ObservationType m_ObservationType;

/// <summary>
/// The type of the observation.
/// </summary>
public ObservationType ObservationType
{
get { return m_ObservationType; }
set { m_ObservationType = value; UpdateSensor(); }
}

[HideInInspector, SerializeField]
[Range(1, 50)]
[Tooltip("Number of camera frames that will be stacked before being fed to the neural network.")]
Expand Down Expand Up @@ -108,7 +120,7 @@ public int ObservationStacks
/// <returns>The created <see cref="CameraSensor"/> object for this component.</returns>
public override ISensor[] CreateSensors()
{
m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression);
m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression, m_ObservationType);

if (ObservationStacks != 1)
{
Expand Down
12 changes: 1 addition & 11 deletions com.unity.ml-agents/Runtime/Sensors/ISensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,7 @@ public enum ObservationType
/// <summary>
/// Collected observations contain goal information.
/// </summary>
Goal = 1,

/// <summary>
/// Collected observations contain reward information.
/// </summary>
Reward = 2,

/// <summary>
/// Collected observations are messages from other agents.
/// </summary>
Message = 3,
GoalSignal = 1,
}

/// <summary>
Expand Down
10 changes: 7 additions & 3 deletions com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,20 @@ public class VectorSensor : ISensor, IBuiltInSensor
/// </summary>
/// <param name="observationSize">Number of vector observations.</param>
/// <param name="name">Name of the sensor.</param>
public VectorSensor(int observationSize, string name = null)
public VectorSensor(int observationSize, string name = null, ObservationType observationType = ObservationType.Default)
{
if (name == null)
if (string.IsNullOrEmpty(name))
{
name = $"VectorSensor_size{observationSize}";
if (observationType != ObservationType.Default)
{
name += $"_{observationType.ToString()}";
}
}

m_Observations = new List<float>(observationSize);
m_Name = name;
m_ObservationSpec = ObservationSpec.Vector(observationSize);
m_ObservationSpec = ObservationSpec.Vector(observationSize, observationType);
}

/// <inheritdoc/>
Expand Down
68 changes: 68 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using UnityEngine;
using UnityEngine.Serialization;

namespace Unity.MLAgents.Sensors
{
/// <summary>
/// A SensorComponent that creates a <see cref="VectorSensor"/>.
/// </summary>
[AddComponentMenu("ML Agents/Vector Sensor", (int)MenuGroup.Sensors)]
public class VectorSensorComponent : SensorComponent
{
/// <summary>
/// Name of the generated <see cref="VectorSensor"/> object.
/// Note that changing this at runtime does not affect how the Agent sorts the sensors.
/// </summary>
public string SensorName
{
get { return m_SensorName; }
set { m_SensorName = value; }
}
[HideInInspector, SerializeField]
private string m_SensorName = "VectorSensor";

/// <summary>
/// The number of float observations in the VectorSensor
/// </summary>
public int ObservationSize
{
get { return m_ObservationSize; }
set { m_ObservationSize = value; }
}

[HideInInspector, SerializeField]
int m_ObservationSize;

[HideInInspector, SerializeField]
ObservationType m_ObservationType;

VectorSensor m_Sensor;

/// <summary>
/// The type of the observation.
/// </summary>
public ObservationType ObservationType
{
get { return m_ObservationType; }
set { m_ObservationType = value; }
}

/// <summary>
/// Creates a VectorSensor.
/// </summary>
/// <returns></returns>
public override ISensor[] CreateSensors()
{
m_Sensor = new VectorSensor(m_ObservationSize, m_SensorName, m_ObservationType);
return new ISensor[] { m_Sensor };
}

/// <summary>
/// Returns the underlying VectorSensor
/// </summary>
public VectorSensor GetSensor()
{
return m_Sensor;
}
}
}
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,22 @@ public void TestCameraSensor()
}
}
}

[Test]
public void TestObservationType()
{
var width = 24;
var height = 16;
var camera = Camera.main;
var sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None);
var spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.Default);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.GoalSignal);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.GoalSignal);
}
}
}
14 changes: 14 additions & 0 deletions com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ public void TestAddObservationFloat()
SensorTestHelper.CompareObservation(sensor, new[] { 1.2f });
}

[Test]
public void TestObservationType()
{
var sensor = new VectorSensor(1);
var spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new VectorSensor(1, observationType: ObservationType.Default);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new VectorSensor(1, observationType: ObservationType.GoalSignal);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.GoalSignal);
}

[Test]
public void TestAddObservationInt()
{
Expand Down
6 changes: 1 addition & 5 deletions ml-agents-envs/mlagents_envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,7 @@ class ObservationType(Enum):
# Observation information is generic.
DEFAULT = 0
# Observation contains goal information for current task.
GOAL = 1
# Observation contains reward information for current task.
REWARD = 2
# Observation contains a message from another agent.
MESSAGE = 3
GOAL_SIGNAL = 1


class ObservationSpec(NamedTuple):
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,9 @@ class ObservationTypeProto(builtin___int):
@classmethod
def items(cls) -> typing___List[typing___Tuple[builtin___str, 'ObservationTypeProto']]: ...
DEFAULT = typing___cast('ObservationTypeProto', 0)
GOAL = typing___cast('ObservationTypeProto', 1)
REWARD = typing___cast('ObservationTypeProto', 2)
MESSAGE = typing___cast('ObservationTypeProto', 3)
GOAL_SIGNAL = typing___cast('ObservationTypeProto', 1)
DEFAULT = typing___cast('ObservationTypeProto', 0)
GOAL = typing___cast('ObservationTypeProto', 1)
REWARD = typing___cast('ObservationTypeProto', 2)
MESSAGE = typing___cast('ObservationTypeProto', 3)
GOAL_SIGNAL = typing___cast('ObservationTypeProto', 1)

class ObservationProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/tests/simple_test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _make_observation_specs(self) -> List[ObservationSpec]:
obs_spec[i] = ObservationSpec(
shape=obs_spec[i].shape,
dimension_property=obs_spec[i].dimension_property,
observation_type=ObservationType.GOAL,
observation_type=ObservationType.GOAL_SIGNAL,
name=obs_spec[i].name,
)
return obs_spec
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/torch/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self._total_goal_enc_size = 0
self._goal_processor_indices: List[int] = []
for i in range(len(observation_specs)):
if observation_specs[i].observation_type == ObservationType.GOAL:
if observation_specs[i].observation_type == ObservationType.GOAL_SIGNAL:
self._total_goal_enc_size += self.embedding_sizes[i]
self._goal_processor_indices.append(i)

Expand Down
Loading

0 comments on commit ca0fca7

Please sign in to comment.