Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VectorSensor and StackedSensor #2813

Merged
merged 21 commits into from
Nov 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public void TestStoreInitalize()
done = true,
id = 5,
maxStepReached = true,
stackedVectorObservation = new List<float>() { 1f, 1f, 1f },
floatObservations = new List<float>() { 1f, 1f, 1f },
storedTextActions = "TestAction",
storedVectorActions = new[] { 0f, 1f },
textObservation = "TestAction",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,60 @@
using System.Collections.Generic;
using System.Linq;
using Barracuda;
using NUnit.Framework;
using UnityEngine;
using MLAgents.InferenceBrain;
using System.Reflection;


namespace MLAgents.Tests
{
public class EditModeTestInternalBrainTensorGenerator
{
static IEnumerable<Agent> GetFakeAgentInfos()
static IEnumerable<Agent> GetFakeAgents()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename since it returns Agents. This changed a fair amount in order to initialize the agents before returning

{
var acaGo = new GameObject("TestAcademy");
acaGo.AddComponent<TestAcademy>();
var aca = acaGo.GetComponent<TestAcademy>();
aca.resetParameters = new ResetParameters();

var goA = new GameObject("goA");
var bpA = goA.AddComponent<BehaviorParameters>();
bpA.brainParameters.vectorObservationSize = 3;
bpA.brainParameters.numStackedVectorObservations = 1;
var agentA = goA.AddComponent<TestAgent>();

var goB = new GameObject("goB");
var bpB = goB.AddComponent<BehaviorParameters>();
bpB.brainParameters.vectorObservationSize = 3;
bpB.brainParameters.numStackedVectorObservations = 1;
var agentB = goB.AddComponent<TestAgent>();

var agents = new List<Agent> { agentA, agentB };
foreach (var agent in agents)
{
var agentEnableMethod = typeof(Agent).GetMethod("OnEnableHelper",
BindingFlags.Instance | BindingFlags.NonPublic);
agentEnableMethod?.Invoke(agent, new object[] { aca });
}
agentA.collectObservationsSensor.AddObservation(new Vector3(1, 2, 3));
agentB.collectObservationsSensor.AddObservation(new Vector3(4, 5, 6));

var infoA = new AgentInfo
{
stackedVectorObservation = new[] { 1f, 2f, 3f }.ToList(),
storedVectorActions = new[] { 1f, 2f },
actionMasks = null
};
var goB = new GameObject("goB");
var agentB = goB.AddComponent<TestAgent>();

var infoB = new AgentInfo
{
stackedVectorObservation = new[] { 4f, 5f, 6f }.ToList(),
storedVectorActions = new[] { 3f, 4f },
actionMasks = new[] { true, false, false, false, false },
};

agentA.Info = infoA;
agentB.Info = infoB;

return new List<Agent> { agentA, agentB };
return agents;
}

[Test]
Expand Down Expand Up @@ -77,9 +101,12 @@ public void GenerateVectorObservation()
shape = new long[] { 2, 3 }
};
const int batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var agentInfos = GetFakeAgents();
var alloc = new TensorCachingAllocator();
var generator = new VectorObservationGenerator(alloc);
generator.AddSensorIndex(0);
generator.AddSensorIndex(1);
generator.AddSensorIndex(2);
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 0], 1);
Expand All @@ -98,7 +125,7 @@ public void GeneratePreviousActionInput()
valueType = TensorProxy.TensorType.Integer
};
const int batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var agentInfos = GetFakeAgents();
var alloc = new TensorCachingAllocator();
var generator = new PreviousActionInputGenerator(alloc);

Expand All @@ -120,7 +147,7 @@ public void GenerateActionMaskInput()
valueType = TensorProxy.TensorType.FloatingPoint
};
const int batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var agentInfos = GetFakeAgents();
var alloc = new TensorCachingAllocator();
var generator = new ActionMaskInputGenerator(alloc);
generator.Generate(inputTensor, batchSize, agentInfos);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using NUnit.Framework;
using System.Reflection;
using MLAgents.Sensor;
using MLAgents.InferenceBrain;

namespace MLAgents.Tests
{
Expand Down Expand Up @@ -83,10 +82,14 @@ public TestSensor(string n)

public int[] GetFloatObservationShape()
{
return new[] { 1 };
return new[] { 0 };
}

public void WriteToTensor(TensorProxy tensorProxy, int agentIndex) { }
public int Write(WriteAdapter adapter)
{
// No-op
return 0;
}

public byte[] GetCompressedObservation()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ public void TestPerception3D()
var go = new GameObject("MyGameObject");
var rayPer3D = go.AddComponent<RayPerception3D>();
var result = rayPer3D.Perceive(1f, angles, tags);
Debug.Log(result.Count);
Assert.IsTrue(result.Count == angles.Length * (tags.Length + 2));
}

Expand Down
8 changes: 8 additions & 0 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using NUnit.Framework;
using UnityEngine;
using MLAgents.Sensor;

