diff --git a/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs b/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs index edb70fee9e..66a40a02b7 100644 --- a/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs +++ b/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs @@ -48,22 +48,22 @@ public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask) if (positionX == 0) { - actionMask.WriteMask(0, new[] { k_Left }); + actionMask.SetActionEnabled(0, k_Left, false); } if (positionX == maxPosition) { - actionMask.WriteMask(0, new[] { k_Right }); + actionMask.SetActionEnabled(0, k_Right, false); } if (positionZ == 0) { - actionMask.WriteMask(0, new[] { k_Down }); + actionMask.SetActionEnabled(0, k_Down, false); } if (positionZ == maxPosition) { - actionMask.WriteMask(0, new[] { k_Up }); + actionMask.SetActionEnabled(0, k_Up, false); } } } diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs index b2c0289dfd..661a69c15e 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs @@ -31,10 +31,10 @@ public class Match3Actuator : IActuator, IHeuristicProvider, IBuiltInActuator /// /// public Match3Actuator(AbstractBoard board, - bool forceHeuristic, - int seed, - Agent agent, - string name) + bool forceHeuristic, + int seed, + Agent agent, + string name) { m_Board = board; m_Rows = board.Rows; @@ -78,34 +78,27 @@ public void OnActionReceived(ActionBuffers actions) /// public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) { + const int branch = 0; + bool foundValidMove = false; using (TimerStack.Instance.Scoped("WriteDiscreteActionMask")) { - actionMask.WriteMask(0, InvalidMoveIndices()); - } - } - - /// - public string Name { get; } + var numMoves = m_Board.NumMoves(); - /// - public void ResetData() - { - } - - /// - public BuiltInActuatorType GetBuiltInActuatorType() - { - return BuiltInActuatorType.Match3Actuator; - } - - IEnumerable InvalidMoveIndices() - { - var numValidMoves = m_Board.NumMoves(); + var currentMove = Move.FromMoveIndex(0, m_Board.Rows, m_Board.Columns); + for (var i = 0; i < numMoves; i++) + { + if (m_Board.IsMoveValid(currentMove)) + { + foundValidMove = true; + } + else + { + actionMask.SetActionEnabled(branch, i, false); + } + currentMove.Next(m_Board.Rows, m_Board.Columns); + } - foreach (var move in m_Board.InvalidMoves()) - { - numValidMoves--; - if (numValidMoves == 0) + if (!foundValidMove) { // If all the moves are invalid and we mask all the actions out, this will cause an assert // later on in IDiscreteActionMask. Instead, fire a callback to the user if they provided one, @@ -122,23 +115,33 @@ IEnumerable InvalidMoveIndices() "an invalid move will be passed to AbstractBoard.MakeMove()." ); } - // This means the last move won't be returned as an invalid index. - yield break; + actionMask.SetActionEnabled(branch, numMoves - 1, true); } - yield return move.MoveIndex; } } + /// + public string Name { get; } + + /// + public void ResetData() + { + } + + /// + public BuiltInActuatorType GetBuiltInActuatorType() + { + return BuiltInActuatorType.Match3Actuator; + } + public void Heuristic(in ActionBuffers actionsOut) { var discreteActions = actionsOut.DiscreteActions; discreteActions[0] = GreedyMove(); } - protected int GreedyMove() { - var bestMoveIndex = 0; var bestMovePoints = -1; var numMovesAtCurrentScore = 0; diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 4dccc17be1..568eaf9c9d 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -11,6 +11,11 @@ and this project adheres to #### com.unity.ml-agents (C#) ====== - Some methods previously marked as `Obsolete` have been removed. If you were using these methods, you need to replace them with their supported counterpart. +- The interface for disabling discrete actions in `IDiscreteActionMask` has changed. +`WriteMask(int branch, IEnumerable actionIndices)` was replaced with +`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. See the +[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more +details. (#5060) #### ml-agents / ml-agents-envs / gym-unity (Python) ### Minor Changes diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs b/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs index bf37c2bf82..d44532b16f 100644 --- a/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs +++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs @@ -35,22 +35,17 @@ internal ActuatorDiscreteActionMask(IList actuators, int sumOfDiscret } /// - public void WriteMask(int branch, IEnumerable actionIndices) + public void SetActionEnabled(int branch, int actionIndex, bool isEnabled) { LazyInitialize(); - - // Perform the masking - foreach (var actionIndex in actionIndices) - { #if DEBUG - if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch]) - { - throw new UnityAgentsException( - "Invalid Action Masking: Action Mask is too large for specified branch."); - } -#endif - m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = true; + if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch]) + { + throw new UnityAgentsException( + "Invalid Action Masking: Action Mask is too large for specified branch."); } +#endif + m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = !isEnabled; } void LazyInitialize() @@ -83,8 +78,12 @@ void LazyInitialize() } } - /// - public bool[] GetMask() + /// + /// Get the current mask for an agent. + /// + /// A mask for the agent. A boolean array of length equal to the total number of + /// actions. + internal bool[] GetMask() { #if DEBUG if (m_CurrentMask != null) @@ -116,7 +115,7 @@ void AssertMask() /// /// Resets the current mask for an agent. /// - public void ResetMask() + internal void ResetMask() { if (m_CurrentMask != null) { diff --git a/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs index f56fe7f776..86c204bee2 100644 --- a/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs +++ b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs @@ -173,7 +173,7 @@ public interface IActionReceiver /// /// /// When using Discrete Control, you can prevent the Agent from using a certain - /// action by masking it with . + /// action by masking it with . /// /// See [Agents - Actions] for more information on masking actions. /// diff --git a/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs index bb64c5c34e..bb82dce711 100644 --- a/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs +++ b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs @@ -8,11 +8,11 @@ namespace Unity.MLAgents.Actuators public interface IDiscreteActionMask { /// - /// Modifies an action mask for discrete control agents. + /// Set whether or not the action index for the given branch is allowed. /// - /// - /// When used, the agent will not be able to perform the actions passed as argument - /// at the next decision for the specified action branch. The actionIndices correspond + /// By default, all discrete actions are allowed. + /// If isEnabled is false, the agent will not be able to perform the actions passed as argument + /// at the next decision for the specified action branch. The actionIndex correspond /// to the action options the agent will be unable to perform. /// /// See [Agents - Actions] for more information on masking actions. @@ -20,19 +20,8 @@ public interface IDiscreteActionMask /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_15_docs/docs/Learning-Environment-Design-Agents.md#actions /// /// The branch for which the actions will be masked. - /// The indices of the masked actions. - void WriteMask(int branch, IEnumerable actionIndices); - - /// - /// Get the current mask for an agent. - /// - /// A mask for the agent. A boolean array of length equal to the total number of - /// actions. - bool[] GetMask(); - - /// - /// Resets the current mask for an agent. - /// - void ResetMask(); + /// Index of the action + /// Whether the action is allowed or now. + void SetActionEnabled(int branch, int actionIndex, bool isEnabled); } } diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 65e564a10a..201d31dd16 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -1178,7 +1178,7 @@ public ReadOnlyCollection GetObservations() /// /// /// When using Discrete Control, you can prevent the Agent from using a certain - /// action by masking it with . + /// action by masking it with . /// /// See [Agents - Actions] for more information on masking actions. /// diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs index 3a9af66e40..1c486af483 100644 --- a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs @@ -29,7 +29,9 @@ public void FirstBranchMask() var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); var mask = masker.GetMask(); Assert.IsNull(mask); - masker.WriteMask(0, new[] { 1, 2, 3 }); + masker.SetActionEnabled(0, 1, false); + masker.SetActionEnabled(0, 2, false); + masker.SetActionEnabled(0, 3, false); mask = masker.GetMask(); Assert.IsFalse(mask[0]); Assert.IsTrue(mask[1]); @@ -39,12 +41,27 @@ public void FirstBranchMask() Assert.AreEqual(mask.Length, 15); } + [Test] + public void CanOverwriteMask() + { + var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); + var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); + masker.SetActionEnabled(0, 1, false); + var mask = masker.GetMask(); + Assert.IsTrue(mask[1]); + + masker.SetActionEnabled(0, 1, true); + Assert.IsFalse(mask[1]); + } + [Test] public void SecondBranchMask() { var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); var masker = new ActuatorDiscreteActionMask(new[] { actuator1 }, 15, 3); - masker.WriteMask(1, new[] { 1, 2, 3 }); + masker.SetActionEnabled(1, 1, false); + masker.SetActionEnabled(1, 2, false); + masker.SetActionEnabled(1, 3, false); var mask = masker.GetMask(); Assert.IsFalse(mask[0]); Assert.IsFalse(mask[4]); @@ -60,7 +77,9 @@ public void MaskReset() { var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); - masker.WriteMask(1, new[] { 1, 2, 3 }); + masker.SetActionEnabled(1, 1, false); + masker.SetActionEnabled(1, 2, false); + masker.SetActionEnabled(1, 3, false); masker.ResetMask(); var mask = masker.GetMask(); for (var i = 0; i < 15; i++) @@ -75,15 +94,18 @@ public void ThrowsError() var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); Assert.Catch( - () => masker.WriteMask(0, new[] { 5 })); + () => masker.SetActionEnabled(0, 5, false)); Assert.Catch( - () => masker.WriteMask(1, new[] { 5 })); - masker.WriteMask(2, new[] { 5 }); + () => masker.SetActionEnabled(1, 5, false)); + masker.SetActionEnabled(2, 5, false); Assert.Catch( - () => masker.WriteMask(3, new[] { 1 })); + () => masker.SetActionEnabled(3, 1, false)); masker.GetMask(); masker.ResetMask(); - masker.WriteMask(0, new[] { 0, 1, 2, 3 }); + masker.SetActionEnabled(0, 0, false); + masker.SetActionEnabled(0, 1, false); + masker.SetActionEnabled(0, 2, false); + masker.SetActionEnabled(0, 3, false); Assert.Catch( () => masker.GetMask()); } @@ -93,9 +115,10 @@ public void MultipleMaskEdit() { var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1"); var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3); - masker.WriteMask(0, new[] { 0, 1 }); - masker.WriteMask(0, new[] { 3 }); - masker.WriteMask(2, new[] { 1 }); + masker.SetActionEnabled(0, 0, false); + masker.SetActionEnabled(0, 1, false); + masker.SetActionEnabled(0, 3, false); + masker.SetActionEnabled(2, 1, false); var mask = masker.GetMask(); for (var i = 0; i < 15; i++) { diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs b/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs index 649a643320..31ba0bf28e 100644 --- a/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs +++ b/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs @@ -22,9 +22,13 @@ public void OnActionReceived(ActionBuffers actionBuffers) public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) { + for (var i = 0; i < Masks.Length; i++) { - actionMask.WriteMask(i, Masks[i]); + foreach (var actionIndex in Masks[i]) + { + actionMask.SetActionEnabled(i, actionIndex, false); + } } } diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs index 2b3dcabfef..7fe52951c8 100644 --- a/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs @@ -25,7 +25,10 @@ public void OnActionReceived(ActionBuffers actionBuffers) public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) { - actionMask.WriteMask(Branch, Mask); + foreach (var actionIndex in Mask) + { + actionMask.SetActionEnabled(Branch, actionIndex, false); + } } public void Heuristic(in ActionBuffers actionBuffersOut) diff --git a/docs/Learning-Environment-Design-Agents.md b/docs/Learning-Environment-Design-Agents.md index d0eab75405..ef13fd4053 100644 --- a/docs/Learning-Environment-Design-Agents.md +++ b/docs/Learning-Environment-Design-Agents.md @@ -667,38 +667,40 @@ When using Discrete Actions, it is possible to specify that some actions are impossible for the next decision. When the Agent is controlled by a neural network, the Agent will be unable to perform the specified action. Note that when the Agent is controlled by its Heuristic, the Agent will still be able to -decide to perform the masked action. In order to mask an action, override the -`Agent.WriteDiscreteActionMask()` virtual method, and call -`WriteMask()` on the provided `IDiscreteActionMask`: +decide to perform the masked action. In order to disallow an action, override +the `Agent.WriteDiscreteActionMask()` virtual method, and call +`SetActionEnabled()` on the provided `IDiscreteActionMask`: ```csharp public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask) { - actionMask.WriteMask(branch, actionIndices); + actionMask.SetActionEnabled(branch, actionIndex, isEnabled); } ``` Where: -- `branch` is the index (starting at 0) of the branch on which you want to mask - the action -- `actionIndices` is a list of `int` corresponding to the indices of the actions - that the Agent **cannot** perform. +- `branch` is the index (starting at 0) of the branch on which you want to +allow or disallow the action +- `actionIndex` is the index of the action that you want to allow or disallow. +- `isEnabled` is a bool indicating whether the action should be allowed or now. For example, if you have an Agent with 2 branches and on the first branch (branch 0) there are 4 possible actions : _"do nothing"_, _"jump"_, _"shoot"_ and _"change weapon"_. Then with the code bellow, the Agent will either _"do -nothing"_ or _"change weapon"_ for his next decision (since action index 1 and 2 +nothing"_ or _"change weapon"_ for their next decision (since action index 1 and 2 are masked) ```csharp -WriteMask(0, new int[2]{1,2}); +actionMask.SetActionEnabled(0, 1, false); +actionMask.SetActionEnabled(0, 2, false); ``` Notes: -- You can call `WriteMask` multiple times if you want to put masks on multiple +- You can call `SetActionEnabled` multiple times if you want to put masks on multiple branches. +- At each step, the state of an action is reset and enabled by default. - You cannot mask all the actions of a branch. - You cannot mask actions in continuous control. diff --git a/docs/Migrating.md b/docs/Migrating.md index dd367621e5..a9bd15b496 100644 --- a/docs/Migrating.md +++ b/docs/Migrating.md @@ -16,6 +16,31 @@ double-check that the versions are in the same. The versions can be found in # Migrating ## Migrating the package to version 2.0 - If you used any of the APIs that were deprecated before version 2.0, you need to use their replacement. These deprecated APIs have been removed. See the migration steps bellow for specific API replacements. +### IDiscreteActionMask changes +- The interface for disabling specific discrete actions has changed. `IDiscreteActionMask.WriteMask()` was removed, +and replaced with `SetActionEnabled()`. Instead of returning an IEnumerable with indices to disable, you can +now call `SetActionEnabled` for each index to disable (or enable). As an example, if you overrode +`Agent.WriteDiscreteActionMask()` with something that looked like: + +```csharp +public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask) +{ + var branch = 2; + var actionsToDisable = new[] {1, 3}; + actionMask.WriteMask(branch, actionsToDisable); +} +``` + +the equivalent code would now be + +```csharp +public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask) +{ + var branch = 2; + actionMask.SetActionEnabled(branch, 1, false); + actionMask.SetActionEnabled(branch, 3, false); +} +``` ## Migrating to Release 13 ### Implementing IHeuristic in your IActuator implementations