Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLA-1129] Clear ActionBuffers before Heuristic calls #5227

Merged
merged 1 commit into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ void OnTriggerEnter(Collider col)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = 0;
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[0] = 3;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
public override void Heuristic(in ActionBuffers actionsOut)
{
var continuousActionsOut = actionsOut.ContinuousActions;
continuousActionsOut[0] = 0;
continuousActionsOut[1] = 0;
continuousActionsOut[2] = 0;
if (Input.GetKey(KeyCode.D))
{
continuousActionsOut[2] = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ void OnCollisionEnter(Collision col)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = 0;
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[0] = 3;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = 0;
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[0] = 3;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = 0;
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[0] = 3;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = 0;
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[0] = 3;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut.Clear();
//forward
if (Input.GetKey(KeyCode.W))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut.Clear();
//forward
if (Input.GetKey(KeyCode.W))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut.Clear();
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[1] = 2;
Expand Down
2 changes: 2 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ depend on the previous behavior, you can explicitly set the Agent's `InferenceDe
- Added support for `Goal Signal` as a type of observation. Trainers can now use HyperNetworks to process `Goal Signal`. Trainers with HyperNetworks are more effective at solving multiple tasks. (#5142, #5159, #5149)
- Modified the [GridWorld environment](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Learning-Environment-Examples.md#gridworld) to use the new `Goal Signal` feature. (#5193)
- `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222)
- `ActionBuffers` are now reset to zero before being passed to `Agent.Heuristic()` and
`IHeuristicProvider.Heuristic()`. (#5227)

#### ml-agents / ml-agents-envs / gym-unity (Python)
- Some console output have been moved from `info` to `debug` and will not be printed by default. If you want all messages to be printed, you can run `mlagents-learn` with the `--debug` option or add the line `debug: true` at the top of the yaml config file. (#5211)
Expand Down
1 change: 1 addition & 0 deletions com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public ref readonly ActionBuffers DecideAction()
{
if (!m_Done && m_DecisionRequested)
{
m_ActionBuffers.Clear();
m_ActuatorManager.ApplyHeuristic(m_ActionBuffers);
}
m_DecisionRequested = false;
Expand Down
3 changes: 3 additions & 0 deletions com.unity.ml-agents/Tests/Editor/Policies.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

125 changes: 125 additions & 0 deletions com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
using UnityEngine;

namespace Unity.MLAgents.Tests.Policies
{
[TestFixture]
public class HeuristicPolicyTest
{
[SetUp]
public void SetUp()
{
if (Academy.IsInitialized)
{
Academy.Instance.Dispose();
}
}

/// <summary>
/// Assert that the action buffers are initialized to zero, and then set them to non-zero values.
/// </summary>
/// <param name="actionsOut"></param>
static void CheckAndSetBuffer(in ActionBuffers actionsOut)
{
var continuousActions = actionsOut.ContinuousActions;
for (var continuousIndex = 0; continuousIndex < continuousActions.Length; continuousIndex++)
{
Assert.AreEqual(continuousActions[continuousIndex], 0.0f);
continuousActions[continuousIndex] = 1.0f;
}

var discreteActions = actionsOut.DiscreteActions;
for (var discreteIndex = 0; discreteIndex < discreteActions.Length; discreteIndex++)
{
Assert.AreEqual(discreteActions[discreteIndex], 0);
discreteActions[discreteIndex] = 1;
}
}


class ActionClearedAgent : Agent
{
public int HeuristicCalls = 0;
public override void Heuristic(in ActionBuffers actionsOut)
{
CheckAndSetBuffer(actionsOut);
HeuristicCalls++;
}
}

class ActionClearedActuator : IActuator
{
public int HeuristicCalls = 0;
public ActionClearedActuator(ActionSpec actionSpec)
{
ActionSpec = actionSpec;
Name = GetType().Name;
}

public void OnActionReceived(ActionBuffers actionBuffers)
{
}

public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
}

public void Heuristic(in ActionBuffers actionBuffersOut)
{
CheckAndSetBuffer(actionBuffersOut);
HeuristicCalls++;
}

public ActionSpec ActionSpec { get; }
public string Name { get; }

public void ResetData()
{

}
}

class ActionClearedActuatorComponent : ActuatorComponent
{
public ActionClearedActuator ActionClearedActuator;
public ActionClearedActuatorComponent()
{
ActionSpec = new ActionSpec(2, new[] { 3, 3 });
}

public override IActuator[] CreateActuators()
{
ActionClearedActuator = new ActionClearedActuator(ActionSpec);
return new IActuator[] { ActionClearedActuator };
}

public override ActionSpec ActionSpec { get; }
}

[Test]
public void TestActionsCleared()
{
var gameObj = new GameObject();
var agent = gameObj.AddComponent<ActionClearedAgent>();
var behaviorParameters = agent.GetComponent<BehaviorParameters>();
behaviorParameters.BrainParameters.ActionSpec = new ActionSpec(1, new[] { 4 });
behaviorParameters.BrainParameters.VectorObservationSize = 0;
behaviorParameters.BehaviorType = BehaviorType.HeuristicOnly;

var actuatorComponent = gameObj.AddComponent<ActionClearedActuatorComponent>();
agent.LazyInitialize();

const int k_NumSteps = 5;
for (var i = 0; i < k_NumSteps; i++)
{
agent.RequestDecision();
Academy.Instance.EnvironmentStep();
}

Assert.AreEqual(agent.HeuristicCalls, k_NumSteps);
Assert.AreEqual(actuatorComponent.ActionClearedActuator.HeuristicCalls, k_NumSteps);
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.