Skip to content

Commit

Permalink
non-IEnumerable interface for action masking (#5060)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Elion authored and surfnerd committed Mar 18, 2021
1 parent 27cbc39 commit 80f34b9
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,22 @@ public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)

if (positionX == 0)
{
actionMask.WriteMask(0, new[] { k_Left });
actionMask.SetActionEnabled(0, k_Left, false);
}

if (positionX == maxPosition)
{
actionMask.WriteMask(0, new[] { k_Right });
actionMask.SetActionEnabled(0, k_Right, false);
}

if (positionZ == 0)
{
actionMask.WriteMask(0, new[] { k_Down });
actionMask.SetActionEnabled(0, k_Down, false);
}

if (positionZ == maxPosition)
{
actionMask.WriteMask(0, new[] { k_Up });
actionMask.SetActionEnabled(0, k_Up, false);
}
}
}
Expand Down
69 changes: 36 additions & 33 deletions com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ public class Match3Actuator : IActuator, IHeuristicProvider, IBuiltInActuator
/// <param name="agent"></param>
/// <param name="name"></param>
public Match3Actuator(AbstractBoard board,
bool forceHeuristic,
int seed,
Agent agent,
string name)
bool forceHeuristic,
int seed,
Agent agent,
string name)
{
m_Board = board;
m_Rows = board.Rows;
Expand Down Expand Up @@ -78,34 +78,27 @@ public void OnActionReceived(ActionBuffers actions)
/// <inheritdoc/>
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
const int branch = 0;
bool foundValidMove = false;
using (TimerStack.Instance.Scoped("WriteDiscreteActionMask"))
{
actionMask.WriteMask(0, InvalidMoveIndices());
}
}

/// <inheritdoc/>
public string Name { get; }
var numMoves = m_Board.NumMoves();

/// <inheritdoc/>
public void ResetData()
{
}

/// <inheritdoc/>
public BuiltInActuatorType GetBuiltInActuatorType()
{
return BuiltInActuatorType.Match3Actuator;
}

IEnumerable<int> InvalidMoveIndices()
{
var numValidMoves = m_Board.NumMoves();
var currentMove = Move.FromMoveIndex(0, m_Board.Rows, m_Board.Columns);
for (var i = 0; i < numMoves; i++)
{
if (m_Board.IsMoveValid(currentMove))
{
foundValidMove = true;
}
else
{
actionMask.SetActionEnabled(branch, i, false);
}
currentMove.Next(m_Board.Rows, m_Board.Columns);
}

foreach (var move in m_Board.InvalidMoves())
{
numValidMoves--;
if (numValidMoves == 0)
if (!foundValidMove)
{
// If all the moves are invalid and we mask all the actions out, this will cause an assert
// later on in IDiscreteActionMask. Instead, fire a callback to the user if they provided one,
Expand All @@ -122,23 +115,33 @@ IEnumerable<int> InvalidMoveIndices()
"an invalid move will be passed to AbstractBoard.MakeMove()."
);
}
// This means the last move won't be returned as an invalid index.
yield break;
actionMask.SetActionEnabled(branch, numMoves - 1, true);
}
yield return move.MoveIndex;
}
}

/// <inheritdoc/>
public string Name { get; }

/// <inheritdoc/>
public void ResetData()
{
}

/// <inheritdoc/>
public BuiltInActuatorType GetBuiltInActuatorType()
{
return BuiltInActuatorType.Match3Actuator;
}

public void Heuristic(in ActionBuffers actionsOut)
{
var discreteActions = actionsOut.DiscreteActions;
discreteActions[0] = GreedyMove();
}


