Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop remove memories #2795

Merged
merged 17 commits into from
Oct 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ public void TestStoreInitalize()
{
vectorObservationSize = 3,
numStackedVectorObservations = 2,
vectorActionDescriptions = new[] {"TestActionA", "TestActionB"},
vectorActionSize = new[] {2, 2},
vectorActionDescriptions = new[] { "TestActionA", "TestActionB" },
vectorActionSize = new[] { 2, 2 },
vectorActionSpaceType = SpaceType.Discrete
};

Expand All @@ -46,14 +46,13 @@ public void TestStoreInitalize()
var agentInfo = new AgentInfo
{
reward = 1f,
actionMasks = new[] {false, true},
actionMasks = new[] { false, true },
done = true,
id = 5,
maxStepReached = true,
memories = new List<float>(),
stackedVectorObservation = new List<float>() {1f, 1f, 1f},
stackedVectorObservation = new List<float>() { 1f, 1f, 1f },
storedTextActions = "TestAction",
storedVectorActions = new[] {0f, 1f},
storedVectorActions = new[] { 0f, 1f },
textObservation = "TestAction",
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ private class TestAgent : Agent
{
public AgentAction GetAction()
{
var f = typeof(Agent).GetField(
var f = typeof(Agent).GetField(
"m_Action", BindingFlags.Instance | BindingFlags.NonPublic);
return (AgentAction)f.GetValue(this);
}
Expand All @@ -26,15 +26,16 @@ private List<Agent> GetFakeAgentInfos()
var goB = new GameObject("goB");
var agentB = goB.AddComponent<TestAgent>();

return new List<Agent> {agentA, agentB};
return new List<Agent> { agentA, agentB };
}

[Test]
public void Construction()
{
var bp = new BrainParameters();
var alloc = new TensorCachingAllocator();
var tensorGenerator = new TensorApplier(bp, 0, alloc);
var mem = new Dictionary<int, List<float>>();
var tensorGenerator = new TensorApplier(bp, 0, alloc, mem);
Assert.IsNotNull(tensorGenerator);
alloc.Dispose();
}
Expand All @@ -44,8 +45,8 @@ public void ApplyContinuousActionOutput()
{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 3},
data = new Tensor(2, 3, new float[] {1, 2, 3, 4, 5, 6})
shape = new long[] { 2, 3 },
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gahhh. whitespace noise

data = new Tensor(2, 3, new float[] { 1, 2, 3, 4, 5, 6 })
};
var agentInfos = GetFakeAgentInfos();

Expand Down Expand Up @@ -73,15 +74,15 @@ public void ApplyDiscreteActionOutput()
{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 5},
shape = new long[] { 2, 5 },
data = new Tensor(
2,
5,
new[] {0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f})
new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f })
};
var agentInfos = GetFakeAgentInfos();
var alloc = new TensorCachingAllocator();
var applier = new DiscreteActionOutputApplier(new[] {2, 3}, 0, alloc);
var applier = new DiscreteActionOutputApplier(new[] { 2, 3 }, 0, alloc);
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos;

Expand All @@ -99,43 +100,13 @@ public void ApplyDiscreteActionOutput()
alloc.Dispose();
}

[Test]
public void ApplyMemoryOutput()
{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 5},
data = new Tensor(
2,
5,
new[] {0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f})
};
var agentInfos = GetFakeAgentInfos();

var applier = new MemoryOutputApplier();
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos;

var agent = agents[0] as TestAgent;
Assert.NotNull(agent);
var action = agent.GetAction();
Assert.AreEqual(action.memories[0], 0.5f);
Assert.AreEqual(action.memories[1], 22.5f);

agent = agents[1] as TestAgent;
Assert.NotNull(agent);
action = agent.GetAction();
Assert.AreEqual(action.memories[2], 6);
Assert.AreEqual(action.memories[3], 7);
}

