Skip to content

Commit

Permalink
Develop remove memories (#2795)
Browse files Browse the repository at this point in the history
* Initial commit removing memories from C# and deprecating memory fields in proto

* initial changes to Python

* Adding functionalities

* Fixes

* adding the memories to the dictionary

* Fixing bugs

* tweeks

* Resolving bugs

* Recreating the proto

* Addressing comments

* Passing by reference does not work. Do not merge

* Fixing huge bug in Inference

* Applying patches

* fixing tests

* Addressing comments

* Renaming variable to reflect type

* test
  • Loading branch information
vincentpierre authored Oct 25, 2019
1 parent 7e80f7c commit 58cee7e
Show file tree
Hide file tree
Showing 34 changed files with 182 additions and 447 deletions.
11 changes: 5 additions & 6 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ public void TestStoreInitalize()
{
vectorObservationSize = 3,
numStackedVectorObservations = 2,
vectorActionDescriptions = new[] {"TestActionA", "TestActionB"},
vectorActionSize = new[] {2, 2},
vectorActionDescriptions = new[] { "TestActionA", "TestActionB" },
vectorActionSize = new[] { 2, 2 },
vectorActionSpaceType = SpaceType.Discrete
};

Expand All @@ -46,14 +46,13 @@ public void TestStoreInitalize()
var agentInfo = new AgentInfo
{
reward = 1f,
actionMasks = new[] {false, true},
actionMasks = new[] { false, true },
done = true,
id = 5,
maxStepReached = true,
memories = new List<float>(),
stackedVectorObservation = new List<float>() {1f, 1f, 1f},
stackedVectorObservation = new List<float>() { 1f, 1f, 1f },
storedTextActions = "TestAction",
storedVectorActions = new[] {0f, 1f},
storedVectorActions = new[] { 0f, 1f },
textObservation = "TestAction",
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ private class TestAgent : Agent
{
public AgentAction GetAction()
{
var f = typeof(Agent).GetField(
var f = typeof(Agent).GetField(
"m_Action", BindingFlags.Instance | BindingFlags.NonPublic);
return (AgentAction)f.GetValue(this);
}
Expand All @@ -26,15 +26,16 @@ private List<Agent> GetFakeAgentInfos()
var goB = new GameObject("goB");
var agentB = goB.AddComponent<TestAgent>();

return new List<Agent> {agentA, agentB};
return new List<Agent> { agentA, agentB };
}

[Test]
public void Construction()
{
var bp = new BrainParameters();
var alloc = new TensorCachingAllocator();
var tensorGenerator = new TensorApplier(bp, 0, alloc);
var mem = new Dictionary<int, List<float>>();
var tensorGenerator = new TensorApplier(bp, 0, alloc, mem);
Assert.IsNotNull(tensorGenerator);
alloc.Dispose();
}
Expand All @@ -44,8 +45,8 @@ public void ApplyContinuousActionOutput()
{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 3},
data = new Tensor(2, 3, new float[] {1, 2, 3, 4, 5, 6})
shape = new long[] { 2, 3 },
data = new Tensor(2, 3, new float[] { 1, 2, 3, 4, 5, 6 })
};
var agentInfos = GetFakeAgentInfos();

Expand Down Expand Up @@ -73,15 +74,15 @@ public void ApplyDiscreteActionOutput()
{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 5},
shape = new long[] { 2, 5 },
data = new Tensor(
2,
5,
new[] {0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f})
new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f })
};
var agentInfos = GetFakeAgentInfos();
var alloc = new TensorCachingAllocator();
var applier = new DiscreteActionOutputApplier(new[] {2, 3}, 0, alloc);
var applier = new DiscreteActionOutputApplier(new[] { 2, 3 }, 0, alloc);
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos;

Expand All @@ -99,43 +100,13 @@ public void ApplyDiscreteActionOutput()
alloc.Dispose();
}

[Test]
public void ApplyMemoryOutput()
{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 5},
data = new Tensor(
2,
5,
new[] {0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f})
};
var agentInfos = GetFakeAgentInfos();

var applier = new MemoryOutputApplier();
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos;

var agent = agents[0] as TestAgent;
Assert.NotNull(agent);
var action = agent.GetAction();
Assert.AreEqual(action.memories[0], 0.5f);
Assert.AreEqual(action.memories[1], 22.5f);

agent = agents[1] as TestAgent;
Assert.NotNull(agent);
action = agent.GetAction();
Assert.AreEqual(action.memories[2], 6);
Assert.AreEqual(action.memories[3], 7);
}

