Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLA-12] update protobuf for vector observations #2862

Merged
merged 21 commits into from
Nov 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ public void TestStoreInitalize()
done = true,
id = 5,
maxStepReached = true,
floatObservations = new List<float>() { 1f, 1f, 1f },
storedVectorActions = new[] { 0f, 1f },
};

Expand Down Expand Up @@ -120,13 +119,17 @@ public void TestAgentWrite()
BrainParametersProto.Parser.ParseDelimitedFrom(reader);

var agentInfoProto = AgentInfoProto.Parser.ParseDelimitedFrom(reader);
var obs = agentInfoProto.StackedVectorObservation;
Assert.AreEqual(obs.Count, bpA.brainParameters.vectorObservationSize);
for (var i = 0; i < obs.Count; i++)
var obs = agentInfoProto.Observations[2]; // skip dummy sensors
{
Assert.AreEqual((float) i+1, obs[i]);
var vecObs = obs.FloatData.Data;
Assert.AreEqual(bpA.brainParameters.vectorObservationSize, vecObs.Count);
for (var i = 0; i < vecObs.Count; i++)
{
Assert.AreEqual((float) i+1, vecObs[i]);
}
}


}
}
}
64 changes: 45 additions & 19 deletions UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Collections.Generic;
using UnityEngine;
using Barracuda;
Expand All @@ -13,12 +14,9 @@ namespace MLAgents
public struct AgentInfo
{
/// <summary>
/// Most recent compressed observations.
/// Most recent observations.
/// </summary>
public List<CompressedObservation> compressedObservations;

// TODO struct?
public List<float> floatObservations;
public List<Observation> observations;

/// <summary>
/// Keeps track of the last vector action taken by the Brain.
Expand Down Expand Up @@ -229,11 +227,23 @@ public AgentInfo Info
/// </summary>
DemonstrationRecorder m_Recorder;

/// <summary>
/// List of sensors used to generate observations.
/// Currently generated from attached SensorComponents, and a legacy VectorSensor
/// </summary>
[FormerlySerializedAs("m_Sensors")]
public List<ISensor> sensors;

/// <summary>
/// VectorSensor which is written to by AddVectorObs
/// </summary>
public VectorSensor collectObservationsSensor;

/// <summary>
/// Internal buffer used for generating float observations.
/// </summary>
float[] m_VectorSensorBuffer;
chriselion marked this conversation as resolved.
Show resolved Hide resolved

WriteAdapter m_WriteAdapter = new WriteAdapter();

/// MonoBehaviour function that is called when the attached GameObject
Expand Down Expand Up @@ -447,11 +457,7 @@ void ResetData()
}
}

m_Info.compressedObservations = new List<CompressedObservation>();
m_Info.floatObservations = new List<float>();
m_Info.floatObservations.AddRange(
new float[param.vectorObservationSize
* param.numStackedVectorObservations]);
m_Info.observations = new List<Observation>();
}

/// <summary>
Expand Down Expand Up @@ -523,6 +529,17 @@ public void InitializeSensors()
Debug.Assert(!sensors[i].GetName().Equals(sensors[i + 1].GetName()), "Sensor names must be unique.");
}
#endif
// Create a buffer for writing vector sensor data too
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might need to discuss alternatives to this. We need to have somewhere for non-compressed sensors to write before passing this to the proto objects; this seemed like a good way to prevent repeated allocations. Another option would be to have the Communicator have a cache of Observations (indexed by size), and it manages the Write() calls before converting to proto.

int numFloatObservations = 0;
for (var i = 0; i < sensors.Count; i++)
{
if (sensors[i].GetCompressionType() == SensorCompressionType.None)
{
numFloatObservations += sensors[i].ObservationSize();
}
}

m_VectorSensorBuffer = new float[numFloatObservations];
}

/// <summary>
Expand All @@ -536,7 +553,7 @@ void SendInfoToBrain()
}

m_Info.storedVectorActions = m_Action.vectorActions;
m_Info.compressedObservations.Clear();
m_Info.observations.Clear();
m_ActionMasker.ResetMask();
UpdateSensors();
using (TimerStack.Instance.Scoped("CollectObservations"))
Expand All @@ -556,9 +573,9 @@ void SendInfoToBrain()