[Test]
public void ApplyValueEstimate()
{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 1},
data = new Tensor(2, 1, new[] {0.5f, 8f})
shape = new long[] { 2, 1 },
data = new Tensor(2, 1, new[] { 0.5f, 8f })
};
var agentInfos = GetFakeAgentInfos();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ private static IEnumerable<Agent> GetFakeAgentInfos()
var infoA = new AgentInfo
{
stackedVectorObservation = new[] { 1f, 2f, 3f }.ToList(),
memories = null,
storedVectorActions = new[] { 1f, 2f },
actionMasks = null
};
Expand All @@ -25,7 +24,6 @@ private static IEnumerable<Agent> GetFakeAgentInfos()
var infoB = new AgentInfo
{
stackedVectorObservation = new[] { 4f, 5f, 6f }.ToList(),
memories = new[] { 1f, 1f, 1f }.ToList(),
storedVectorActions = new[] { 3f, 4f },
actionMasks = new[] { true, false, false, false, false },
};
Expand All @@ -40,7 +38,8 @@ public void Construction()
{
var bp = new BrainParameters();
var alloc = new TensorCachingAllocator();
var tensorGenerator = new TensorGenerator(bp, 0, alloc);
var mem = new Dictionary<int, List<float>>();
var tensorGenerator = new TensorGenerator(bp, 0, alloc, mem);
Assert.IsNotNull(tensorGenerator);
alloc.Dispose();
}
Expand Down Expand Up @@ -91,26 +90,6 @@ public void GenerateVectorObservation()
alloc.Dispose();
}

[Test]
public void GenerateRecurrentInput()
{
var inputTensor = new TensorProxy
{
shape = new long[] { 2, 5 }
};
const int batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var alloc = new TensorCachingAllocator();
var generator = new RecurrentInputGenerator(alloc);
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 0], 0);
Assert.AreEqual(inputTensor.data[0, 4], 0);
Assert.AreEqual(inputTensor.data[1, 0], 1);
Assert.AreEqual(inputTensor.data[1, 4], 0);
alloc.Dispose();
}

[Test]
public void GeneratePreviousActionInput()
{
Expand Down
31 changes: 0 additions & 31 deletions UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,6 @@ public struct AgentInfo
/// </summary>
public bool[] actionMasks;

/// <summary>
/// Used by the Trainer to store information about the agent. This data
/// structure is not consumed or modified by the agent directly, they are
/// just the owners of their trainier's memory. Currently, however, the
/// size of the memory is in the Brain properties.
/// </summary>
public List<float> memories;

/// <summary>
/// Current agent reward.
/// </summary>
Expand Down Expand Up @@ -96,7 +88,6 @@ public struct AgentAction
{
public float[] vectorActions;
public string textActions;
public List<float> memories;
public float value;
/// TODO(cgoy): All references to protobuf objects should be removed.
public CommunicatorObjects.CustomActionProto customAction;
Expand Down Expand Up @@ -484,8 +475,6 @@ void ResetData()
if (m_Info.textObservation == null)
m_Info.textObservation = "";
m_Action.textActions = "";
m_Info.memories = new List<float>();
m_Action.memories = new List<float>();
m_Info.vectorObservation =
new List<float>(param.vectorObservationSize);
m_Info.stackedVectorObservation =
Expand Down Expand Up @@ -563,7 +552,6 @@ void SendInfoToBrain()
return;
}

m_Info.memories = m_Action.memories;
m_Info.storedVectorActions = m_Action.vectorActions;
m_Info.storedTextActions = m_Action.textActions;
m_Info.vectorObservation.Clear();
Expand Down Expand Up @@ -902,25 +890,6 @@ public void UpdateVectorAction(float[] vectorActions)
m_Action.vectorActions = vectorActions;
}

/// <summary>
/// Updates the memories action.
/// </summary>
/// <param name="memories">Memories.</param>
public void UpdateMemoriesAction(List<float> memories)
{
m_Action.memories = memories;
}

public void AppendMemoriesAction(List<float> memories)
{
m_Action.memories.AddRange(memories);
}

public List<float> GetMemoriesAction()
{
return m_Action.memories;
}

