From 3f11c0a4afc1fea4b50eebc5d4b83e422ba0a156 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 7 Apr 2021 11:44:28 -0700 Subject: [PATCH] Clear ActionBuffers before Heuristic calls --- .../DungeonEscape/Scripts/PushAgentEscape.cs | 1 - .../Scripts/FoodCollectorAgent.cs | 3 - .../Examples/Hallway/Scripts/HallwayAgent.cs | 1 - .../PushBlock/Scripts/PushAgentBasic.cs | 1 - .../PushBlock/Scripts/PushAgentCollab.cs | 1 - .../Examples/Pyramids/Scripts/PyramidAgent.cs | 1 - .../Examples/Soccer/Scripts/AgentSoccer.cs | 1 - .../Examples/Sorter/Scripts/SorterAgent.cs | 1 - .../WallJump/Scripts/WallJumpAgent.cs | 1 - com.unity.ml-agents/CHANGELOG.md | 2 + .../Runtime/Policies/HeuristicPolicy.cs | 1 + .../Tests/Editor/Policies.meta | 3 + .../Editor/Policies/HeuristicPolicyTest.cs | 125 ++++++++++++++++++ .../Policies/HeuristicPolicyTest.cs.meta | 3 + 14 files changed, 134 insertions(+), 11 deletions(-) create mode 100644 com.unity.ml-agents/Tests/Editor/Policies.meta create mode 100644 com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs create mode 100644 com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs.meta diff --git a/Project/Assets/ML-Agents/Examples/DungeonEscape/Scripts/PushAgentEscape.cs b/Project/Assets/ML-Agents/Examples/DungeonEscape/Scripts/PushAgentEscape.cs index 6f31ce29b1..5f7c426e39 100644 --- a/Project/Assets/ML-Agents/Examples/DungeonEscape/Scripts/PushAgentEscape.cs +++ b/Project/Assets/ML-Agents/Examples/DungeonEscape/Scripts/PushAgentEscape.cs @@ -116,7 +116,6 @@ void OnTriggerEnter(Collider col) public override void Heuristic(in ActionBuffers actionsOut) { var discreteActionsOut = actionsOut.DiscreteActions; - discreteActionsOut[0] = 0; if (Input.GetKey(KeyCode.D)) { discreteActionsOut[0] = 3; diff --git a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs index 1a8801f0f0..e9e173e56f 100644 --- a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs +++ b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs @@ -192,9 +192,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers) public override void Heuristic(in ActionBuffers actionsOut) { var continuousActionsOut = actionsOut.ContinuousActions; - continuousActionsOut[0] = 0; - continuousActionsOut[1] = 0; - continuousActionsOut[2] = 0; if (Input.GetKey(KeyCode.D)) { continuousActionsOut[2] = 1; diff --git a/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs b/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs index aa7daf1a57..86ed43c533 100644 --- a/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs @@ -100,7 +100,6 @@ void OnCollisionEnter(Collision col) public override void Heuristic(in ActionBuffers actionsOut) { var discreteActionsOut = actionsOut.DiscreteActions; - discreteActionsOut[0] = 0; if (Input.GetKey(KeyCode.D)) { discreteActionsOut[0] = 3; diff --git a/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs b/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs index 7ed48de608..488c96b022 100644 --- a/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs +++ b/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs @@ -177,7 +177,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers) public override void Heuristic(in ActionBuffers actionsOut) { var discreteActionsOut = actionsOut.DiscreteActions; - discreteActionsOut[0] = 0; if (Input.GetKey(KeyCode.D)) { discreteActionsOut[0] = 3; diff --git a/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentCollab.cs b/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentCollab.cs index 5e09e5c20e..fd2fe5fd72 100644 --- a/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentCollab.cs +++ b/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentCollab.cs @@ -71,7 +71,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers) public override void Heuristic(in ActionBuffers actionsOut) { var discreteActionsOut = actionsOut.DiscreteActions; - discreteActionsOut[0] = 0; if (Input.GetKey(KeyCode.D)) { discreteActionsOut[0] = 3; diff --git a/Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs b/Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs index bc848c48d4..5371a33d84 100644 --- a/Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs @@ -66,7 +66,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers) public override void Heuristic(in ActionBuffers actionsOut) { var discreteActionsOut = actionsOut.DiscreteActions; - discreteActionsOut[0] = 0; if (Input.GetKey(KeyCode.D)) { discreteActionsOut[0] = 3; diff --git a/Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs b/Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs index abddb5485f..a794971162 100644 --- a/Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs +++ b/Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs @@ -153,7 +153,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers) public override void Heuristic(in ActionBuffers actionsOut) { var discreteActionsOut = actionsOut.DiscreteActions; - discreteActionsOut.Clear(); //forward if (Input.GetKey(KeyCode.W)) { diff --git a/Project/Assets/ML-Agents/Examples/Sorter/Scripts/SorterAgent.cs b/Project/Assets/ML-Agents/Examples/Sorter/Scripts/SorterAgent.cs index 96b7a9781c..c74b6e3f19 100644 --- a/Project/Assets/ML-Agents/Examples/Sorter/Scripts/SorterAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Sorter/Scripts/SorterAgent.cs @@ -238,7 +238,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers) public override void Heuristic(in ActionBuffers actionsOut) { var discreteActionsOut = actionsOut.DiscreteActions; - discreteActionsOut.Clear(); //forward if (Input.GetKey(KeyCode.W)) { diff --git a/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs b/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs index 0a34632578..d9379c6629 100644 --- a/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs +++ b/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs @@ -264,7 +264,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers) public override void Heuristic(in ActionBuffers actionsOut) { var discreteActionsOut = actionsOut.DiscreteActions; - discreteActionsOut.Clear(); if (Input.GetKey(KeyCode.D)) { discreteActionsOut[1] = 2; diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 4c5aafd9a9..a8f1c7dd5e 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -48,6 +48,8 @@ depend on the previous behavior, you can explicitly set the Agent's `InferenceDe - Added support for `Goal Signal` as a type of observation. Trainers can now use HyperNetworks to process `Goal Signal`. Trainers with HyperNetworks are more effective at solving multiple tasks. (#5142, #5159, #5149) - Modified the [GridWorld environment](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Learning-Environment-Examples.md#gridworld) to use the new `Goal Signal` feature. (#5193) - `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222) +- `ActionBuffers` are now reset to zero before being passed to `Agent.Heuristic()` and +`IHeuristicProvider.Heuristic()`. (#5227) #### ml-agents / ml-agents-envs / gym-unity (Python) - Some console output have been moved from `info` to `debug` and will not be printed by default. If you want all messages to be printed, you can run `mlagents-learn` with the `--debug` option or add the line `debug: true` at the top of the yaml config file. (#5211) diff --git a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs index ae7273001c..b18956833a 100644 --- a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs @@ -46,6 +46,7 @@ public ref readonly ActionBuffers DecideAction() { if (!m_Done && m_DecisionRequested) { + m_ActionBuffers.Clear(); m_ActuatorManager.ApplyHeuristic(m_ActionBuffers); } m_DecisionRequested = false; diff --git a/com.unity.ml-agents/Tests/Editor/Policies.meta b/com.unity.ml-agents/Tests/Editor/Policies.meta new file mode 100644 index 0000000000..be3f189b91 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Policies.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: df271cac120e4d6893b14599fa8eb64d +timeCreated: 1617813392 \ No newline at end of file diff --git a/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs b/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs new file mode 100644 index 0000000000..944b7ff907 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs @@ -0,0 +1,125 @@ +using NUnit.Framework; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Policies; +using UnityEngine; + +namespace Unity.MLAgents.Tests.Policies +{ + [TestFixture] + public class HeuristicPolicyTest + { + [SetUp] + public void SetUp() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + } + + /// + /// Assert that the action buffers are initialized to zero, and then set them to non-zero values. + /// + /// + static void CheckAndSetBuffer(in ActionBuffers actionsOut) + { + var continuousActions = actionsOut.ContinuousActions; + for (var continuousIndex = 0; continuousIndex < continuousActions.Length; continuousIndex++) + { + Assert.AreEqual(continuousActions[continuousIndex], 0.0f); + continuousActions[continuousIndex] = 1.0f; + } + + var discreteActions = actionsOut.DiscreteActions; + for (var discreteIndex = 0; discreteIndex < discreteActions.Length; discreteIndex++) + { + Assert.AreEqual(discreteActions[discreteIndex], 0); + discreteActions[discreteIndex] = 1; + } + } + + + class ActionClearedAgent : Agent + { + public int HeuristicCalls = 0; + public override void Heuristic(in ActionBuffers actionsOut) + { + CheckAndSetBuffer(actionsOut); + HeuristicCalls++; + } + } + + class ActionClearedActuator : IActuator + { + public int HeuristicCalls = 0; + public ActionClearedActuator(ActionSpec actionSpec) + { + ActionSpec = actionSpec; + Name = GetType().Name; + } + + public void OnActionReceived(ActionBuffers actionBuffers) + { + } + + public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) + { + } + + public void Heuristic(in ActionBuffers actionBuffersOut) + { + CheckAndSetBuffer(actionBuffersOut); + HeuristicCalls++; + } + + public ActionSpec ActionSpec { get; } + public string Name { get; } + + public void ResetData() + { + + } + } + + class ActionClearedActuatorComponent : ActuatorComponent + { + public ActionClearedActuator ActionClearedActuator; + public ActionClearedActuatorComponent() + { + ActionSpec = new ActionSpec(2, new[] { 3, 3 }); + } + + public override IActuator[] CreateActuators() + { + ActionClearedActuator = new ActionClearedActuator(ActionSpec); + return new IActuator[] { ActionClearedActuator }; + } + + public override ActionSpec ActionSpec { get; } + } + + [Test] + public void TestActionsCleared() + { + var gameObj = new GameObject(); + var agent = gameObj.AddComponent(); + var behaviorParameters = agent.GetComponent(); + behaviorParameters.BrainParameters.ActionSpec = new ActionSpec(1, new[] { 4 }); + behaviorParameters.BrainParameters.VectorObservationSize = 0; + behaviorParameters.BehaviorType = BehaviorType.HeuristicOnly; + + var actuatorComponent = gameObj.AddComponent(); + agent.LazyInitialize(); + + const int k_NumSteps = 5; + for (var i = 0; i < k_NumSteps; i++) + { + agent.RequestDecision(); + Academy.Instance.EnvironmentStep(); + } + + Assert.AreEqual(agent.HeuristicCalls, k_NumSteps); + Assert.AreEqual(actuatorComponent.ActionClearedActuator.HeuristicCalls, k_NumSteps); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs.meta b/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs.meta new file mode 100644 index 0000000000..682a64b746 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 5108e92f91a04ddab9d628c9bc57cadb +timeCreated: 1617813411 \ No newline at end of file