diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md
index 4c5aafd9a9..265f2bed2f 100755
--- a/com.unity.ml-agents/CHANGELOG.md
+++ b/com.unity.ml-agents/CHANGELOG.md
@@ -47,6 +47,8 @@ different sizes using the same model. For a summary of the interface changes, pl
depend on the previous behavior, you can explicitly set the Agent's `InferenceDevice` to `InferenceDevice.CPU`. (#5175)
- Added support for `Goal Signal` as a type of observation. Trainers can now use HyperNetworks to process `Goal Signal`. Trainers with HyperNetworks are more effective at solving multiple tasks. (#5142, #5159, #5149)
- Modified the [GridWorld environment](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Learning-Environment-Examples.md#gridworld) to use the new `Goal Signal` feature. (#5193)
+- `DecisionRequester.ShouldRequestDecision()` and `ShouldRequestAction()`methods were added. These are used to
+determine whether `Agent.RequestDecision()` and `Agent.RequestAction()` are called (respectively). (#5223)
- `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222)
#### ml-agents / ml-agents-envs / gym-unity (Python)
diff --git a/com.unity.ml-agents/Runtime/DecisionRequester.cs b/com.unity.ml-agents/Runtime/DecisionRequester.cs
index fc7cc55afd..50a489010c 100644
--- a/com.unity.ml-agents/Runtime/DecisionRequester.cs
+++ b/com.unity.ml-agents/Runtime/DecisionRequester.cs
@@ -42,6 +42,14 @@ public class DecisionRequester : MonoBehaviour
[NonSerialized]
Agent m_Agent;
+ ///
+ /// Get the Agent attached to the DecisionRequester.
+ ///
+ public Agent Agent
+ {
+ get => m_Agent;
+ }
+
internal void Awake()
{
m_Agent = gameObject.GetComponent();
@@ -57,6 +65,17 @@ void OnDestroy()
}
}
+ ///
+ /// Information about Academy step used to make decisions about whether to request a decision.
+ ///
+ public struct DecisionRequestContext
+ {
+ ///
+ /// The current step count of the Academy, equivalent to Academy.StepCount.
+ ///
+ public int AcademyStepCount;
+ }
+
///
/// Method that hooks into the Academy in order inform the Agent on whether or not it should request a
/// decision, and whether or not it should take actions between decisions.
@@ -64,14 +83,40 @@ void OnDestroy()
/// The current step count of the academy.
void MakeRequests(int academyStepCount)
{
- if (academyStepCount % DecisionPeriod == 0)
+ var context = new DecisionRequestContext
+ {
+ AcademyStepCount = academyStepCount
+ };
+
+ if (ShouldRequestDecision(context))
{
m_Agent?.RequestDecision();
}
- if (TakeActionsBetweenDecisions)
+
+ if (ShouldRequestAction(context))
{
m_Agent?.RequestAction();
}
}
+
+ ///
+ /// Whether Agent.RequestDecision should be called on this update step.
+ ///
+ ///
+ ///
+ protected virtual bool ShouldRequestDecision(DecisionRequestContext context)
+ {
+ return context.AcademyStepCount % DecisionPeriod == 0;
+ }
+
+ ///
+ /// Whether Agent.RequestAction should be called on this update step.
+ ///
+ ///
+ ///
+ protected virtual bool ShouldRequestAction(DecisionRequestContext context)
+ {
+ return TakeActionsBetweenDecisions;
+ }
}
}
diff --git a/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs b/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs
index 3633fbf011..f3635b64a7 100644
--- a/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs
+++ b/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs
@@ -1,6 +1,7 @@
using System.Collections.Generic;
using Unity.MLAgents.Sensors;
using NUnit.Framework;
+using Unity.MLAgents;
using UnityEngine;
namespace Unity.MLAgentsExamples
@@ -76,5 +77,25 @@ public void CheckSetupRayPerceptionSensorComponent()
Assert.AreEqual(outputs.RayOutputs.Length, 2*sensorComponent.RaysPerDirection + 1);
}
#endif
+
+ ///
+ /// Make sure we can inherit from DecisionRequester and override some logic.
+ ///
+ class CustomDecisionRequester : DecisionRequester
+ {
+ ///
+ /// Example logic. If the killswitch flag is set, the Agent never requests a decision.
+ ///
+ public bool KillswitchEnabled;
+
+ public CustomDecisionRequester()
+ {
+ }
+
+ protected override bool ShouldRequestDecision(DecisionRequestContext context)
+ {
+ return !KillswitchEnabled && base.ShouldRequestDecision(context);
+ }
+ }
}
}