Skip to content

Commit

Permalink
[MLA-12] update protobuf for vector observations (#2862)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Elion authored Nov 7, 2019
1 parent 51032d3 commit 720679a
Show file tree
Hide file tree
Showing 52 changed files with 983 additions and 677 deletions.
13 changes: 8 additions & 5 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ public void TestStoreInitalize()
done = true,
id = 5,
maxStepReached = true,
floatObservations = new List<float>() { 1f, 1f, 1f },
storedVectorActions = new[] { 0f, 1f },
};

Expand Down Expand Up @@ -120,13 +119,17 @@ public void TestAgentWrite()
BrainParametersProto.Parser.ParseDelimitedFrom(reader);

var agentInfoProto = AgentInfoProto.Parser.ParseDelimitedFrom(reader);
var obs = agentInfoProto.StackedVectorObservation;
Assert.AreEqual(obs.Count, bpA.brainParameters.vectorObservationSize);
for (var i = 0; i < obs.Count; i++)
var obs = agentInfoProto.Observations[2]; // skip dummy sensors
{
Assert.AreEqual((float) i+1, obs[i]);
var vecObs = obs.FloatData.Data;
Assert.AreEqual(bpA.brainParameters.vectorObservationSize, vecObs.Count);
for (var i = 0; i < vecObs.Count; i++)
{
Assert.AreEqual((float) i+1, vecObs[i]);
}
}


}
}
}
64 changes: 45 additions & 19 deletions UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Collections.Generic;
using UnityEngine;
using Barracuda;
Expand All @@ -13,12 +14,9 @@ namespace MLAgents
public struct AgentInfo
{
/// <summary>
/// Most recent compressed observations.
/// Most recent observations.
/// </summary>
public List<CompressedObservation> compressedObservations;

// TODO struct?
public List<float> floatObservations;
public List<Observation> observations;

/// <summary>
/// Keeps track of the last vector action taken by the Brain.
Expand Down Expand Up @@ -229,11 +227,23 @@ public AgentInfo Info
/// </summary>
DemonstrationRecorder m_Recorder;

/// <summary>
/// List of sensors used to generate observations.
/// Currently generated from attached SensorComponents, and a legacy VectorSensor
/// </summary>
[FormerlySerializedAs("m_Sensors")]
public List<ISensor> sensors;

/// <summary>
/// VectorSensor which is written to by AddVectorObs
/// </summary>
public VectorSensor collectObservationsSensor;

/// <summary>
/// Internal buffer used for generating float observations.
/// </summary>
float[] m_VectorSensorBuffer;

WriteAdapter m_WriteAdapter = new WriteAdapter();

/// MonoBehaviour function that is called when the attached GameObject
Expand Down Expand Up @@ -447,11 +457,7 @@ void ResetData()
}
}

m_Info.compressedObservations = new List<CompressedObservation>();
m_Info.floatObservations = new List<float>();
m_Info.floatObservations.AddRange(
new float[param.vectorObservationSize
* param.numStackedVectorObservations]);
m_Info.observations = new List<Observation>();
}

/// <summary>
Expand Down Expand Up @@ -523,6 +529,17 @@ public void InitializeSensors()
Debug.Assert(!sensors[i].GetName().Equals(sensors[i + 1].GetName()), "Sensor names must be unique.");
}
#endif
// Create a buffer for writing vector sensor data too
int numFloatObservations = 0;
for (var i = 0; i < sensors.Count; i++)
{
if (sensors[i].GetCompressionType() == SensorCompressionType.None)
{
numFloatObservations += sensors[i].ObservationSize();
}
}

m_VectorSensorBuffer = new float[numFloatObservations];
}

/// <summary>
Expand All @@ -536,7 +553,7 @@ void SendInfoToBrain()
}

m_Info.storedVectorActions = m_Action.vectorActions;
m_Info.compressedObservations.Clear();
m_Info.observations.Clear();
m_ActionMasker.ResetMask();
UpdateSensors();
using (TimerStack.Instance.Scoped("CollectObservations"))
Expand All @@ -556,9 +573,9 @@ void SendInfoToBrain()

