Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop remove memories #2795

Merged
merged 17 commits into from
Oct 25, 2019
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ public void TestStoreInitalize()
{
vectorObservationSize = 3,
numStackedVectorObservations = 2,
vectorActionDescriptions = new[] {"TestActionA", "TestActionB"},
vectorActionSize = new[] {2, 2},
vectorActionDescriptions = new[] { "TestActionA", "TestActionB" },
vectorActionSize = new[] { 2, 2 },
vectorActionSpaceType = SpaceType.Discrete
};

Expand All @@ -46,14 +46,13 @@ public void TestStoreInitalize()
var agentInfo = new AgentInfo
{
reward = 1f,
actionMasks = new[] {false, true},
actionMasks = new[] { false, true },
done = true,
id = 5,
maxStepReached = true,
memories = new List<float>(),
stackedVectorObservation = new List<float>() {1f, 1f, 1f},
stackedVectorObservation = new List<float>() { 1f, 1f, 1f },
storedTextActions = "TestAction",
storedVectorActions = new[] {0f, 1f},
storedVectorActions = new[] { 0f, 1f },
textObservation = "TestAction",
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ private class TestAgent : Agent
{
public AgentAction GetAction()
{
var f = typeof(Agent).GetField(
var f = typeof(Agent).GetField(
"m_Action", BindingFlags.Instance | BindingFlags.NonPublic);
return (AgentAction)f.GetValue(this);
}
Expand All @@ -26,7 +26,7 @@ private List<Agent> GetFakeAgentInfos()
var goB = new GameObject("goB");
var agentB = goB.AddComponent<TestAgent>();

return new List<Agent> {agentA, agentB};
return new List<Agent> { agentA, agentB };
}

[Test]
Expand All @@ -44,13 +44,13 @@ public void ApplyContinuousActionOutput()
{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 3},
data = new Tensor(2, 3, new float[] {1, 2, 3, 4, 5, 6})
shape = new long[] { 2, 3 },
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gahhh. whitespace noise

data = new Tensor(2, 3, new float[] { 1, 2, 3, 4, 5, 6 })
};
var agentInfos = GetFakeAgentInfos();

var applier = new ContinuousActionOutputApplier();
applier.Apply(inputTensor, agentInfos);
applier.Apply(inputTensor, agentInfos, null);
var agents = agentInfos;

var agent = agents[0] as TestAgent;
Expand All @@ -73,16 +73,16 @@ public void ApplyDiscreteActionOutput()
{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 5},
shape = new long[] { 2, 5 },
data = new Tensor(
2,
5,
new[] {0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f})
new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f })
};
var agentInfos = GetFakeAgentInfos();
var alloc = new TensorCachingAllocator();
var applier = new DiscreteActionOutputApplier(new[] {2, 3}, 0, alloc);
applier.Apply(inputTensor, agentInfos);
var applier = new DiscreteActionOutputApplier(new[] { 2, 3 }, 0, alloc);
applier.Apply(inputTensor, agentInfos, null);
var agents = agentInfos;

var agent = agents[0] as TestAgent;
Expand All @@ -99,48 +99,18 @@ public void ApplyDiscreteActionOutput()
alloc.Dispose();
}

[Test]
public void ApplyMemoryOutput()
{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 5},
data = new Tensor(
2,
5,
new[] {0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f})
};
var agentInfos = GetFakeAgentInfos();

var applier = new MemoryOutputApplier();
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos;

var agent = agents[0] as TestAgent;
Assert.NotNull(agent);
var action = agent.GetAction();
Assert.AreEqual(action.memories[0], 0.5f);
Assert.AreEqual(action.memories[1], 22.5f);

agent = agents[1] as TestAgent;
Assert.NotNull(agent);
action = agent.GetAction();
Assert.AreEqual(action.memories[2], 6);
Assert.AreEqual(action.memories[3], 7);
}