[Test]
public void ApplyValueEstimate()
{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 1},
data = new Tensor(2, 1, new[] {0.5f, 8f})
shape = new long[] { 2, 1 },
data = new Tensor(2, 1, new[] { 0.5f, 8f })
};
var agentInfos = GetFakeAgentInfos();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ private static IEnumerable<Agent> GetFakeAgentInfos()
var infoA = new AgentInfo
{
stackedVectorObservation = new[] { 1f, 2f, 3f }.ToList(),
memories = null,
storedVectorActions = new[] { 1f, 2f },
actionMasks = null
};
Expand All @@ -25,7 +24,6 @@ private static IEnumerable<Agent> GetFakeAgentInfos()
var infoB = new AgentInfo
{
stackedVectorObservation = new[] { 4f, 5f, 6f }.ToList(),
memories = new[] { 1f, 1f, 1f }.ToList(),
storedVectorActions = new[] { 3f, 4f },
actionMasks = new[] { true, false, false, false, false },
};
Expand All @@ -40,7 +38,8 @@ public void Construction()
{
var bp = new BrainParameters();
var alloc = new TensorCachingAllocator();
var tensorGenerator = new TensorGenerator(bp, 0, alloc);
var mem = new Dictionary<int, List<float>>();
var tensorGenerator = new TensorGenerator(bp, 0, alloc, mem);
Assert.IsNotNull(tensorGenerator);
alloc.Dispose();
}
Expand Down Expand Up @@ -91,26 +90,6 @@ public void GenerateVectorObservation()
alloc.Dispose();
}

[Test]
public void GenerateRecurrentInput()
{
var inputTensor = new TensorProxy
{
shape = new long[] { 2, 5 }
};
const int batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var alloc = new TensorCachingAllocator();
var generator = new RecurrentInputGenerator(alloc);
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 0], 0);
Assert.AreEqual(inputTensor.data[0, 4], 0);
Assert.AreEqual(inputTensor.data[1, 0], 1);
Assert.AreEqual(inputTensor.data[1, 4], 0);
alloc.Dispose();
}

[Test]
public void GeneratePreviousActionInput()
{
Expand Down
31 changes: 0 additions & 31 deletions UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,6 @@ public struct AgentInfo
/// </summary>
public bool[] actionMasks;

/// <summary>
/// Used by the Trainer to store information about the agent. This data
/// structure is not consumed or modified by the agent directly, they are
/// just the owners of their trainier's memory. Currently, however, the
/// size of the memory is in the Brain properties.
/// </summary>
public List<float> memories;

/// <summary>
/// Current agent reward.
/// </summary>
Expand Down Expand Up @@ -96,7 +88,6 @@ public struct AgentAction
{
public float[] vectorActions;
public string textActions;
public List<float> memories;
public float value;
/// TODO(cgoy): All references to protobuf objects should be removed.
public CommunicatorObjects.CustomActionProto customAction;
Expand Down Expand Up @@ -484,8 +475,6 @@ void ResetData()
if (m_Info.textObservation == null)
m_Info.textObservation = "";
m_Action.textActions = "";
m_Info.memories = new List<float>();
m_Action.memories = new List<float>();
m_Info.vectorObservation =
new List<float>(param.vectorObservationSize);
m_Info.stackedVectorObservation =
Expand Down Expand Up @@ -563,7 +552,6 @@ void SendInfoToBrain()
return;
}

m_Info.memories = m_Action.memories;
m_Info.storedVectorActions = m_Action.vectorActions;
m_Info.storedTextActions = m_Action.textActions;
m_Info.vectorObservation.Clear();
Expand Down Expand Up @@ -902,25 +890,6 @@ public void UpdateVectorAction(float[] vectorActions)
m_Action.vectorActions = vectorActions;
}

/// <summary>
/// Updates the memories action.
/// </summary>
/// <param name="memories">Memories.</param>
public void UpdateMemoriesAction(List<float> memories)
{
m_Action.memories = memories;
}

public void AppendMemoriesAction(List<float> memories)
{
m_Action.memories.AddRange(memories);
}

public List<float> GetMemoriesAction()
{
return m_Action.memories;
}

