diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 1aa9df2545..ec43218bec 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -61,7 +61,7 @@ jobs: pip freeze > pip_versions-${{ matrix.python-version }}.txt cat pip_versions-${{ matrix.python-version }}.txt - name: Run pytest - run: pytest --cov=ml-agents --cov=ml-agents-envs --cov=gym-unity --cov-report html --junitxml=junit/test-results-${{ matrix.python-version }}.xml -p no:warnings + run: pytest --cov=ml-agents --cov=ml-agents-envs --cov=gym-unity --cov-report html --junitxml=junit/test-results-${{ matrix.python-version }}.xml -p no:warnings -v - name: Upload pytest test results uses: actions/upload-artifact@v2 with: diff --git a/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs b/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs index c6389536e8..c77cc86409 100644 --- a/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs +++ b/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs @@ -25,6 +25,7 @@ public override void OnInspectorGUI() EditorGUILayout.PropertyField(so.FindProperty("m_Height"), true); EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true); EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true); } EditorGUI.EndDisabledGroup(); EditorGUILayout.PropertyField(so.FindProperty("m_Compression"), true); diff --git a/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs b/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs new file mode 100644 index 0000000000..ff1cfd2e78 --- /dev/null +++ b/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs @@ -0,0 +1,30 @@ +using UnityEditor; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Editor +{ + [CustomEditor(typeof(VectorSensorComponent))] + [CanEditMultipleObjects] + internal class VectorSensorComponentEditor : UnityEditor.Editor + { + public override void OnInspectorGUI() + { + var so = serializedObject; + so.Update(); + + // Drawing the VectorSensorComponent + + EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); + { + // These fields affect the sensor order or observation size, + // So can't be changed at runtime. + EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservationSize"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true); + } + EditorGUI.EndDisabledGroup(); + + so.ApplyModifiedProperties(); + } + } +} diff --git a/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta b/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta new file mode 100644 index 0000000000..9862a23944 --- /dev/null +++ b/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: aa0230c3402f04921acdbbdb61f6ff00 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs index 55904935c8..3e23c8d991 100644 --- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs @@ -35,10 +35,10 @@ static ObservationReflection() { "b25fdHlwZRgHIAEoDjIqLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0", "aW9uVHlwZVByb3RvEgwKBG5hbWUYCCABKAkaGQoJRmxvYXREYXRhEgwKBGRh", "dGEYASADKAJCEgoQb2JzZXJ2YXRpb25fZGF0YSopChRDb21wcmVzc2lvblR5", - "cGVQcm90bxIICgROT05FEAASBwoDUE5HEAEqRgoUT2JzZXJ2YXRpb25UeXBl", - "UHJvdG8SCwoHREVGQVVMVBAAEggKBEdPQUwQARIKCgZSRVdBUkQQAhILCgdN", - "RVNTQUdFEANCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVj", - "dHNiBnByb3RvMw==")); + "cGVQcm90bxIICgROT05FEAASBwoDUE5HEAEqQAoUT2JzZXJ2YXRpb25UeXBl", + "UHJvdG8SCwoHREVGQVVMVBAAEg8KC0dPQUxfU0lHTkFMEAEiBAgCEAIiBAgD", + "EANCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy", + "b3RvMw==")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { }, new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto), }, new pbr::GeneratedClrTypeInfo[] { @@ -56,9 +56,7 @@ internal enum CompressionTypeProto { internal enum ObservationTypeProto { [pbr::OriginalName("DEFAULT")] Default = 0, - [pbr::OriginalName("GOAL")] Goal = 1, - [pbr::OriginalName("REWARD")] Reward = 2, - [pbr::OriginalName("MESSAGE")] Message = 3, + [pbr::OriginalName("GOAL_SIGNAL")] GoalSignal = 1, } #endregion diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs index 3cede2408d..805e7e302b 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs @@ -44,8 +44,9 @@ public SensorCompressionType CompressionType /// Whether to convert the generated image to grayscale or keep color. /// The name of the camera sensor. /// The compression to apply to the generated image. + /// The type of observation. public CameraSensor( - Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression) + Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression, ObservationType observationType = ObservationType.Default) { m_Camera = camera; m_Width = width; @@ -53,7 +54,7 @@ public CameraSensor( m_Grayscale = grayscale; m_Name = name; var channels = grayscale ? 1 : 3; - m_ObservationSpec = ObservationSpec.Visual(height, width, channels); + m_ObservationSpec = ObservationSpec.Visual(height, width, channels, observationType); m_CompressionType = compression; } diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs index 41582d35c6..3df6c72764 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs @@ -75,6 +75,18 @@ public bool Grayscale set { m_Grayscale = value; } } + [HideInInspector, SerializeField] + ObservationType m_ObservationType; + + /// + /// The type of the observation. + /// + public ObservationType ObservationType + { + get { return m_ObservationType; } + set { m_ObservationType = value; UpdateSensor(); } + } + [HideInInspector, SerializeField] [Range(1, 50)] [Tooltip("Number of camera frames that will be stacked before being fed to the neural network.")] @@ -108,7 +120,7 @@ public int ObservationStacks /// The created object for this component. public override ISensor[] CreateSensors() { - m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression); + m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression, m_ObservationType); if (ObservationStacks != 1) { diff --git a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs index 62f6f78a08..8ca07ad830 100644 --- a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs @@ -42,17 +42,7 @@ public enum ObservationType /// /// Collected observations contain goal information. /// - Goal = 1, - - /// - /// Collected observations contain reward information. - /// - Reward = 2, - - /// - /// Collected observations are messages from other agents. - /// - Message = 3, + GoalSignal = 1, } /// diff --git a/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs index a09d2d58a2..37eb052289 100644 --- a/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs @@ -21,16 +21,20 @@ public class VectorSensor : ISensor, IBuiltInSensor /// /// Number of vector observations. /// Name of the sensor. - public VectorSensor(int observationSize, string name = null) + public VectorSensor(int observationSize, string name = null, ObservationType observationType = ObservationType.Default) { - if (name == null) + if (string.IsNullOrEmpty(name)) { name = $"VectorSensor_size{observationSize}"; + if (observationType != ObservationType.Default) + { + name += $"_{observationType.ToString()}"; + } } m_Observations = new List(observationSize); m_Name = name; - m_ObservationSpec = ObservationSpec.Vector(observationSize); + m_ObservationSpec = ObservationSpec.Vector(observationSize, observationType); } /// diff --git a/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs new file mode 100644 index 0000000000..abd8bb09f5 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs @@ -0,0 +1,68 @@ +using UnityEngine; +using UnityEngine.Serialization; + +namespace Unity.MLAgents.Sensors +{ + /// + /// A SensorComponent that creates a . + /// + [AddComponentMenu("ML Agents/Vector Sensor", (int)MenuGroup.Sensors)] + public class VectorSensorComponent : SensorComponent + { + /// + /// Name of the generated object. + /// Note that changing this at runtime does not affect how the Agent sorts the sensors. + /// + public string SensorName + { + get { return m_SensorName; } + set { m_SensorName = value; } + } + [HideInInspector, SerializeField] + private string m_SensorName = "VectorSensor"; + + /// + /// The number of float observations in the VectorSensor + /// + public int ObservationSize + { + get { return m_ObservationSize; } + set { m_ObservationSize = value; } + } + + [HideInInspector, SerializeField] + int m_ObservationSize; + + [HideInInspector, SerializeField] + ObservationType m_ObservationType; + + VectorSensor m_Sensor; + + /// + /// The type of the observation. + /// + public ObservationType ObservationType + { + get { return m_ObservationType; } + set { m_ObservationType = value; } + } + + /// + /// Creates a VectorSensor. + /// + /// + public override ISensor[] CreateSensors() + { + m_Sensor = new VectorSensor(m_ObservationSize, m_SensorName, m_ObservationType); + return new ISensor[] { m_Sensor }; + } + + /// + /// Returns the underlying VectorSensor + /// + public VectorSensor GetSensor() + { + return m_Sensor; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta b/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta new file mode 100644 index 0000000000..c867a60f2b --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 38b7cc1f5819445aa85e9a9b054552dc +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs index 925544def3..7dc9c42fde 100644 --- a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs @@ -37,5 +37,22 @@ public void TestCameraSensor() } } } + + [Test] + public void TestObservationType() + { + var width = 24; + var height = 16; + var camera = Camera.main; + var sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None); + var spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default); + sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.Default); + spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default); + sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.GoalSignal); + spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.GoalSignal); + } } } diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs index 5326bca868..f58606e99c 100644 --- a/com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs @@ -42,6 +42,20 @@ public void TestAddObservationFloat() SensorTestHelper.CompareObservation(sensor, new[] { 1.2f }); } + [Test] + public void TestObservationType() + { + var sensor = new VectorSensor(1); + var spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default); + sensor = new VectorSensor(1, observationType: ObservationType.Default); + spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default); + sensor = new VectorSensor(1, observationType: ObservationType.GoalSignal); + spec = sensor.GetObservationSpec(); + Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.GoalSignal); + } + [Test] public void TestAddObservationInt() { diff --git a/ml-agents-envs/mlagents_envs/base_env.py b/ml-agents-envs/mlagents_envs/base_env.py index d80f752b71..000cef5709 100644 --- a/ml-agents-envs/mlagents_envs/base_env.py +++ b/ml-agents-envs/mlagents_envs/base_env.py @@ -486,11 +486,7 @@ class ObservationType(Enum): # Observation information is generic. DEFAULT = 0 # Observation contains goal information for current task. - GOAL = 1 - # Observation contains reward information for current task. - REWARD = 2 - # Observation contains a message from another agent. - MESSAGE = 3 + GOAL_SIGNAL = 1 class ObservationSpec(NamedTuple): diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py index 2acc6959db..838ca1d87d 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py +++ b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py @@ -20,7 +20,7 @@ name='mlagents_envs/communicator_objects/observation.proto', package='communicator_objects', syntax='proto3', - serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\x8f\x03\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x12\x1c\n\x14\x64imension_properties\x18\x06 \x03(\x05\x12\x44\n\x10observation_type\x18\x07 \x01(\x0e\x32*.communicator_objects.ObservationTypeProto\x12\x0c\n\x04name\x18\x08 \x01(\t\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01*F\n\x14ObservationTypeProto\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x08\n\x04GOAL\x10\x01\x12\n\n\x06REWARD\x10\x02\x12\x0b\n\x07MESSAGE\x10\x03\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') + serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\x8f\x03\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x12\x1c\n\x14\x64imension_properties\x18\x06 \x03(\x05\x12\x44\n\x10observation_type\x18\x07 \x01(\x0e\x32*.communicator_objects.ObservationTypeProto\x12\x0c\n\x04name\x18\x08 \x01(\t\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01*@\n\x14ObservationTypeProto\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x0f\n\x0bGOAL_SIGNAL\x10\x01\"\x04\x08\x02\x10\x02\"\x04\x08\x03\x10\x03\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') ) _COMPRESSIONTYPEPROTO = _descriptor.EnumDescriptor( @@ -57,22 +57,14 @@ options=None, type=None), _descriptor.EnumValueDescriptor( - name='GOAL', index=1, number=1, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='REWARD', index=2, number=2, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='MESSAGE', index=3, number=3, + name='GOAL_SIGNAL', index=1, number=1, options=None, type=None), ], containing_type=None, options=None, serialized_start=523, - serialized_end=593, + serialized_end=587, ) _sym_db.RegisterEnumDescriptor(_OBSERVATIONTYPEPROTO) @@ -80,9 +72,7 @@ NONE = 0 PNG = 1 DEFAULT = 0 -GOAL = 1 -REWARD = 2 -MESSAGE = 3 +GOAL_SIGNAL = 1 diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi index 0afb9cd458..6427c50851 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi +++ b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi @@ -64,13 +64,9 @@ class ObservationTypeProto(builtin___int): @classmethod def items(cls) -> typing___List[typing___Tuple[builtin___str, 'ObservationTypeProto']]: ... DEFAULT = typing___cast('ObservationTypeProto', 0) - GOAL = typing___cast('ObservationTypeProto', 1) - REWARD = typing___cast('ObservationTypeProto', 2) - MESSAGE = typing___cast('ObservationTypeProto', 3) + GOAL_SIGNAL = typing___cast('ObservationTypeProto', 1) DEFAULT = typing___cast('ObservationTypeProto', 0) -GOAL = typing___cast('ObservationTypeProto', 1) -REWARD = typing___cast('ObservationTypeProto', 2) -MESSAGE = typing___cast('ObservationTypeProto', 3) +GOAL_SIGNAL = typing___cast('ObservationTypeProto', 1) class ObservationProto(google___protobuf___message___Message): DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... diff --git a/ml-agents/mlagents/trainers/tests/simple_test_envs.py b/ml-agents/mlagents/trainers/tests/simple_test_envs.py index d4c25def4d..489fecfa46 100644 --- a/ml-agents/mlagents/trainers/tests/simple_test_envs.py +++ b/ml-agents/mlagents/trainers/tests/simple_test_envs.py @@ -121,7 +121,7 @@ def _make_observation_specs(self) -> List[ObservationSpec]: obs_spec[i] = ObservationSpec( shape=obs_spec[i].shape, dimension_property=obs_spec[i].dimension_property, - observation_type=ObservationType.GOAL, + observation_type=ObservationType.GOAL_SIGNAL, name=obs_spec[i].name, ) return obs_spec diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index d2e9131cca..ca6c8c9481 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -60,7 +60,7 @@ def __init__( self._total_goal_enc_size = 0 self._goal_processor_indices: List[int] = [] for i in range(len(observation_specs)): - if observation_specs[i].observation_type == ObservationType.GOAL: + if observation_specs[i].observation_type == ObservationType.GOAL_SIGNAL: self._total_goal_enc_size += self.embedding_sizes[i] self._goal_processor_indices.append(i) diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto index 2d9e59b9af..94b65eaa8e 100644 --- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto +++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto @@ -10,9 +10,9 @@ enum CompressionTypeProto { enum ObservationTypeProto { DEFAULT = 0; - GOAL = 1; - REWARD = 2; - MESSAGE = 3; + GOAL_SIGNAL = 1; + reserved 2; // Reserved for potential "reward" type + reserved 3; // Reserved for potential "message" type } message ObservationProto {