protected int GreedyMove()
{

var bestMoveIndex = 0;
var bestMovePoints = -1;
var numMovesAtCurrentScore = 0;
Expand Down
5 changes: 5 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ and this project adheres to
#### com.unity.ml-agents (C#)
======
- Some methods previously marked as `Obsolete` have been removed. If you were using these methods, you need to replace them with their supported counterpart.
- The interface for disabling discrete actions in `IDiscreteActionMask` has changed.
`WriteMask(int branch, IEnumerable<int> actionIndices)` was replaced with
`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. See the
[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more
details. (#5060)
#### ml-agents / ml-agents-envs / gym-unity (Python)

### Minor Changes
Expand Down
29 changes: 14 additions & 15 deletions com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,17 @@ internal ActuatorDiscreteActionMask(IList<IActuator> actuators, int sumOfDiscret
}

/// <inheritdoc/>
public void WriteMask(int branch, IEnumerable<int> actionIndices)
public void SetActionEnabled(int branch, int actionIndex, bool isEnabled)
{
LazyInitialize();

// Perform the masking
foreach (var actionIndex in actionIndices)
{
#if DEBUG
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch])
{
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
#endif
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = true;
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch])
{
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
#endif
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = !isEnabled;
}

void LazyInitialize()
Expand Down Expand Up @@ -83,8 +78,12 @@ void LazyInitialize()
}
}

/// <inheritdoc/>
public bool[] GetMask()
/// <summary>
/// Get the current mask for an agent.
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
internal bool[] GetMask()
{
#if DEBUG
if (m_CurrentMask != null)
Expand Down Expand Up @@ -116,7 +115,7 @@ void AssertMask()
/// <summary>
/// Resets the current mask for an agent.
/// </summary>
public void ResetMask()
internal void ResetMask()
{
if (m_CurrentMask != null)
{
Expand Down
2 changes: 1 addition & 1 deletion com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ public interface IActionReceiver
/// </param>
/// <remarks>
/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it with <see cref="IDiscreteActionMask.WriteMask"/>.
/// action by masking it with <see cref="IDiscreteActionMask.SetActionEnabled"/>.
///
/// See [Agents - Actions] for more information on masking actions.
///
Expand Down
25 changes: 7 additions & 18 deletions com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,20 @@ namespace Unity.MLAgents.Actuators
public interface IDiscreteActionMask
{
/// <summary>
/// Modifies an action mask for discrete control agents.
/// Set whether or not the action index for the given branch is allowed.
/// </summary>
/// <remarks>
/// When used, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndices correspond
/// By default, all discrete actions are allowed.
/// If isEnabled is false, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndex correspond
/// to the action options the agent will be unable to perform.
///
/// See [Agents - Actions] for more information on masking actions.
///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_15_docs/docs/Learning-Environment-Design-Agents.md#actions
/// </remarks>
/// <param name="branch">The branch for which the actions will be masked.</param>
/// <param name="actionIndices">The indices of the masked actions.</param>
void WriteMask(int branch, IEnumerable<int> actionIndices);

/// <summary>
/// Get the current mask for an agent.
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
bool[] GetMask();

/// <summary>
/// Resets the current mask for an agent.
/// </summary>
void ResetMask();
/// <param name="actionIndex">Index of the action</param>
/// <param name="isEnabled">Whether the action is allowed or now.</param>
void SetActionEnabled(int branch, int actionIndex, bool isEnabled);
}
}
2 changes: 1 addition & 1 deletion com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,7 @@ public ReadOnlyCollection<float> GetObservations()
/// </param>
/// <remarks>
/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it with <see cref="IDiscreteActionMask.WriteMask(int, IEnumerable{int})"/>.
/// action by masking it with <see cref="IDiscreteActionMask.SetActionEnabled"/>.
///
/// See [Agents - Actions] for more information on masking actions.
///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ public void FirstBranchMask()
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
var mask = masker.GetMask();
Assert.IsNull(mask);
masker.WriteMask(0, new[] { 1, 2, 3 });
masker.SetActionEnabled(0, 1, false);
masker.SetActionEnabled(0, 2, false);
masker.SetActionEnabled(0, 3, false);
mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsTrue(mask[1]);
Expand All @@ -39,12 +41,27 @@ public void FirstBranchMask()
Assert.AreEqual(mask.Length, 15);
}

[Test]
public void CanOverwriteMask()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
masker.SetActionEnabled(0, 1, false);
var mask = masker.GetMask();
Assert.IsTrue(mask[1]);

masker.SetActionEnabled(0, 1, true);
Assert.IsFalse(mask[1]);
}

[Test]
public void SecondBranchMask()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
var masker = new ActuatorDiscreteActionMask(new[] { actuator1 }, 15, 3);
masker.WriteMask(1, new[] { 1, 2, 3 });
masker.SetActionEnabled(1, 1, false);
masker.SetActionEnabled(1, 2, false);
masker.SetActionEnabled(1, 3, false);
var mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsFalse(mask[4]);
Expand All @@ -60,7 +77,9 @@ public void MaskReset()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
masker.WriteMask(1, new[] { 1, 2, 3 });
masker.SetActionEnabled(1, 1, false);
masker.SetActionEnabled(1, 2, false);
masker.SetActionEnabled(1, 3, false);
masker.ResetMask();
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)
Expand All @@ -75,15 +94,18 @@ public void ThrowsError()
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
Assert.Catch<UnityAgentsException>(
() => masker.WriteMask(0, new[] { 5 }));
() => masker.SetActionEnabled(0, 5, false));
Assert.Catch<UnityAgentsException>(
() => masker.WriteMask(1, new[] { 5 }));
masker.WriteMask(2, new[] { 5 });
() => masker.SetActionEnabled(1, 5, false));
masker.SetActionEnabled(2, 5, false);
Assert.Catch<UnityAgentsException>(
() => masker.WriteMask(3, new[] { 1 }));
() => masker.SetActionEnabled(3, 1, false));
masker.GetMask();
masker.ResetMask();
masker.WriteMask(0, new[] { 0, 1, 2, 3 });
masker.SetActionEnabled(0, 0, false);
masker.SetActionEnabled(0, 1, false);
masker.SetActionEnabled(0, 2, false);
masker.SetActionEnabled(0, 3, false);
Assert.Catch<UnityAgentsException>(
() => masker.GetMask());
}
Expand All @@ -93,9 +115,10 @@ public void MultipleMaskEdit()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
masker.WriteMask(0, new[] { 0, 1 });
masker.WriteMask(0, new[] { 3 });
masker.WriteMask(2, new[] { 1 });
masker.SetActionEnabled(0, 0, false);
masker.SetActionEnabled(0, 1, false);
masker.SetActionEnabled(0, 3, false);
masker.SetActionEnabled(2, 1, false);
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)
{
Expand Down
6 changes: 5 additions & 1 deletion com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ public void OnActionReceived(ActionBuffers actionBuffers)

public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{

for (var i = 0; i < Masks.Length; i++)
{
actionMask.WriteMask(i, Masks[i]);
foreach (var actionIndex in Masks[i])
{
actionMask.SetActionEnabled(i, actionIndex, false);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ public void OnActionReceived(ActionBuffers actionBuffers)

public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
actionMask.WriteMask(Branch, Mask);
foreach (var actionIndex in Mask)
{
actionMask.SetActionEnabled(Branch, actionIndex, false);
}
}

public void Heuristic(in ActionBuffers actionBuffersOut)
Expand Down
Loading

0 comments on commit 80f34b9

Please sign in to comment.