Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation using IHeuristicProvider. #4849

Merged
merged 27 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1cd526a
Initial implementation using IHeuristic.
surfnerd Jan 12, 2021
6a45d35
Remove some 'style' changes.
surfnerd Jan 12, 2021
8382acf
Fix the fix.
surfnerd Jan 12, 2021
866e826
Send ActuatorManager to HeuristicPolicy. Have VectorActuator call in…
surfnerd Jan 13, 2021
5917d20
Update com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs
surfnerd Jan 13, 2021
9313864
Update com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
surfnerd Jan 13, 2021
5c028e6
Pass ActuatorManager directly to HeuristicPolicy and PolicyFactory.
surfnerd Jan 13, 2021
e30f643
Revert some logic.
surfnerd Jan 13, 2021
4d72bcb
Update changelog.
surfnerd Jan 13, 2021
3bb3f17
Add tests.
surfnerd Jan 13, 2021
b336c72
Merge branch 'master' into develop-heuristic-interface
surfnerd Jan 14, 2021
35dd22b
Use actuator heuristic for basic scene.
surfnerd Jan 14, 2021
c9245e0
Rename Match3 -> Board for extension classes.
surfnerd Jan 14, 2021
2542348
Fix comments to remove Match3.
surfnerd Jan 14, 2021
640a1bf
Update match3 examples and extensions to use actuators.
surfnerd Jan 15, 2021
6c1761e
Update tests.
surfnerd Jan 15, 2021
55d4348
Update changelog and migrating docs.
surfnerd Jan 15, 2021
511ffc3
Install wheel package to fix error in yamato.
surfnerd Jan 15, 2021
f771e3b
Do not clear heuristic action buffer.
surfnerd Jan 15, 2021
550c74c
Order of operations.
surfnerd Jan 15, 2021
4b1c3e6
Undo renames.
surfnerd Jan 15, 2021
5639930
Make sure all references to 'Board' namespace/directories/comments ar…
surfnerd Jan 19, 2021
67571f4
Remove Heuristic Quality. Move the random seed out of the extension …
surfnerd Jan 19, 2021
797c2fd
Rename ActuatorManager.Heuristic to ActuatorManager.ApplyHeuristic.
surfnerd Jan 19, 2021
7b3be45
More Board reverts.
surfnerd Jan 19, 2021
a872085
Merge branch 'master' into develop-heuristic-interface
surfnerd Jan 19, 2021
8d91dde
Fix comment doc.
surfnerd Jan 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,18 @@ public struct ActionSpec
/// <summary>
/// The number of continuous actions that an Agent can take.
/// </summary>
public int NumContinuousActions { get { return m_NumContinuousActions; } set { m_NumContinuousActions = value; } }
public int NumContinuousActions { get => m_NumContinuousActions; set => m_NumContinuousActions = value; }

/// <summary>
surfnerd marked this conversation as resolved.
Show resolved Hide resolved
/// The number of branches for discrete actions that an Agent can take.
/// </summary>
public int NumDiscreteActions { get { return BranchSizes == null ? 0 : BranchSizes.Length; } }
public int NumDiscreteActions => BranchSizes?.Length ?? 0;

/// <summary>
/// Get the total number of Discrete Actions that can be taken by calculating the Sum
/// of all of the Discrete Action branch sizes.
/// </summary>
public int SumOfDiscreteBranchSizes { get { return BranchSizes == null ? 0 : BranchSizes.Sum(); } }
public int SumOfDiscreteBranchSizes => BranchSizes?.Sum() ?? 0;

/// <summary>
/// Creates a Continuous <see cref="ActionSpec"/> with the number of actions available.
Expand Down
39 changes: 39 additions & 0 deletions com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,45 @@ public void WriteActionMask()
}
}

