Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

non-IEnumerable interface for action masking #5060

Merged
merged 5 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,22 @@ public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)

if (positionX == 0)
{
actionMask.WriteMask(0, new[] { k_Left });
actionMask.SetActionEnabled(0, k_Left, false);
}

if (positionX == maxPosition)
{
actionMask.WriteMask(0, new[] { k_Right });
actionMask.SetActionEnabled(0, k_Right, false);
}

if (positionZ == 0)
{
actionMask.WriteMask(0, new[] { k_Down });
actionMask.SetActionEnabled(0, k_Down, false);
}

if (positionZ == maxPosition)
{
actionMask.WriteMask(0, new[] { k_Up });
actionMask.SetActionEnabled(0, k_Up, false);
}
}
}
Expand Down
69 changes: 36 additions & 33 deletions com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ public class Match3Actuator : IActuator, IHeuristicProvider, IBuiltInActuator
/// <param name="agent"></param>
/// <param name="name"></param>
public Match3Actuator(AbstractBoard board,
bool forceHeuristic,
int seed,
Agent agent,
string name)
bool forceHeuristic,
int seed,
Agent agent,
string name)
{
m_Board = board;
m_Rows = board.Rows;
Expand Down Expand Up @@ -78,34 +78,27 @@ public void OnActionReceived(ActionBuffers actions)
/// <inheritdoc/>
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
const int branch = 0;
bool foundValidMove = false;
using (TimerStack.Instance.Scoped("WriteDiscreteActionMask"))
{
actionMask.WriteMask(0, InvalidMoveIndices());
}
}

/// <inheritdoc/>
public string Name { get; }
var numMoves = m_Board.NumMoves();

/// <inheritdoc/>
public void ResetData()
{
}

/// <inheritdoc/>
public BuiltInActuatorType GetBuiltInActuatorType()
{
return BuiltInActuatorType.Match3Actuator;
}

IEnumerable<int> InvalidMoveIndices()
{
var numValidMoves = m_Board.NumMoves();
var currentMove = Move.FromMoveIndex(0, m_Board.Rows, m_Board.Columns);
for (var i = 0; i < numMoves; i++)
{
if (m_Board.IsMoveValid(currentMove))
{
foundValidMove = true;
}
else
{
actionMask.SetActionEnabled(branch, i, false);
}
currentMove.Next(m_Board.Rows, m_Board.Columns);
}

foreach (var move in m_Board.InvalidMoves())
{
numValidMoves--;
if (numValidMoves == 0)
if (!foundValidMove)
{
// If all the moves are invalid and we mask all the actions out, this will cause an assert
// later on in IDiscreteActionMask. Instead, fire a callback to the user if they provided one,
Expand All @@ -122,23 +115,33 @@ IEnumerable<int> InvalidMoveIndices()
"an invalid move will be passed to AbstractBoard.MakeMove()."
);
}
// This means the last move won't be returned as an invalid index.
yield break;
actionMask.SetActionEnabled(branch, numMoves - 1, false);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be true

}
yield return move.MoveIndex;
}
}

/// <inheritdoc/>
public string Name { get; }

/// <inheritdoc/>
public void ResetData()
{
}

/// <inheritdoc/>
public BuiltInActuatorType GetBuiltInActuatorType()
{
return BuiltInActuatorType.Match3Actuator;
}

public void Heuristic(in ActionBuffers actionsOut)
{
var discreteActions = actionsOut.DiscreteActions;
discreteActions[0] = GreedyMove();
}


protected int GreedyMove()
{

var bestMoveIndex = 0;
var bestMovePoints = -1;
var numMovesAtCurrentScore = 0;
Expand Down
31 changes: 21 additions & 10 deletions com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,22 @@ public void WriteMask(int branch, IEnumerable<int> actionIndices)
// Perform the masking
foreach (var actionIndex in actionIndices)
{
SetActionEnabled(branch, actionIndex, false);
}
}

/// <inheritdoc/>
public void SetActionEnabled(int branch, int actionIndex, bool isEnabled)
{
LazyInitialize();
#if DEBUG
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch])
{
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
#endif
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = true;
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch])
{
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
#endif
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = !isEnabled;
}

void LazyInitialize()
Expand Down Expand Up @@ -83,8 +90,12 @@ void LazyInitialize()
}
}

/// <inheritdoc/>
public bool[] GetMask()
/// <summary>
/// Get the current mask for an agent.
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
internal bool[] GetMask()
{
#if DEBUG
if (m_CurrentMask != null)
Expand Down Expand Up @@ -116,7 +127,7 @@ void AssertMask()
/// <summary>
/// Resets the current mask for an agent.
/// </summary>
public void ResetMask()
internal void ResetMask()
{
if (m_CurrentMask != null)
{
Expand Down
23 changes: 14 additions & 9 deletions com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,20 @@ public interface IDiscreteActionMask
void WriteMask(int branch, IEnumerable<int> actionIndices);

/// <summary>
/// Get the current mask for an agent.
/// Set whether or not the action index for the given branch is allowed.
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
bool[] GetMask();

/// <summary>
/// Resets the current mask for an agent.
/// </summary>
void ResetMask();
/// By default, all discrete actions are allowed.
/// If isEnabled is false, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndex correspond
/// to the action options the agent will be unable to perform.
///
/// See [Agents - Actions] for more information on masking actions.
///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_13_docs/docs/Learning-Environment-Design-Agents.md#actions
/// </remarks>
/// <param name="branch">The branch for which the actions will be masked.</param>
/// <param name="actionIndex">Index of the action</param>
/// <param name="isEnabled">Whether the action is allowed or now.</param>
void SetActionEnabled(int branch, int actionIndex, bool isEnabled);
}
}