[Test]
public void ApplyValueEstimate()
{
var inputTensor = new TensorProxy()
{
shape = new long[] {2, 1},
data = new Tensor(2, 1, new[] {0.5f, 8f})
shape = new long[] { 2, 1 },
data = new Tensor(2, 1, new[] { 0.5f, 8f })
};
var agentInfos = GetFakeAgentInfos();

var applier = new ValueEstimateApplier();
applier.Apply(inputTensor, agentInfos);
applier.Apply(inputTensor, agentInfos, null);
var agents = agentInfos;

var agent = agents[0] as TestAgent;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ private static IEnumerable<Agent> GetFakeAgentInfos()
var infoA = new AgentInfo
{
stackedVectorObservation = new[] { 1f, 2f, 3f }.ToList(),
memories = null,
storedVectorActions = new[] { 1f, 2f },
actionMasks = null
};
Expand All @@ -25,7 +24,6 @@ private static IEnumerable<Agent> GetFakeAgentInfos()
var infoB = new AgentInfo
{
stackedVectorObservation = new[] { 4f, 5f, 6f }.ToList(),
memories = new[] { 1f, 1f, 1f }.ToList(),
storedVectorActions = new[] { 3f, 4f },
actionMasks = new[] { true, false, false, false, false },
};
Expand All @@ -52,7 +50,7 @@ public void GenerateBatchSize()
var alloc = new TensorCachingAllocator();
const int batchSize = 4;
var generator = new BatchSizeGenerator(alloc);
generator.Generate(inputTensor, batchSize, null);
generator.Generate(inputTensor, batchSize, null, null);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0], batchSize);
alloc.Dispose();
Expand All @@ -65,7 +63,7 @@ public void GenerateSequenceLength()
var alloc = new TensorCachingAllocator();
const int batchSize = 4;
var generator = new SequenceLengthGenerator(alloc);
generator.Generate(inputTensor, batchSize, null);
generator.Generate(inputTensor, batchSize, null, null);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0], 1);
alloc.Dispose();
Expand All @@ -82,7 +80,7 @@ public void GenerateVectorObservation()
var agentInfos = GetFakeAgentInfos();
var alloc = new TensorCachingAllocator();
var generator = new VectorObservationGenerator(alloc);
generator.Generate(inputTensor, batchSize, agentInfos);
generator.Generate(inputTensor, batchSize, agentInfos, null);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 0], 1);
Assert.AreEqual(inputTensor.data[0, 2], 3);
Expand All @@ -91,26 +89,6 @@ public void GenerateVectorObservation()
alloc.Dispose();
}

[Test]
public void GenerateRecurrentInput()
{
var inputTensor = new TensorProxy
{
shape = new long[] { 2, 5 }
};
const int batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var alloc = new TensorCachingAllocator();
var generator = new RecurrentInputGenerator(alloc);
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 0], 0);
Assert.AreEqual(inputTensor.data[0, 4], 0);
Assert.AreEqual(inputTensor.data[1, 0], 1);
Assert.AreEqual(inputTensor.data[1, 4], 0);
alloc.Dispose();
}

[Test]
public void GeneratePreviousActionInput()
{
Expand All @@ -124,7 +102,7 @@ public void GeneratePreviousActionInput()
var alloc = new TensorCachingAllocator();
var generator = new PreviousActionInputGenerator(alloc);

generator.Generate(inputTensor, batchSize, agentInfos);
generator.Generate(inputTensor, batchSize, agentInfos, null);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 0], 1);
Assert.AreEqual(inputTensor.data[0, 1], 2);
Expand All @@ -145,7 +123,7 @@ public void GenerateActionMaskInput()
var agentInfos = GetFakeAgentInfos();
var alloc = new TensorCachingAllocator();
var generator = new ActionMaskInputGenerator(alloc);
generator.Generate(inputTensor, batchSize, agentInfos);
generator.Generate(inputTensor, batchSize, agentInfos, null);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 0], 1);
Assert.AreEqual(inputTensor.data[0, 4], 1);
Expand Down
31 changes: 0 additions & 31 deletions UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,6 @@ public struct AgentInfo
/// </summary>
public bool[] actionMasks;

/// <summary>
/// Used by the Trainer to store information about the agent. This data
/// structure is not consumed or modified by the agent directly, they are
/// just the owners of their trainier's memory. Currently, however, the
/// size of the memory is in the Brain properties.
/// </summary>
public List<float> memories;