if (m_Recorder != null && m_Recorder.record && Application.isEditor)
{
// This is a bit of a hack - if we're in inference mode, compressed observations won't be generated
// This is a bit of a hack - if we're in inference mode, observations won't be generated
// But we need these to be generated for the recorder. So generate them here.
if (m_Info.compressedObservations.Count == 0)
if (m_Info.observations.Count == 0)
{
GenerateSensorData();
}
Expand All @@ -584,26 +601,35 @@ void UpdateSensors()
/// </summary>
public void GenerateSensorData()
{

int floatsWritten = 0;
// Generate data for all Sensors
for (var i = 0; i < sensors.Count; i++)
{
var sensor = sensors[i];
if (sensor.GetCompressionType() == SensorCompressionType.None)
{
m_WriteAdapter.SetTarget(m_Info.floatObservations, floatsWritten);
floatsWritten += sensor.Write(m_WriteAdapter);
// only handles 1D
// TODO handle in communicator code instead
m_WriteAdapter.SetTarget(m_VectorSensorBuffer, floatsWritten);
var numFloats = sensor.Write(m_WriteAdapter);
var floatObs = new Observation
{
FloatData = new ArraySegment<float>(m_VectorSensorBuffer, floatsWritten, numFloats),
Shape = sensor.GetFloatObservationShape(),
CompressionType = sensor.GetCompressionType()
};
m_Info.observations.Add(floatObs);
floatsWritten += numFloats;
}
else
{
var compressedObs = new CompressedObservation
var compressedObs = new Observation
{
Data = sensor.GetCompressedObservation(),
CompressedData = sensor.GetCompressedObservation(),
Shape = sensor.GetFloatObservationShape(),
CompressionType = sensor.GetCompressionType()
};
m_Info.compressedObservations.Add(compressedObs);
m_Info.observations.Add(compressedObs);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,19 @@ static AgentInfoReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjNtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGj9tbGFnZW50cy9lbnZz",
"L2NvbW11bmljYXRvcl9vYmplY3RzL2NvbXByZXNzZWRfb2JzZXJ2YXRpb24u",
"cHJvdG8inQIKDkFnZW50SW5mb1Byb3RvEiIKGnN0YWNrZWRfdmVjdG9yX29i",
"c2VydmF0aW9uGAEgAygCEh0KFXN0b3JlZF92ZWN0b3JfYWN0aW9ucxgEIAMo",
"AhIOCgZyZXdhcmQYByABKAISDAoEZG9uZRgIIAEoCBIYChBtYXhfc3RlcF9y",
"ZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlvbl9tYXNrGAsgAygI",
"ElEKF2NvbXByZXNzZWRfb2JzZXJ2YXRpb25zGA0gAygLMjAuY29tbXVuaWNh",
"dG9yX29iamVjdHMuQ29tcHJlc3NlZE9ic2VydmF0aW9uUHJvdG9KBAgCEANK",
"BAgDEARKBAgFEAZKBAgGEAdKBAgMEA1CH6oCHE1MQWdlbnRzLkNvbW11bmlj",
"YXRvck9iamVjdHNiBnByb3RvMw=="));
"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50cy9lbnZz",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIuoBCg5B",
"Z2VudEluZm9Qcm90bxIdChVzdG9yZWRfdmVjdG9yX2FjdGlvbnMYBCADKAIS",
"DgoGcmV3YXJkGAcgASgCEgwKBGRvbmUYCCABKAgSGAoQbWF4X3N0ZXBfcmVh",
"Y2hlZBgJIAEoCBIKCgJpZBgKIAEoBRITCgthY3Rpb25fbWFzaxgLIAMoCBI8",
"CgxvYnNlcnZhdGlvbnMYDSADKAsyJi5jb21tdW5pY2F0b3Jfb2JqZWN0cy5P",
"YnNlcnZhdGlvblByb3RvSgQIARACSgQIAhADSgQIAxAESgQIBRAGSgQIBhAH",
"SgQIDBANQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90",
"bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.CompressedObservationReflection.Descriptor, },
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.ObservationReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoProto), global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "StackedVectorObservation", "StoredVectorActions", "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "CompressedObservations" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoProto), global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "StoredVectorActions", "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations" }, null, null, null)
}));
}
#endregion
Expand Down Expand Up @@ -70,14 +69,13 @@ public AgentInfoProto() {

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public AgentInfoProto(AgentInfoProto other) : this() {
stackedVectorObservation_ = other.stackedVectorObservation_.Clone();
storedVectorActions_ = other.storedVectorActions_.Clone();
reward_ = other.reward_;
done_ = other.done_;
maxStepReached_ = other.maxStepReached_;
id_ = other.id_;
actionMask_ = other.actionMask_.Clone();
compressedObservations_ = other.compressedObservations_.Clone();
observations_ = other.observations_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

Expand All @@ -86,16 +84,6 @@ public AgentInfoProto Clone() {
return new AgentInfoProto(this);
}

/// <summary>Field number for the "stacked_vector_observation" field.</summary>
public const int StackedVectorObservationFieldNumber = 1;
private static readonly pb::FieldCodec<float> _repeated_stackedVectorObservation_codec
= pb::FieldCodec.ForFloat(10);
private readonly pbc::RepeatedField<float> stackedVectorObservation_ = new pbc::RepeatedField<float>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<float> StackedVectorObservation {
get { return stackedVectorObservation_; }
}

/// <summary>Field number for the "stored_vector_actions" field.</summary>
public const int StoredVectorActionsFieldNumber = 4;
private static readonly pb::FieldCodec<float> _repeated_storedVectorActions_codec
Expand Down Expand Up @@ -160,14 +148,14 @@ public int Id {
get { return actionMask_; }
}

/// <summary>Field number for the "compressed_observations" field.</summary>
public const int CompressedObservationsFieldNumber = 13;
private static readonly pb::FieldCodec<global::MLAgents.CommunicatorObjects.CompressedObservationProto> _repeated_compressedObservations_codec
= pb::FieldCodec.ForMessage(106, global::MLAgents.CommunicatorObjects.CompressedObservationProto.Parser);
private readonly pbc::RepeatedField<global::MLAgents.CommunicatorObjects.CompressedObservationProto> compressedObservations_ = new pbc::RepeatedField<global::MLAgents.CommunicatorObjects.CompressedObservationProto>();
/// <summary>Field number for the "observations" field.</summary>
public const int ObservationsFieldNumber = 13;
private static readonly pb::FieldCodec<global::MLAgents.CommunicatorObjects.ObservationProto> _repeated_observations_codec
= pb::FieldCodec.ForMessage(106, global::MLAgents.CommunicatorObjects.ObservationProto.Parser);
private readonly pbc::RepeatedField<global::MLAgents.CommunicatorObjects.ObservationProto> observations_ = new pbc::RepeatedField<global::MLAgents.CommunicatorObjects.ObservationProto>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::MLAgents.CommunicatorObjects.CompressedObservationProto> CompressedObservations {
get { return compressedObservations_; }
public pbc::RepeatedField<global::MLAgents.CommunicatorObjects.ObservationProto> Observations {
get { return observations_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
Expand All @@ -183,28 +171,26 @@ public bool Equals(AgentInfoProto other) {
if (ReferenceEquals(other, this)) {
return true;
}
if(!stackedVectorObservation_.Equals(other.stackedVectorObservation_)) return false;
if(!storedVectorActions_.Equals(other.storedVectorActions_)) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Reward, other.Reward)) return false;
if (Done != other.Done) return false;
if (MaxStepReached != other.MaxStepReached) return false;
if (Id != other.Id) return false;
if(!actionMask_.Equals(other.actionMask_)) return false;
if(!compressedObservations_.Equals(other.compressedObservations_)) return false;
if(!observations_.Equals(other.observations_)) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
hash ^= stackedVectorObservation_.GetHashCode();
hash ^= storedVectorActions_.GetHashCode();
if (Reward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Reward);
if (Done != false) hash ^= Done.GetHashCode();
if (MaxStepReached != false) hash ^= MaxStepReached.GetHashCode();
if (Id != 0) hash ^= Id.GetHashCode();
hash ^= actionMask_.GetHashCode();
hash ^= compressedObservations_.GetHashCode();
hash ^= observations_.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
Expand All @@ -218,7 +204,6 @@ public override string ToString() {

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
stackedVectorObservation_.WriteTo(output, _repeated_stackedVectorObservation_codec);
storedVectorActions_.WriteTo(output, _repeated_storedVectorActions_codec);
if (Reward != 0F) {
output.WriteRawTag(61);
Expand All @@ -237,7 +222,7 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteInt32(Id);
}
actionMask_.WriteTo(output, _repeated_actionMask_codec);
compressedObservations_.WriteTo(output, _repeated_compressedObservations_codec);
observations_.WriteTo(output, _repeated_observations_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
Expand All @@ -246,7 +231,6 @@ public void WriteTo(pb::CodedOutputStream output) {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
size += stackedVectorObservation_.CalculateSize(_repeated_stackedVectorObservation_codec);
size += storedVectorActions_.CalculateSize(_repeated_storedVectorActions_codec);
if (Reward != 0F) {
size += 1 + 4;
Expand All @@ -261,7 +245,7 @@ public int CalculateSize() {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id);
}
size += actionMask_.CalculateSize(_repeated_actionMask_codec);
size += compressedObservations_.CalculateSize(_repeated_compressedObservations_codec);
size += observations_.CalculateSize(_repeated_observations_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
Expand All @@ -273,7 +257,6 @@ public void MergeFrom(AgentInfoProto other) {
if (other == null) {
return;
}
stackedVectorObservation_.Add(other.stackedVectorObservation_);
storedVectorActions_.Add(other.storedVectorActions_);
if (other.Reward != 0F) {
Reward = other.Reward;
Expand All @@ -288,7 +271,7 @@ public void MergeFrom(AgentInfoProto other) {
Id = other.Id;
}
actionMask_.Add(other.actionMask_);
compressedObservations_.Add(other.compressedObservations_);
observations_.Add(other.observations_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

Expand All @@ -300,11 +283,6 @@ public void MergeFrom(pb::CodedInputStream input) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10:
case 13: {
stackedVectorObservation_.AddEntriesFrom(input, _repeated_stackedVectorObservation_codec);
break;
}
case 34:
case 37: {
storedVectorActions_.AddEntriesFrom(input, _repeated_storedVectorActions_codec);
Expand Down Expand Up @@ -332,7 +310,7 @@ public void MergeFrom(pb::CodedInputStream input) {
break;
}
case 106: {
compressedObservations_.AddEntriesFrom(input, _repeated_compressedObservations_codec);
observations_.AddEntriesFrom(input, _repeated_observations_codec);
break;
}
}
Expand Down
Loading

0 comments on commit 720679a

Please sign in to comment.