if (m_Recorder != null && m_Recorder.record && Application.isEditor)
{
// This is a bit of a hack - if we're in inference mode, compressed observations won't be generated
// This is a bit of a hack - if we're in inference mode, observations won't be generated
// But we need these to be generated for the recorder. So generate them here.
if (m_Info.compressedObservations.Count == 0)
if (m_Info.observations.Count == 0)
{
GenerateSensorData();
}
Expand All @@ -584,26 +601,35 @@ void UpdateSensors()
/// </summary>
public void GenerateSensorData()
{

int floatsWritten = 0;
// Generate data for all Sensors
for (var i = 0; i < sensors.Count; i++)
{
var sensor = sensors[i];
if (sensor.GetCompressionType() == SensorCompressionType.None)
{
m_WriteAdapter.SetTarget(m_Info.floatObservations, floatsWritten);
floatsWritten += sensor.Write(m_WriteAdapter);
// only handles 1D
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's currently an assumption that compressed observation <=> 3 dimensional. We should (later) add an option to write visual obs as floats, just to test all the code paths (and maybe a compressed vector obs too).

// TODO handle in communicator code instead
m_WriteAdapter.SetTarget(m_VectorSensorBuffer, floatsWritten);
var numFloats = sensor.Write(m_WriteAdapter);
var floatObs = new Observation
{
FloatData = new ArraySegment<float>(m_VectorSensorBuffer, floatsWritten, numFloats),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool

Shape = sensor.GetFloatObservationShape(),
CompressionType = sensor.GetCompressionType()
};
m_Info.observations.Add(floatObs);
floatsWritten += numFloats;
}
else
{
var compressedObs = new CompressedObservation
var compressedObs = new Observation
{
Data = sensor.GetCompressedObservation(),
CompressedData = sensor.GetCompressedObservation(),
Shape = sensor.GetFloatObservationShape(),
CompressionType = sensor.GetCompressionType()
};
m_Info.compressedObservations.Add(compressedObs);
m_Info.observations.Add(compressedObs);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,19 @@ static AgentInfoReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjNtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGj9tbGFnZW50cy9lbnZz",
"L2NvbW11bmljYXRvcl9vYmplY3RzL2NvbXByZXNzZWRfb2JzZXJ2YXRpb24u",
"cHJvdG8inQIKDkFnZW50SW5mb1Byb3RvEiIKGnN0YWNrZWRfdmVjdG9yX29i",
"c2VydmF0aW9uGAEgAygCEh0KFXN0b3JlZF92ZWN0b3JfYWN0aW9ucxgEIAMo",
"AhIOCgZyZXdhcmQYByABKAISDAoEZG9uZRgIIAEoCBIYChBtYXhfc3RlcF9y",
"ZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlvbl9tYXNrGAsgAygI",
"ElEKF2NvbXByZXNzZWRfb2JzZXJ2YXRpb25zGA0gAygLMjAuY29tbXVuaWNh",
"dG9yX29iamVjdHMuQ29tcHJlc3NlZE9ic2VydmF0aW9uUHJvdG9KBAgCEANK",
"BAgDEARKBAgFEAZKBAgGEAdKBAgMEA1CH6oCHE1MQWdlbnRzLkNvbW11bmlj",
"YXRvck9iamVjdHNiBnByb3RvMw=="));
"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50cy9lbnZz",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIuoBCg5B",
"Z2VudEluZm9Qcm90bxIdChVzdG9yZWRfdmVjdG9yX2FjdGlvbnMYBCADKAIS",
"DgoGcmV3YXJkGAcgASgCEgwKBGRvbmUYCCABKAgSGAoQbWF4X3N0ZXBfcmVh",
"Y2hlZBgJIAEoCBIKCgJpZBgKIAEoBRITCgthY3Rpb25fbWFzaxgLIAMoCBI8",
"CgxvYnNlcnZhdGlvbnMYDSADKAsyJi5jb21tdW5pY2F0b3Jfb2JqZWN0cy5P",
"YnNlcnZhdGlvblByb3RvSgQIARACSgQIAhADSgQIAxAESgQIBRAGSgQIBhAH",
"SgQIDBANQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90",
"bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.CompressedObservationReflection.Descriptor, },
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.ObservationReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoProto), global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "StackedVectorObservation", "StoredVectorActions", "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "CompressedObservations" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoProto), global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "StoredVectorActions", "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations" }, null, null, null)
}));
}
#endregion
Expand Down Expand Up @@ -70,14 +69,13 @@ public AgentInfoProto() {

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public AgentInfoProto(AgentInfoProto other) : this() {
stackedVectorObservation_ = other.stackedVectorObservation_.Clone();
storedVectorActions_ = other.storedVectorActions_.Clone();
reward_ = other.reward_;
done_ = other.done_;
maxStepReached_ = other.maxStepReached_;
id_ = other.id_;
actionMask_ = other.actionMask_.Clone();
compressedObservations_ = other.compressedObservations_.Clone();
observations_ = other.observations_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

Expand All @@ -86,16 +84,6 @@ public AgentInfoProto Clone() {
return new AgentInfoProto(this);
}

/// <summary>Field number for the "stacked_vector_observation" field.</summary>
public const int StackedVectorObservationFieldNumber = 1;
private static readonly pb::FieldCodec<float> _repeated_stackedVectorObservation_codec
= pb::FieldCodec.ForFloat(10);
private readonly pbc::RepeatedField<float> stackedVectorObservation_ = new pbc::RepeatedField<float>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<float> StackedVectorObservation {
get { return stackedVectorObservation_; }
}

/// <summary>Field number for the "stored_vector_actions" field.</summary>
public const int StoredVectorActionsFieldNumber = 4;
private static readonly pb::FieldCodec<float> _repeated_storedVectorActions_codec
Expand Down Expand Up @@ -160,14 +148,14 @@ public int Id {
get { return actionMask_; }
}

/// <summary>Field number for the "compressed_observations" field.</summary>
public const int CompressedObservationsFieldNumber = 13;
private static readonly pb::FieldCodec<global::MLAgents.CommunicatorObjects.CompressedObservationProto> _repeated_compressedObservations_codec
= pb::FieldCodec.ForMessage(106, global::MLAgents.CommunicatorObjects.CompressedObservationProto.Parser);
private readonly pbc::RepeatedField<global::MLAgents.CommunicatorObjects.CompressedObservationProto> compressedObservations_ = new pbc::RepeatedField<global::MLAgents.CommunicatorObjects.CompressedObservationProto>();
/// <summary>Field number for the "observations" field.</summary>
public const int ObservationsFieldNumber = 13;
private static readonly pb::FieldCodec<global::MLAgents.CommunicatorObjects.ObservationProto> _repeated_observations_codec
= pb::FieldCodec.ForMessage(106, global::MLAgents.CommunicatorObjects.ObservationProto.Parser);
private readonly pbc::RepeatedField<global::MLAgents.CommunicatorObjects.ObservationProto> observations_ = new pbc::RepeatedField<global::MLAgents.CommunicatorObjects.ObservationProto>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::MLAgents.CommunicatorObjects.CompressedObservationProto> CompressedObservations {
get { return compressedObservations_; }
public pbc::RepeatedField<global::MLAgents.CommunicatorObjects.ObservationProto> Observations {
get { return observations_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
Expand All @@ -183,28 +171,26 @@ public bool Equals(AgentInfoProto other) {
if (ReferenceEquals(other, this)) {
return true;
}
if(!stackedVectorObservation_.Equals(other.stackedVectorObservation_)) return false;
if(!storedVectorActions_.Equals(other.storedVectorActions_)) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Reward, other.Reward)) return false;
if (Done != other.Done) return false;
if (MaxStepReached != other.MaxStepReached) return false;
if (Id != other.Id) return false;
if(!actionMask_.Equals(other.actionMask_)) return false;
if(!compressedObservations_.Equals(other.compressedObservations_)) return false;
if(!observations_.Equals(other.observations_)) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
hash ^= stackedVectorObservation_.GetHashCode();
hash ^= storedVectorActions_.GetHashCode();
if (Reward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Reward);
if (Done != false) hash ^= Done.GetHashCode();
if (MaxStepReached != false) hash ^= MaxStepReached.GetHashCode();
if (Id != 0) hash ^= Id.GetHashCode();
hash ^= actionMask_.GetHashCode();
hash ^= compressedObservations_.GetHashCode();
hash ^= observations_.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
Expand All @@ -218,7 +204,6 @@ public override string ToString() {

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
stackedVectorObservation_.WriteTo(output, _repeated_stackedVectorObservation_codec);
storedVectorActions_.WriteTo(output, _repeated_storedVectorActions_codec);
if (Reward != 0F) {
output.WriteRawTag(61);
Expand All @@ -237,7 +222,7 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteInt32(Id);
}
actionMask_.WriteTo(output, _repeated_actionMask_codec);
compressedObservations_.WriteTo(output, _repeated_compressedObservations_codec);
observations_.WriteTo(output, _repeated_observations_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
Expand All @@ -246,7 +231,6 @@ public void WriteTo(pb::CodedOutputStream output) {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
size += stackedVectorObservation_.CalculateSize(_repeated_stackedVectorObservation_codec);
size += storedVectorActions_.CalculateSize(_repeated_storedVectorActions_codec);
if (Reward != 0F) {
size += 1 + 4;
Expand All @@ -261,7 +245,7 @@ public int CalculateSize() {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id);
}
size += actionMask_.CalculateSize(_repeated_actionMask_codec);
size += compressedObservations_.CalculateSize(_repeated_compressedObservations_codec);
size += observations_.CalculateSize(_repeated_observations_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
Expand All @@ -273,7 +257,6 @@ public void MergeFrom(AgentInfoProto other) {
if (other == null) {
return;
}
stackedVectorObservation_.Add(other.stackedVectorObservation_);
storedVectorActions_.Add(other.storedVectorActions_);
if (other.Reward != 0F) {
Reward = other.Reward;
Expand All @@ -288,7 +271,7 @@ public void MergeFrom(AgentInfoProto other) {
Id = other.Id;
}
actionMask_.Add(other.actionMask_);
compressedObservations_.Add(other.compressedObservations_);
observations_.Add(other.observations_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

Expand All @@ -300,11 +283,6 @@ public void MergeFrom(pb::CodedInputStream input) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10:
case 13: {
stackedVectorObservation_.AddEntriesFrom(input, _repeated_stackedVectorObservation_codec);
break;
}
case 34:
case 37: {
storedVectorActions_.AddEntriesFrom(input, _repeated_storedVectorActions_codec);
Expand Down Expand Up @@ -332,7 +310,7 @@ public void MergeFrom(pb::CodedInputStream input) {
break;
}
case 106: {
compressedObservations_.AddEntriesFrom(input, _repeated_compressedObservations_codec);
observations_.AddEntriesFrom(input, _repeated_observations_codec);
break;
}
}
Expand Down
Loading