/// <summary>
/// Iterates through all of the IActuators in this list and calls their
/// <see cref="IHeuristic.Heuristic"/> method on them, if implemented, with the appropriate
/// <see cref="ActionSegment{T}"/>s depending on their <see cref="ActionSpec"/>.
/// </summary>
public void ExecuteHeuristic()
{
ReadyActuatorsForExecution();
var continuousStart = 0;
var discreteStart = 0;
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
var numContinuousActions = actuator.ActionSpec.NumContinuousActions;
var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions;

var continuousActions = ActionSegment<float>.Empty;
if (numContinuousActions > 0)
{
continuousActions = new ActionSegment<float>(StoredActions.ContinuousActions.Array,
continuousStart,
numContinuousActions);
}

var discreteActions = ActionSegment<int>.Empty;
if (numDiscreteActions > 0)
{
discreteActions = new ActionSegment<int>(StoredActions.DiscreteActions.Array,
discreteStart,
numDiscreteActions);
}

var heuristic = actuator as IHeuristic;
heuristic?.Heuristic(new ActionBuffers(continuousActions, discreteActions));
continuousStart += numContinuousActions;
discreteStart += numDiscreteActions;
}
}

/// <summary>
/// Iterates through all of the IActuators in this list and calls their
/// <see cref="IActionReceiver.OnActionReceived"/> method on them with the appropriate
Expand Down
7 changes: 7 additions & 0 deletions com.unity.ml-agents/Runtime/Actuators/IHeuristic.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace Unity.MLAgents.Actuators
{
public interface IHeuristic
{
void Heuristic(in ActionBuffers actionBuffersOut);
}
}
3 changes: 3 additions & 0 deletions com.unity.ml-agents/Runtime/Actuators/IHeuristic.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ internal ActionBuffers ActionBuffers
/// <param name="actionSpec"></param>
/// <param name="name"></param>
public VectorActuator(IActionReceiver actionReceiver,

ActionSpec actionSpec,
string name = "VectorActuator")
{
Expand Down
40 changes: 30 additions & 10 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ public void CopyActions(ActionBuffers actionBuffers)
"docs/Learning-Environment-Design-Agents.md")]
[Serializable]
[RequireComponent(typeof(BehaviorParameters))]
public partial class Agent : MonoBehaviour, ISerializationCallbackReceiver, IActionReceiver
public partial class Agent : MonoBehaviour, ISerializationCallbackReceiver, IActionReceiver, IHeuristic
{
IPolicy m_Brain;
BehaviorParameters m_PolicyFactory;
Expand Down Expand Up @@ -312,6 +312,11 @@ internal struct AgentParameters
/// </summary>
float[] m_LegacyActionCache;

/// <summary>
/// This is used to avoid allocation of a float array during legacy calls to Heuristic.
/// </summary>
float[] m_LegacyHeuristicCache;

/// <summary>
/// Called when the attached [GameObject] becomes enabled and active.
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
Expand Down Expand Up @@ -429,7 +434,7 @@ public void LazyInitialize()
InitializeActuators();
}

m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), Heuristic);
m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), this);
ResetData();
Initialize();

Expand Down Expand Up @@ -606,7 +611,7 @@ internal void ReloadPolicy()
return;
}
m_Brain?.Dispose();
m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), Heuristic);
m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), this);
}

/// <summary>
Expand Down Expand Up @@ -889,27 +894,30 @@ public virtual void Heuristic(in ActionBuffers actionsOut)
// Disable deprecation warnings so we can call the legacy overload.
#pragma warning disable CS0618

Array.Clear(m_LegacyHeuristicCache, 0, m_LegacyHeuristicCache.Length);
surfnerd marked this conversation as resolved.
Show resolved Hide resolved
// The default implementation of Heuristic calls the
// obsolete version for backward compatibility
switch (m_PolicyFactory.BrainParameters.VectorActionSpaceType)
{
case SpaceType.Continuous:
Heuristic(actionsOut.ContinuousActions.Array);
Heuristic(m_LegacyHeuristicCache);
Array.Copy(m_LegacyHeuristicCache, actionsOut.ContinuousActions.Array, m_LegacyActionCache.Length);
actionsOut.DiscreteActions.Clear();
break;
case SpaceType.Discrete:
var convertedOut = Array.ConvertAll(actionsOut.DiscreteActions.Array, x => (float)x);
Heuristic(convertedOut);
Heuristic(m_LegacyHeuristicCache);
var discreteActionSegment = actionsOut.DiscreteActions;
for (var i = 0; i < actionsOut.DiscreteActions.Length; i++)
{
discreteActionSegment[i] = (int)convertedOut[i];
discreteActionSegment[i] = (int)m_LegacyHeuristicCache[i];
}
actionsOut.ContinuousActions.Clear();
break;
}
#pragma warning restore CS0618

