Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VectorSensor and StackedSensor #2813

Merged
merged 21 commits into from
Nov 1, 2019
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public void TestStoreInitalize()
done = true,
id = 5,
maxStepReached = true,
stackedVectorObservation = new List<float>() { 1f, 1f, 1f },
floatObservations = new List<float>() { 1f, 1f, 1f },
storedTextActions = "TestAction",
storedVectorActions = new[] { 0f, 1f },
textObservation = "TestAction",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,60 @@
using System.Collections.Generic;
using System.Linq;
using Barracuda;
using NUnit.Framework;
using UnityEngine;
using MLAgents.InferenceBrain;
using System.Reflection;


namespace MLAgents.Tests
{
public class EditModeTestInternalBrainTensorGenerator
{
static IEnumerable<Agent> GetFakeAgentInfos()
static IEnumerable<Agent> GetFakeAgents()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename since it returns Agents. This changed a fair amount in order to initialize the agents before returning

{
var acaGo = new GameObject("TestAcademy");
acaGo.AddComponent<TestAcademy>();
var aca = acaGo.GetComponent<TestAcademy>();
aca.resetParameters = new ResetParameters();

var goA = new GameObject("goA");
var bpA = goA.AddComponent<BehaviorParameters>();
bpA.brainParameters.vectorObservationSize = 3;
bpA.brainParameters.numStackedVectorObservations = 1;
var agentA = goA.AddComponent<TestAgent>();

var goB = new GameObject("goB");
var bpB = goA.AddComponent<BehaviorParameters>();
bpB.brainParameters.vectorObservationSize = 3;
bpB.brainParameters.numStackedVectorObservations = 1;
var agentB = goB.AddComponent<TestAgent>();

var agents = new List<Agent> { agentA, agentB };
foreach (var agent in agents)
{
var agentEnableMethod = typeof(Agent).GetMethod("OnEnableHelper",
BindingFlags.Instance | BindingFlags.NonPublic);
agentEnableMethod?.Invoke(agent, new object[] { aca });
}
agentA.collectObservationsSensor.AddObservation(new Vector3(1, 2, 3));
agentB.collectObservationsSensor.AddObservation(new Vector3(4, 5, 6));

var infoA = new AgentInfo
{
stackedVectorObservation = new[] { 1f, 2f, 3f }.ToList(),
storedVectorActions = new[] { 1f, 2f },
actionMasks = null
};
var goB = new GameObject("goB");
var agentB = goB.AddComponent<TestAgent>();

var infoB = new AgentInfo
{
stackedVectorObservation = new[] { 4f, 5f, 6f }.ToList(),
storedVectorActions = new[] { 3f, 4f },
actionMasks = new[] { true, false, false, false, false },
};

agentA.Info = infoA;
agentB.Info = infoB;

return new List<Agent> { agentA, agentB };
return agents;
}

[Test]
Expand Down Expand Up @@ -77,9 +101,12 @@ public void GenerateVectorObservation()
shape = new long[] { 2, 3 }
};
const int batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var agentInfos = GetFakeAgents();
var alloc = new TensorCachingAllocator();
var generator = new VectorObservationGenerator(alloc);
generator.AddSensorIndex(0);
generator.AddSensorIndex(1);
generator.AddSensorIndex(2);
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 0], 1);
Expand All @@ -98,7 +125,7 @@ public void GeneratePreviousActionInput()
valueType = TensorProxy.TensorType.Integer
};
const int batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var agentInfos = GetFakeAgents();
var alloc = new TensorCachingAllocator();
var generator = new PreviousActionInputGenerator(alloc);

Expand All @@ -120,7 +147,7 @@ public void GenerateActionMaskInput()
valueType = TensorProxy.TensorType.FloatingPoint
};
const int batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var agentInfos = GetFakeAgents();
var alloc = new TensorCachingAllocator();
var generator = new ActionMaskInputGenerator(alloc);
generator.Generate(inputTensor, batchSize, agentInfos);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using NUnit.Framework;
using System.Reflection;
using MLAgents.Sensor;
using MLAgents.InferenceBrain;

namespace MLAgents.Tests
{
Expand Down Expand Up @@ -83,10 +82,14 @@ public TestSensor(string n)

public int[] GetFloatObservationShape()
{
return new[] { 1 };
return new[] { 0 };
}

public void WriteToTensor(TensorProxy tensorProxy, int agentIndex) { }
public int Write(WriteAdapter adapter)
{
// No-op
return 0;
}

public byte[] GetCompressedObservation()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ public void TestPerception3D()
var go = new GameObject("MyGameObject");
var rayPer3D = go.AddComponent<RayPerception3D>();
var result = rayPer3D.Perceive(1f, angles, tags);
Debug.Log(result.Count);
Assert.IsTrue(result.Count == angles.Length * (tags.Length + 2));
}