/// <summary>
/// Updates the value of the agent.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ static AgentActionReflection() {
string.Concat(
"CjVtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2Fj",
"dGlvbi5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMaNm1sYWdlbnRzL2Vu",
"dnMvY29tbXVuaWNhdG9yX29iamVjdHMvY3VzdG9tX2FjdGlvbi5wcm90byKh",
"dnMvY29tbXVuaWNhdG9yX29iamVjdHMvY3VzdG9tX2FjdGlvbi5wcm90byKV",
"AQoQQWdlbnRBY3Rpb25Qcm90bxIWCg52ZWN0b3JfYWN0aW9ucxgBIAMoAhIU",
"Cgx0ZXh0X2FjdGlvbnMYAiABKAkSEAoIbWVtb3JpZXMYAyADKAISDQoFdmFs",
"dWUYBCABKAISPgoNY3VzdG9tX2FjdGlvbhgFIAEoCzInLmNvbW11bmljYXRv",
"cl9vYmplY3RzLkN1c3RvbUFjdGlvblByb3RvQh+qAhxNTEFnZW50cy5Db21t",
"dW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
"Cgx0ZXh0X2FjdGlvbnMYAiABKAkSDQoFdmFsdWUYBCABKAISPgoNY3VzdG9t",
"X2FjdGlvbhgFIAEoCzInLmNvbW11bmljYXRvcl9vYmplY3RzLkN1c3RvbUFj",
"dGlvblByb3RvSgQIAxAEQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmpl",
"Y3RzYgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.CustomActionReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentActionProto), global::MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "TextActions", "Memories", "Value", "CustomAction" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentActionProto), global::MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "TextActions", "Value", "CustomAction" }, null, null, null)
}));
}
#endregion
Expand Down Expand Up @@ -69,7 +69,6 @@ public AgentActionProto() {
public AgentActionProto(AgentActionProto other) : this() {
vectorActions_ = other.vectorActions_.Clone();
textActions_ = other.textActions_;
memories_ = other.memories_.Clone();
value_ = other.value_;
CustomAction = other.customAction_ != null ? other.CustomAction.Clone() : null;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
Expand Down Expand Up @@ -101,16 +100,6 @@ public string TextActions {
}
}

/// <summary>Field number for the "memories" field.</summary>
public const int MemoriesFieldNumber = 3;
private static readonly pb::FieldCodec<float> _repeated_memories_codec
= pb::FieldCodec.ForFloat(26);
private readonly pbc::RepeatedField<float> memories_ = new pbc::RepeatedField<float>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<float> Memories {
get { return memories_; }
}

/// <summary>Field number for the "value" field.</summary>
public const int ValueFieldNumber = 4;
private float value_;
Expand Down Expand Up @@ -148,7 +137,6 @@ public bool Equals(AgentActionProto other) {
}
if(!vectorActions_.Equals(other.vectorActions_)) return false;
if (TextActions != other.TextActions) return false;
if(!memories_.Equals(other.memories_)) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Value, other.Value)) return false;
if (!object.Equals(CustomAction, other.CustomAction)) return false;
return Equals(_unknownFields, other._unknownFields);
Expand All @@ -159,7 +147,6 @@ public override int GetHashCode() {
int hash = 1;
hash ^= vectorActions_.GetHashCode();
if (TextActions.Length != 0) hash ^= TextActions.GetHashCode();
hash ^= memories_.GetHashCode();
if (Value != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Value);
if (customAction_ != null) hash ^= CustomAction.GetHashCode();
if (_unknownFields != null) {
Expand All @@ -180,7 +167,6 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteRawTag(18);
output.WriteString(TextActions);
}
memories_.WriteTo(output, _repeated_memories_codec);
if (Value != 0F) {
output.WriteRawTag(37);
output.WriteFloat(Value);
Expand All @@ -201,7 +187,6 @@ public int CalculateSize() {
if (TextActions.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(TextActions);
}
size += memories_.CalculateSize(_repeated_memories_codec);
if (Value != 0F) {
size += 1 + 4;
}
Expand All @@ -223,7 +208,6 @@ public void MergeFrom(AgentActionProto other) {
if (other.TextActions.Length != 0) {
TextActions = other.TextActions;
}
memories_.Add(other.memories_);
if (other.Value != 0F) {
Value = other.Value;
}
Expand Down Expand Up @@ -253,11 +237,6 @@ public void MergeFrom(pb::CodedInputStream input) {
TextActions = input.ReadString();
break;
}
case 26:
case 29: {
memories_.AddEntriesFrom(input, _repeated_memories_codec);
break;
}
case 37: {
Value = input.ReadFloat();
break;
Expand Down
Loading

0 comments on commit 58cee7e

Please sign in to comment.