/// <summary>
/// Updates the value of the agent.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ static AgentActionReflection() {
string.Concat(
"CjVtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2Fj",
"dGlvbi5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMaNm1sYWdlbnRzL2Vu",
"dnMvY29tbXVuaWNhdG9yX29iamVjdHMvY3VzdG9tX2FjdGlvbi5wcm90byKh",
"dnMvY29tbXVuaWNhdG9yX29iamVjdHMvY3VzdG9tX2FjdGlvbi5wcm90byKV",
"AQoQQWdlbnRBY3Rpb25Qcm90bxIWCg52ZWN0b3JfYWN0aW9ucxgBIAMoAhIU",
"Cgx0ZXh0X2FjdGlvbnMYAiABKAkSEAoIbWVtb3JpZXMYAyADKAISDQoFdmFs",
"dWUYBCABKAISPgoNY3VzdG9tX2FjdGlvbhgFIAEoCzInLmNvbW11bmljYXRv",
"cl9vYmplY3RzLkN1c3RvbUFjdGlvblByb3RvQh+qAhxNTEFnZW50cy5Db21t",
"dW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
"Cgx0ZXh0X2FjdGlvbnMYAiABKAkSDQoFdmFsdWUYBCABKAISPgoNY3VzdG9t",
"X2FjdGlvbhgFIAEoCzInLmNvbW11bmljYXRvcl9vYmplY3RzLkN1c3RvbUFj",
"dGlvblByb3RvSgQIAxAEQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmpl",
"Y3RzYgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.CustomActionReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentActionProto), global::MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "TextActions", "Memories", "Value", "CustomAction" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentActionProto), global::MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "TextActions", "Value", "CustomAction" }, null, null, null)
}));
}
#endregion
Expand Down Expand Up @@ -69,7 +69,6 @@ public AgentActionProto() {
public AgentActionProto(AgentActionProto other) : this() {
vectorActions_ = other.vectorActions_.Clone();
textActions_ = other.textActions_;
memories_ = other.memories_.Clone();
value_ = other.value_;
CustomAction = other.customAction_ != null ? other.CustomAction.Clone() : null;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
Expand Down Expand Up @@ -101,16 +100,6 @@ public string TextActions {
}
}

/// <summary>Field number for the "memories" field.</summary>
public const int MemoriesFieldNumber = 3;
private static readonly pb::FieldCodec<float> _repeated_memories_codec
= pb::FieldCodec.ForFloat(26);
private readonly pbc::RepeatedField<float> memories_ = new pbc::RepeatedField<float>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<float> Memories {
get { return memories_; }
}

/// <summary>Field number for the "value" field.</summary>
public const int ValueFieldNumber = 4;
private float value_;
Expand Down Expand Up @@ -148,7 +137,6 @@ public bool Equals(AgentActionProto other) {
}
if(!vectorActions_.Equals(other.vectorActions_)) return false;
if (TextActions != other.TextActions) return false;
if(!memories_.Equals(other.memories_)) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Value, other.Value)) return false;
if (!object.Equals(CustomAction, other.CustomAction)) return false;
return Equals(_unknownFields, other._unknownFields);
Expand All @@ -159,7 +147,6 @@ public override int GetHashCode() {
int hash = 1;
hash ^= vectorActions_.GetHashCode();
if (TextActions.Length != 0) hash ^= TextActions.GetHashCode();
hash ^= memories_.GetHashCode();
if (Value != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Value);
if (customAction_ != null) hash ^= CustomAction.GetHashCode();
if (_unknownFields != null) {
Expand All @@ -180,7 +167,6 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteRawTag(18);
output.WriteString(TextActions);
}
memories_.WriteTo(output, _repeated_memories_codec);
if (Value != 0F) {
output.WriteRawTag(37);
output.WriteFloat(Value);
Expand All @@ -201,7 +187,6 @@ public int CalculateSize() {
if (TextActions.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(TextActions);
}
size += memories_.CalculateSize(_repeated_memories_codec);
if (Value != 0F) {
size += 1 + 4;
}
Expand All @@ -223,7 +208,6 @@ public void MergeFrom(AgentActionProto other) {
if (other.TextActions.Length != 0) {
TextActions = other.TextActions;
}
memories_.Add(other.memories_);
if (other.Value != 0F) {
Value = other.Value;
}
Expand Down Expand Up @@ -253,11 +237,6 @@ public void MergeFrom(pb::CodedInputStream input) {
TextActions = input.ReadString();
break;
}
case 26:
case 29: {
memories_.AddEntriesFrom(input, _repeated_memories_codec);
break;
}
case 37: {
Value = input.ReadFloat();
break;
Expand Down
Loading