Expand Down
4 changes: 2 additions & 2 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/StandaloneBuildTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ static void BuildStandalonePlayerOSX()
}
#else
var error = buildResult;
var isOK = string.IsNullOrEmpty(error);
var isOk = string.IsNullOrEmpty(error);
surfnerd marked this conversation as resolved.
Show resolved Hide resolved
#endif
if (isOK)
if (isOk)
{
EditorApplication.Exit(0);
}
Expand Down
112 changes: 53 additions & 59 deletions UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,14 @@ namespace MLAgents
/// </summary>
public struct AgentInfo
{
/// <summary>
/// Most recent agent vector (i.e. numeric) observation.
/// </summary>
public List<float> vectorObservation;

/// <summary>
/// The previous agent vector observations, stacked. The length of the
/// history (i.e. number of vector observations to stack) is specified
/// in the Brain parameters.
/// </summary>
public List<float> stackedVectorObservation;

/// <summary>
/// Most recent compressed observations.
/// </summary>
public List<CompressedObservation> compressedObservations;

// TODO struct?
public List<float> floatObservations;

/// <summary>
/// Most recent text observation.
/// </summary>
Expand Down Expand Up @@ -261,6 +252,10 @@ public AgentInfo Info
[FormerlySerializedAs("m_Sensors")]
public List<ISensor> sensors;

public VectorSensor collectObservationsSensor;

WriteAdapter m_WriteAdapter = new WriteAdapter();

/// MonoBehaviour function that is called when the attached GameObject
/// becomes enabled or active.
void OnEnable()
Expand Down Expand Up @@ -475,16 +470,12 @@ void ResetData()
if (m_Info.textObservation == null)
m_Info.textObservation = "";
m_Action.textActions = "";
m_Info.vectorObservation =
new List<float>(param.vectorObservationSize);
m_Info.stackedVectorObservation =
new List<float>(param.vectorObservationSize
* param.numStackedVectorObservations);
m_Info.stackedVectorObservation.AddRange(
new float[param.vectorObservationSize
* param.numStackedVectorObservations]);

m_Info.compressedObservations = new List<CompressedObservation>();
m_Info.floatObservations = new List<float>();
m_Info.floatObservations.AddRange(
new float[param.vectorObservationSize
* param.numStackedVectorObservations]);
m_Info.customObservation = null;
}

Expand Down Expand Up @@ -523,13 +514,30 @@ public virtual float[] Heuristic()
/// </summary>
public void InitializeSensors()
{
// Get all attached sensor components
var attachedSensorComponents = GetComponents<SensorComponent>();
sensors.Capacity += attachedSensorComponents.Length;
foreach (var component in attachedSensorComponents)
{
sensors.Add(component.CreateSensor());
}

// Support legacy CollectObservations
var param = m_PolicyFactory.brainParameters;
if (param.vectorObservationSize > 0)
{
collectObservationsSensor = new VectorSensor(param.vectorObservationSize);
if (param.numStackedVectorObservations > 1)
{
var stackingSensor = new StackingSensor(collectObservationsSensor, param.numStackedVectorObservations);
sensors.Add(stackingSensor);
}
else
{
sensors.Add(collectObservationsSensor);
}
}

// Sort the Sensors by name to ensure determinism
sensors.Sort((x, y) => x.GetName().CompareTo(y.GetName()));

Expand All @@ -554,7 +562,6 @@ void SendInfoToBrain()

m_Info.storedVectorActions = m_Action.vectorActions;
m_Info.storedTextActions = m_Action.textActions;
m_Info.vectorObservation.Clear();
m_Info.compressedObservations.Clear();
m_ActionMasker.ResetMask();
using (TimerStack.Instance.Scoped("CollectObservations"))
Expand All @@ -563,21 +570,7 @@ void SendInfoToBrain()
}
m_Info.actionMasks = m_ActionMasker.GetMask();

var param = m_PolicyFactory.brainParameters;
if (m_Info.vectorObservation.Count != param.vectorObservationSize)
{
throw new UnityAgentsException(string.Format(
"Vector Observation size mismatch in continuous " +
"agent {0}. " +
"Was Expecting {1} but received {2}. ",
gameObject.name,
param.vectorObservationSize,
m_Info.vectorObservation.Count));
}