namespace MLAgents.Tests
{
public class StackingSensorTests
{
[Test]
public void TestCtor()
{
ISensor wrapped = new VectorSensor(4);
ISensor sensor = new StackingSensor(wrapped, 4);
Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName());
Assert.AreEqual(sensor.GetFloatObservationShape(), new [] {16});
}

[Test]
public void TestStacking()
{
VectorSensor wrapped = new VectorSensor(2);
ISensor sensor = new StackingSensor(wrapped, 3);

wrapped.AddObservation(new [] {1f, 2f});
SensorTestHelper.CompareObservation(sensor, new [] {0f, 0f, 0f, 0f, 1f, 2f});

wrapped.AddObservation(new [] {3f, 4f});
SensorTestHelper.CompareObservation(sensor, new [] {0f, 0f, 1f, 2f, 3f, 4f});

wrapped.AddObservation(new [] {5f, 6f});
SensorTestHelper.CompareObservation(sensor, new [] {1f, 2f, 3f, 4f, 5f, 6f});

wrapped.AddObservation(new [] {7f, 8f});
SensorTestHelper.CompareObservation(sensor, new [] {3f, 4f, 5f, 6f, 7f, 8f});

wrapped.AddObservation(new [] {9f, 10f});
SensorTestHelper.CompareObservation(sensor, new [] {5f, 6f, 7f, 8f, 9f, 10f});
}


}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

138 changes: 138 additions & 0 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/VectorSensorTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
using NUnit.Framework;
using UnityEngine;
using MLAgents.Sensor;

namespace MLAgents.Tests
{
public class SensorTestHelper
{
public static void CompareObservation(ISensor sensor, float[] expected)
{
var numExpected = expected.Length;
const float fill = -1337f;
var output = new float[numExpected];
for (var i = 0; i < numExpected; i++)
{
output[i] = fill;
}
Assert.AreEqual(fill, output[0]);

WriteAdapter writer = new WriteAdapter();
writer.SetTarget(output, 0);

// Make sure WriteAdapter didn't touch anything
Assert.AreEqual(fill, output[0]);

sensor.Write(writer);
for (var i = 0; i < numExpected; i++)
{
Assert.AreEqual(expected[i], output[i]);
}
}
}

public class VectorSensorTests
{
[Test]
public void TestCtor()
{
ISensor sensor = new VectorSensor(4);
Assert.AreEqual("VectorSensor_size4", sensor.GetName());

sensor = new VectorSensor(3, "test_sensor");
Assert.AreEqual("test_sensor", sensor.GetName());
}

[Test]
public void TestWrite()
{
var sensor = new VectorSensor(4);
sensor.AddObservation(1f);
sensor.AddObservation(2f);
sensor.AddObservation(3f);
sensor.AddObservation(4f);

SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f });
}

[Test]
public void TestAddObservationFloat()
{
var sensor = new VectorSensor(1);
sensor.AddObservation(1.2f);
SensorTestHelper.CompareObservation(sensor, new []{1.2f});
}

[Test]
public void TestAddObservationInt()
{
var sensor = new VectorSensor(1);
sensor.AddObservation(42);
SensorTestHelper.CompareObservation(sensor, new []{42f});
}

[Test]
public void TestAddObservationVec()
{
var sensor = new VectorSensor(3);
sensor.AddObservation(new Vector3(1,2,3));
SensorTestHelper.CompareObservation(sensor, new []{1f, 2f, 3f});

sensor = new VectorSensor(2);
sensor.AddObservation(new Vector2(4,5));
SensorTestHelper.CompareObservation(sensor, new[] { 4f, 5f });
}

[Test]
public void TestAddObservationQuaternion()
{
var sensor = new VectorSensor(4);
sensor.AddObservation(Quaternion.identity);
SensorTestHelper.CompareObservation(sensor, new []{0f, 0f, 0f, 1f});
}

[Test]
public void TestWriteEnumerable()
{
var sensor = new VectorSensor(4);
sensor.AddObservation(new [] {1f, 2f, 3f, 4f});

SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f });
}

[Test]
public void TestAddObservationBool()
{
var sensor = new VectorSensor(1);
sensor.AddObservation(true);
SensorTestHelper.CompareObservation(sensor, new []{1f});
}

[Test]
public void TestAddObservationOneHot()
{
var sensor = new VectorSensor(4);
sensor.AddOneHotObservation(2, 4);
SensorTestHelper.CompareObservation(sensor, new []{0f, 0f, 1f, 0f});
}

[Test]
public void TestWriteTooMany()
{
var sensor = new VectorSensor(2);
sensor.AddObservation(new [] {1f, 2f, 3f, 4f});

SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f});
}

[Test]
public void TestWriteNotEnough()
{
var sensor = new VectorSensor(4);
sensor.AddObservation(new [] {1f, 2f});

// Make sure extra zeros are added
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 0f, 0f});
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading