diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml
index 1aa9df2545..ec43218bec 100644
--- a/.github/workflows/pytest.yml
+++ b/.github/workflows/pytest.yml
@@ -61,7 +61,7 @@ jobs:
pip freeze > pip_versions-${{ matrix.python-version }}.txt
cat pip_versions-${{ matrix.python-version }}.txt
- name: Run pytest
- run: pytest --cov=ml-agents --cov=ml-agents-envs --cov=gym-unity --cov-report html --junitxml=junit/test-results-${{ matrix.python-version }}.xml -p no:warnings
+ run: pytest --cov=ml-agents --cov=ml-agents-envs --cov=gym-unity --cov-report html --junitxml=junit/test-results-${{ matrix.python-version }}.xml -p no:warnings -v
- name: Upload pytest test results
uses: actions/upload-artifact@v2
with:
diff --git a/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs b/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs
index c6389536e8..c77cc86409 100644
--- a/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs
+++ b/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs
@@ -25,6 +25,7 @@ public override void OnInspectorGUI()
EditorGUILayout.PropertyField(so.FindProperty("m_Height"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
+ EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true);
}
EditorGUI.EndDisabledGroup();
EditorGUILayout.PropertyField(so.FindProperty("m_Compression"), true);
diff --git a/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs b/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs
new file mode 100644
index 0000000000..ff1cfd2e78
--- /dev/null
+++ b/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs
@@ -0,0 +1,30 @@
+using UnityEditor;
+using Unity.MLAgents.Sensors;
+
+namespace Unity.MLAgents.Editor
+{
+ [CustomEditor(typeof(VectorSensorComponent))]
+ [CanEditMultipleObjects]
+ internal class VectorSensorComponentEditor : UnityEditor.Editor
+ {
+ public override void OnInspectorGUI()
+ {
+ var so = serializedObject;
+ so.Update();
+
+ // Drawing the VectorSensorComponent
+
+ EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
+ {
+ // These fields affect the sensor order or observation size,
+ // So can't be changed at runtime.
+ EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true);
+ EditorGUILayout.PropertyField(so.FindProperty("m_ObservationSize"), true);
+ EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true);
+ }
+ EditorGUI.EndDisabledGroup();
+
+ so.ApplyModifiedProperties();
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta b/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta
new file mode 100644
index 0000000000..9862a23944
--- /dev/null
+++ b/com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: aa0230c3402f04921acdbbdb61f6ff00
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs
index 55904935c8..3e23c8d991 100644
--- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs
+++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs
@@ -35,10 +35,10 @@ static ObservationReflection() {
"b25fdHlwZRgHIAEoDjIqLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0",
"aW9uVHlwZVByb3RvEgwKBG5hbWUYCCABKAkaGQoJRmxvYXREYXRhEgwKBGRh",
"dGEYASADKAJCEgoQb2JzZXJ2YXRpb25fZGF0YSopChRDb21wcmVzc2lvblR5",
- "cGVQcm90bxIICgROT05FEAASBwoDUE5HEAEqRgoUT2JzZXJ2YXRpb25UeXBl",
- "UHJvdG8SCwoHREVGQVVMVBAAEggKBEdPQUwQARIKCgZSRVdBUkQQAhILCgdN",
- "RVNTQUdFEANCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVj",
- "dHNiBnByb3RvMw=="));
+ "cGVQcm90bxIICgROT05FEAASBwoDUE5HEAEqQAoUT2JzZXJ2YXRpb25UeXBl",
+ "UHJvdG8SCwoHREVGQVVMVBAAEg8KC0dPQUxfU0lHTkFMEAEiBAgCEAIiBAgD",
+ "EANCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy",
+ "b3RvMw=="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
@@ -56,9 +56,7 @@ internal enum CompressionTypeProto {
internal enum ObservationTypeProto {
[pbr::OriginalName("DEFAULT")] Default = 0,
- [pbr::OriginalName("GOAL")] Goal = 1,
- [pbr::OriginalName("REWARD")] Reward = 2,
- [pbr::OriginalName("MESSAGE")] Message = 3,
+ [pbr::OriginalName("GOAL_SIGNAL")] GoalSignal = 1,
}
#endregion
diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
index 3cede2408d..805e7e302b 100644
--- a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
@@ -44,8 +44,9 @@ public SensorCompressionType CompressionType
/// Whether to convert the generated image to grayscale or keep color.
/// The name of the camera sensor.
/// The compression to apply to the generated image.
+ /// The type of observation.
public CameraSensor(
- Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression)
+ Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression, ObservationType observationType = ObservationType.Default)
{
m_Camera = camera;
m_Width = width;
@@ -53,7 +54,7 @@ public CameraSensor(
m_Grayscale = grayscale;
m_Name = name;
var channels = grayscale ? 1 : 3;
- m_ObservationSpec = ObservationSpec.Visual(height, width, channels);
+ m_ObservationSpec = ObservationSpec.Visual(height, width, channels, observationType);
m_CompressionType = compression;
}
diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
index 41582d35c6..3df6c72764 100644
--- a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
@@ -75,6 +75,18 @@ public bool Grayscale
set { m_Grayscale = value; }
}
+ [HideInInspector, SerializeField]
+ ObservationType m_ObservationType;
+
+ ///
+ /// The type of the observation.
+ ///
+ public ObservationType ObservationType
+ {
+ get { return m_ObservationType; }
+ set { m_ObservationType = value; UpdateSensor(); }
+ }
+
[HideInInspector, SerializeField]
[Range(1, 50)]
[Tooltip("Number of camera frames that will be stacked before being fed to the neural network.")]
@@ -108,7 +120,7 @@ public int ObservationStacks
/// The created object for this component.
public override ISensor[] CreateSensors()
{
- m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression);
+ m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression, m_ObservationType);
if (ObservationStacks != 1)
{
diff --git a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs
index 62f6f78a08..8ca07ad830 100644
--- a/com.unity.ml-agents/Runtime/Sensors/ISensor.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/ISensor.cs
@@ -42,17 +42,7 @@ public enum ObservationType
///
/// Collected observations contain goal information.
///
- Goal = 1,
-
- ///
- /// Collected observations contain reward information.
- ///
- Reward = 2,
-
- ///
- /// Collected observations are messages from other agents.
- ///
- Message = 3,
+ GoalSignal = 1,
}
///
diff --git a/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
index a09d2d58a2..37eb052289 100644
--- a/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
@@ -21,16 +21,20 @@ public class VectorSensor : ISensor, IBuiltInSensor
///
/// Number of vector observations.
/// Name of the sensor.
- public VectorSensor(int observationSize, string name = null)
+ public VectorSensor(int observationSize, string name = null, ObservationType observationType = ObservationType.Default)
{
- if (name == null)
+ if (string.IsNullOrEmpty(name))
{
name = $"VectorSensor_size{observationSize}";
+ if (observationType != ObservationType.Default)
+ {
+ name += $"_{observationType.ToString()}";
+ }
}
m_Observations = new List(observationSize);
m_Name = name;
- m_ObservationSpec = ObservationSpec.Vector(observationSize);
+ m_ObservationSpec = ObservationSpec.Vector(observationSize, observationType);
}
///
diff --git a/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs
new file mode 100644
index 0000000000..abd8bb09f5
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs
@@ -0,0 +1,68 @@
+using UnityEngine;
+using UnityEngine.Serialization;
+
+namespace Unity.MLAgents.Sensors
+{
+ ///
+ /// A SensorComponent that creates a .
+ ///
+ [AddComponentMenu("ML Agents/Vector Sensor", (int)MenuGroup.Sensors)]
+ public class VectorSensorComponent : SensorComponent
+ {
+ ///
+ /// Name of the generated object.
+ /// Note that changing this at runtime does not affect how the Agent sorts the sensors.
+ ///
+ public string SensorName
+ {
+ get { return m_SensorName; }
+ set { m_SensorName = value; }
+ }
+ [HideInInspector, SerializeField]
+ private string m_SensorName = "VectorSensor";
+
+ ///
+ /// The number of float observations in the VectorSensor
+ ///
+ public int ObservationSize
+ {
+ get { return m_ObservationSize; }
+ set { m_ObservationSize = value; }
+ }
+
+ [HideInInspector, SerializeField]
+ int m_ObservationSize;
+
+ [HideInInspector, SerializeField]
+ ObservationType m_ObservationType;
+
+ VectorSensor m_Sensor;
+
+ ///
+ /// The type of the observation.
+ ///
+ public ObservationType ObservationType
+ {
+ get { return m_ObservationType; }
+ set { m_ObservationType = value; }
+ }
+
+ ///
+ /// Creates a VectorSensor.
+ ///
+ ///
+ public override ISensor[] CreateSensors()
+ {
+ m_Sensor = new VectorSensor(m_ObservationSize, m_SensorName, m_ObservationType);
+ return new ISensor[] { m_Sensor };
+ }
+
+ ///
+ /// Returns the underlying VectorSensor
+ ///
+ public VectorSensor GetSensor()
+ {
+ return m_Sensor;
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta b/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta
new file mode 100644
index 0000000000..c867a60f2b
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 38b7cc1f5819445aa85e9a9b054552dc
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs
index 925544def3..7dc9c42fde 100644
--- a/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs
+++ b/com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs
@@ -37,5 +37,22 @@ public void TestCameraSensor()
}
}
}
+
+ [Test]
+ public void TestObservationType()
+ {
+ var width = 24;
+ var height = 16;
+ var camera = Camera.main;
+ var sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None);
+ var spec = sensor.GetObservationSpec();
+ Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
+ sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.Default);
+ spec = sensor.GetObservationSpec();
+ Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
+ sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.GoalSignal);
+ spec = sensor.GetObservationSpec();
+ Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.GoalSignal);
+ }
}
}
diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs
index 5326bca868..f58606e99c 100644
--- a/com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs
+++ b/com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs
@@ -42,6 +42,20 @@ public void TestAddObservationFloat()
SensorTestHelper.CompareObservation(sensor, new[] { 1.2f });
}
+ [Test]
+ public void TestObservationType()
+ {
+ var sensor = new VectorSensor(1);
+ var spec = sensor.GetObservationSpec();
+ Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
+ sensor = new VectorSensor(1, observationType: ObservationType.Default);
+ spec = sensor.GetObservationSpec();
+ Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
+ sensor = new VectorSensor(1, observationType: ObservationType.GoalSignal);
+ spec = sensor.GetObservationSpec();
+ Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.GoalSignal);
+ }
+
[Test]
public void TestAddObservationInt()
{
diff --git a/ml-agents-envs/mlagents_envs/base_env.py b/ml-agents-envs/mlagents_envs/base_env.py
index d80f752b71..000cef5709 100644
--- a/ml-agents-envs/mlagents_envs/base_env.py
+++ b/ml-agents-envs/mlagents_envs/base_env.py
@@ -486,11 +486,7 @@ class ObservationType(Enum):
# Observation information is generic.
DEFAULT = 0
# Observation contains goal information for current task.
- GOAL = 1
- # Observation contains reward information for current task.
- REWARD = 2
- # Observation contains a message from another agent.
- MESSAGE = 3
+ GOAL_SIGNAL = 1
class ObservationSpec(NamedTuple):
diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py
index 2acc6959db..838ca1d87d 100644
--- a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py
+++ b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py
@@ -20,7 +20,7 @@
name='mlagents_envs/communicator_objects/observation.proto',
package='communicator_objects',
syntax='proto3',
- serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\x8f\x03\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x12\x1c\n\x14\x64imension_properties\x18\x06 \x03(\x05\x12\x44\n\x10observation_type\x18\x07 \x01(\x0e\x32*.communicator_objects.ObservationTypeProto\x12\x0c\n\x04name\x18\x08 \x01(\t\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01*F\n\x14ObservationTypeProto\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x08\n\x04GOAL\x10\x01\x12\n\n\x06REWARD\x10\x02\x12\x0b\n\x07MESSAGE\x10\x03\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
+ serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\x8f\x03\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x12\x1c\n\x14\x64imension_properties\x18\x06 \x03(\x05\x12\x44\n\x10observation_type\x18\x07 \x01(\x0e\x32*.communicator_objects.ObservationTypeProto\x12\x0c\n\x04name\x18\x08 \x01(\t\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01*@\n\x14ObservationTypeProto\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x0f\n\x0bGOAL_SIGNAL\x10\x01\"\x04\x08\x02\x10\x02\"\x04\x08\x03\x10\x03\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)
_COMPRESSIONTYPEPROTO = _descriptor.EnumDescriptor(
@@ -57,22 +57,14 @@
options=None,
type=None),
_descriptor.EnumValueDescriptor(
- name='GOAL', index=1, number=1,
- options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='REWARD', index=2, number=2,
- options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='MESSAGE', index=3, number=3,
+ name='GOAL_SIGNAL', index=1, number=1,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=523,
- serialized_end=593,
+ serialized_end=587,
)
_sym_db.RegisterEnumDescriptor(_OBSERVATIONTYPEPROTO)
@@ -80,9 +72,7 @@
NONE = 0
PNG = 1
DEFAULT = 0
-GOAL = 1
-REWARD = 2
-MESSAGE = 3
+GOAL_SIGNAL = 1
diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi
index 0afb9cd458..6427c50851 100644
--- a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi
+++ b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi
@@ -64,13 +64,9 @@ class ObservationTypeProto(builtin___int):
@classmethod
def items(cls) -> typing___List[typing___Tuple[builtin___str, 'ObservationTypeProto']]: ...
DEFAULT = typing___cast('ObservationTypeProto', 0)
- GOAL = typing___cast('ObservationTypeProto', 1)
- REWARD = typing___cast('ObservationTypeProto', 2)
- MESSAGE = typing___cast('ObservationTypeProto', 3)
+ GOAL_SIGNAL = typing___cast('ObservationTypeProto', 1)
DEFAULT = typing___cast('ObservationTypeProto', 0)
-GOAL = typing___cast('ObservationTypeProto', 1)
-REWARD = typing___cast('ObservationTypeProto', 2)
-MESSAGE = typing___cast('ObservationTypeProto', 3)
+GOAL_SIGNAL = typing___cast('ObservationTypeProto', 1)
class ObservationProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
diff --git a/ml-agents/mlagents/trainers/tests/simple_test_envs.py b/ml-agents/mlagents/trainers/tests/simple_test_envs.py
index d4c25def4d..489fecfa46 100644
--- a/ml-agents/mlagents/trainers/tests/simple_test_envs.py
+++ b/ml-agents/mlagents/trainers/tests/simple_test_envs.py
@@ -121,7 +121,7 @@ def _make_observation_specs(self) -> List[ObservationSpec]:
obs_spec[i] = ObservationSpec(
shape=obs_spec[i].shape,
dimension_property=obs_spec[i].dimension_property,
- observation_type=ObservationType.GOAL,
+ observation_type=ObservationType.GOAL_SIGNAL,
name=obs_spec[i].name,
)
return obs_spec
diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py
index d2e9131cca..ca6c8c9481 100644
--- a/ml-agents/mlagents/trainers/torch/networks.py
+++ b/ml-agents/mlagents/trainers/torch/networks.py
@@ -60,7 +60,7 @@ def __init__(
self._total_goal_enc_size = 0
self._goal_processor_indices: List[int] = []
for i in range(len(observation_specs)):
- if observation_specs[i].observation_type == ObservationType.GOAL:
+ if observation_specs[i].observation_type == ObservationType.GOAL_SIGNAL:
self._total_goal_enc_size += self.embedding_sizes[i]
self._goal_processor_indices.append(i)
diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto
index 2d9e59b9af..94b65eaa8e 100644
--- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto
+++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto
@@ -10,9 +10,9 @@ enum CompressionTypeProto {
enum ObservationTypeProto {
DEFAULT = 0;
- GOAL = 1;
- REWARD = 2;
- MESSAGE = 3;
+ GOAL_SIGNAL = 1;
+ reserved 2; // Reserved for potential "reward" type
+ reserved 3; // Reserved for potential "message" type
}
message ObservationProto {