Utilities.ShiftLeft(m_Info.stackedVectorObservation, param.vectorObservationSize);
Utilities.ReplaceRange(m_Info.stackedVectorObservation, m_Info.vectorObservation,
m_Info.stackedVectorObservation.Count - m_Info.vectorObservation.Count);
// var param = m_PolicyFactory.brainParameters; // look, no brain params!

m_Info.reward = m_Reward;
m_Info.done = m_Done;
Expand Down Expand Up @@ -609,18 +602,27 @@ void SendInfoToBrain()
/// </summary>
public void GenerateSensorData()
{

int floatsWritten = 0;
// Generate data for all Sensors
// TODO add bool argument indicating when to compress? For now, we always will compress.
for (var i = 0; i < sensors.Count; i++)
{
var sensor = sensors[i];
var compressedObs = new CompressedObservation
if (sensor.GetCompressionType() == SensorCompressionType.None)
{
Data = sensor.GetCompressedObservation(),
Shape = sensor.GetFloatObservationShape(),
CompressionType = sensor.GetCompressionType()
};
m_Info.compressedObservations.Add(compressedObs);
m_WriteAdapter.SetTarget(m_Info.floatObservations, floatsWritten);
floatsWritten += sensor.Write(m_WriteAdapter);
}
else
{
var compressedObs = new CompressedObservation
{
Data = sensor.GetCompressedObservation(),
Shape = sensor.GetFloatObservationShape(),
CompressionType = sensor.GetCompressionType()
};
m_Info.compressedObservations.Add(compressedObs);
}
}
}

Expand Down Expand Up @@ -719,7 +721,7 @@ protected void SetActionMask(int branch, IEnumerable<int> actionIndices)
/// <param name="observation">Observation.</param>
protected void AddVectorObs(float observation)
{
m_Info.vectorObservation.Add(observation);
collectObservationsSensor.AddObservation(observation);
}

/// <summary>
Expand All @@ -729,7 +731,7 @@ protected void AddVectorObs(float observation)
/// <param name="observation">Observation.</param>
protected void AddVectorObs(int observation)
{
m_Info.vectorObservation.Add(observation);
collectObservationsSensor.AddObservation(observation);
}

/// <summary>
Expand All @@ -739,9 +741,7 @@ protected void AddVectorObs(int observation)
/// <param name="observation">Observation.</param>
protected void AddVectorObs(Vector3 observation)
{
m_Info.vectorObservation.Add(observation.x);
m_Info.vectorObservation.Add(observation.y);
m_Info.vectorObservation.Add(observation.z);
collectObservationsSensor.AddObservation(observation);
}

/// <summary>
Expand All @@ -751,8 +751,7 @@ protected void AddVectorObs(Vector3 observation)
/// <param name="observation">Observation.</param>
protected void AddVectorObs(Vector2 observation)
{
m_Info.vectorObservation.Add(observation.x);
m_Info.vectorObservation.Add(observation.y);
collectObservationsSensor.AddObservation(observation);
}

/// <summary>
Expand All @@ -762,7 +761,7 @@ protected void AddVectorObs(Vector2 observation)
/// <param name="observation">Observation.</param>
protected void AddVectorObs(IEnumerable<float> observation)
{
m_Info.vectorObservation.AddRange(observation);
collectObservationsSensor.AddObservation(observation);
}

/// <summary>
Expand All @@ -772,10 +771,7 @@ protected void AddVectorObs(IEnumerable<float> observation)
/// <param name="observation">Observation.</param>
protected void AddVectorObs(Quaternion observation)
{
m_Info.vectorObservation.Add(observation.x);
m_Info.vectorObservation.Add(observation.y);
m_Info.vectorObservation.Add(observation.z);
m_Info.vectorObservation.Add(observation.w);
collectObservationsSensor.AddObservation(observation);
}

/// <summary>
Expand All @@ -785,14 +781,12 @@ protected void AddVectorObs(Quaternion observation)
/// <param name="observation"></param>
protected void AddVectorObs(bool observation)
{
m_Info.vectorObservation.Add(observation ? 1f : 0f);
collectObservationsSensor.AddObservation(observation);
}

protected void AddVectorObs(int observation, int range)
{
var oneHotVector = new float[range];
oneHotVector[observation] = 1;
m_Info.vectorObservation.AddRange(oneHotVector);
collectObservationsSensor.AddOneHotObservation(observation, range);
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public static AgentInfoProto ToProto(this AgentInfo ai)
{
var agentInfoProto = new AgentInfoProto
{
StackedVectorObservation = { ai.stackedVectorObservation },
StackedVectorObservation = { ai.floatObservations },
StoredVectorActions = { ai.storedVectorActions },
StoredTextActions = ai.storedTextActions,
TextObservation = ai.textObservation,
Expand Down
Loading