/// <summary>
/// Current agent reward.
/// </summary>
Expand Down Expand Up @@ -96,7 +88,6 @@ public struct AgentAction
{
public float[] vectorActions;
public string textActions;
public List<float> memories;
public float value;
/// TODO(cgoy): All references to protobuf objects should be removed.
public CommunicatorObjects.CustomActionProto customAction;
Expand Down Expand Up @@ -484,8 +475,6 @@ void ResetData()
if (m_Info.textObservation == null)
m_Info.textObservation = "";
m_Action.textActions = "";
m_Info.memories = new List<float>();
m_Action.memories = new List<float>();
m_Info.vectorObservation =
new List<float>(param.vectorObservationSize);
m_Info.stackedVectorObservation =
Expand Down Expand Up @@ -563,7 +552,6 @@ void SendInfoToBrain()
return;
}

m_Info.memories = m_Action.memories;
m_Info.storedVectorActions = m_Action.vectorActions;
m_Info.storedTextActions = m_Action.textActions;
m_Info.vectorObservation.Clear();
Expand Down Expand Up @@ -902,25 +890,6 @@ public void UpdateVectorAction(float[] vectorActions)
m_Action.vectorActions = vectorActions;
}

/// <summary>
/// Updates the memories action.
/// </summary>
/// <param name="memories">Memories.</param>
public void UpdateMemoriesAction(List<float> memories)
{
m_Action.memories = memories;
}

public void AppendMemoriesAction(List<float> memories)
{
m_Action.memories.AddRange(memories);
}

public List<float> GetMemoriesAction()
{
return m_Action.memories;
}

/// <summary>
/// Updates the value of the agent.
/// </summary>
Expand Down
7 changes: 1 addition & 6 deletions UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ public static AgentInfoProto ToProto(this AgentInfo ai)
Id = ai.id,
CustomObservation = ai.customObservation
};
if (ai.memories != null)
{
agentInfoProto.Memories.Add(ai.memories);
}

if (ai.actionMasks != null)
{
Expand Down Expand Up @@ -164,7 +160,6 @@ public static AgentAction ToAgentAction(this AgentActionProto aap)
{
vectorActions = aap.VectorActions.ToArray(),
textActions = aap.TextActions,
memories = aap.Memories.ToList(),
value = aap.Value,
customAction = aap.CustomAction
};
Expand All @@ -185,7 +180,7 @@ public static CompressedObservationProto ToProto(this CompressedObservation obs)
var obsProto = new CompressedObservationProto
{
Data = ByteString.CopyFrom(obs.Data),
CompressionType = (CompressionTypeProto) obs.CompressionType,
CompressionType = (CompressionTypeProto)obs.CompressionType,
};
obsProto.Shape.AddRange(obs.Shape);
return obsProto;
Expand Down
17 changes: 10 additions & 7 deletions UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ Since the messages are sent back and forth with exchange and simultaneously when
UnityOutput and UnityInput can be extended to provide functionalities beyond RL
UnityRLOutput and UnityRLInput can be extended to provide new RL functionalities
*/
public interface ICommunicator : IBatchedDecisionMaker
public interface ICommunicator
{
/// <summary>
/// Quit was received by the communicator.
Expand Down Expand Up @@ -141,17 +141,20 @@ public interface ICommunicator : IBatchedDecisionMaker
/// <param name="brainParameters">The Parameters for the Brain being registered</param>
void SubscribeBrain(string name, BrainParameters brainParameters);

/// <summary>
/// Sends the observations of one Agent.
/// </summary>
/// <param name="key">Batch Key.</param>
/// <param name="agents">Agent info.</param>
void PutObservations(string brainKey, Agent agent);

void DecideBatch();
surfnerd marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Gets the AgentActions based on the batching key.
/// </summary>
/// <param name="key">A key to identify which actions to get</param>
/// <returns></returns>
Dictionary<Agent, AgentAction> GetActions(string key);
}

public interface IBatchedDecisionMaker : IDisposable
{
void PutObservations(string key, Agent agent);
void DecideBatch();
}
}
Loading