// Send heuristic buffers to actuators if they implement IHeuristic
m_ActuatorManager.ExecuteHeuristic();
surfnerd marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
Expand Down Expand Up @@ -995,7 +1003,12 @@ void InitializeActuators()
var param = m_PolicyFactory.BrainParameters;
m_VectorActuator = new VectorActuator(this, param.ActionSpec);
m_ActuatorManager = new ActuatorManager(attachedActuators.Length + 1);
m_LegacyActionCache = new float[m_VectorActuator.TotalNumberOfActions()];
#pragma warning disable 618
m_LegacyActionCache = new float[param.VectorActionSpaceType == SpaceType.Continuous
surfnerd marked this conversation as resolved.
Show resolved Hide resolved
? param.ActionSpec.NumContinuousActions
: param.ActionSpec.NumDiscreteActions];
m_LegacyHeuristicCache = new float[m_LegacyActionCache.Length];
#pragma warning restore 618

m_ActuatorManager.Add(m_VectorActuator);

Expand Down Expand Up @@ -1241,11 +1254,18 @@ public virtual void OnActionReceived(ActionBuffers actions)

if (!actions.ContinuousActions.IsEmpty())
{
m_LegacyActionCache = actions.ContinuousActions.Array;
Array.Copy(actions.ContinuousActions.Array,
m_LegacyActionCache,
actionSpec.NumContinuousActions);
}
else
{
m_LegacyActionCache = Array.ConvertAll(actions.DiscreteActions.Array, x => (float)x);
var discreteArray = Array.ConvertAll(actions.DiscreteActions.Array, x => (float)x);
surfnerd marked this conversation as resolved.
Show resolved Hide resolved
Array.Copy(discreteArray,
0,
m_LegacyActionCache,
actionSpec.NumContinuousActions,
actionSpec.NumDiscreteActions);
}
// Disable deprecation warnings so we can call the legacy overload.
#pragma warning disable CS0618
Expand Down
3 changes: 1 addition & 2 deletions com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public string FullyQualifiedBehaviorName
get { return m_BehaviorName + "?team=" + TeamId; }
}

internal IPolicy GeneratePolicy(ActionSpec actionSpec, HeuristicPolicy.ActionGenerator heuristic)
internal IPolicy GeneratePolicy(ActionSpec actionSpec, IHeuristic heuristic)
{
switch (m_BehaviorType)
{
Expand Down Expand Up @@ -241,6 +241,5 @@ internal void UpdateAgentPolicy()
}
agent.ReloadPolicy();
}

}
}
7 changes: 3 additions & 4 deletions com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ namespace Unity.MLAgents.Policies
/// </summary>
internal class HeuristicPolicy : IPolicy
{
public delegate void ActionGenerator(in ActionBuffers actionBuffers);
ActionGenerator m_Heuristic;
IHeuristic m_Heuristic;
ActionBuffers m_ActionBuffers;
bool m_Done;
bool m_DecisionRequested;
Expand All @@ -24,7 +23,7 @@ internal class HeuristicPolicy : IPolicy


/// <inheritdoc />
public HeuristicPolicy(ActionGenerator heuristic, ActionSpec actionSpec)
public HeuristicPolicy(IHeuristic heuristic, ActionSpec actionSpec)
{
m_Heuristic = heuristic;
var numContinuousActions = actionSpec.NumContinuousActions;
Expand All @@ -47,7 +46,7 @@ public ref readonly ActionBuffers DecideAction()
{
if (!m_Done && m_DecisionRequested)
{
m_Heuristic.Invoke(m_ActionBuffers);
m_Heuristic.Heuristic(m_ActionBuffers);
}
m_DecisionRequested = false;
return ref m_ActionBuffers;
Expand Down
6 changes: 3 additions & 3 deletions com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
namespace Unity.MLAgents.Tests
{
[TestFixture]
public class BehaviorParameterTests
public class BehaviorParameterTests : IHeuristic
{
static void DummyHeuristic(in ActionBuffers actionsOut)
public void Heuristic(in ActionBuffers actionsOut)
{
// No-op
}
Expand All @@ -23,7 +23,7 @@ public void TestNoModelInferenceOnlyThrows()

Assert.Throws<UnityAgentsException>(() =>
{
bp.GeneratePolicy(actionSpec, DummyHeuristic);
bp.GeneratePolicy(actionSpec, this);
});
}
}
Expand Down