diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 4c5aafd9a9..265f2bed2f 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -47,6 +47,8 @@ different sizes using the same model. For a summary of the interface changes, pl depend on the previous behavior, you can explicitly set the Agent's `InferenceDevice` to `InferenceDevice.CPU`. (#5175) - Added support for `Goal Signal` as a type of observation. Trainers can now use HyperNetworks to process `Goal Signal`. Trainers with HyperNetworks are more effective at solving multiple tasks. (#5142, #5159, #5149) - Modified the [GridWorld environment](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Learning-Environment-Examples.md#gridworld) to use the new `Goal Signal` feature. (#5193) +- `DecisionRequester.ShouldRequestDecision()` and `ShouldRequestAction()`methods were added. These are used to +determine whether `Agent.RequestDecision()` and `Agent.RequestAction()` are called (respectively). (#5223) - `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222) #### ml-agents / ml-agents-envs / gym-unity (Python) diff --git a/com.unity.ml-agents/Runtime/DecisionRequester.cs b/com.unity.ml-agents/Runtime/DecisionRequester.cs index fc7cc55afd..50a489010c 100644 --- a/com.unity.ml-agents/Runtime/DecisionRequester.cs +++ b/com.unity.ml-agents/Runtime/DecisionRequester.cs @@ -42,6 +42,14 @@ public class DecisionRequester : MonoBehaviour [NonSerialized] Agent m_Agent; + /// + /// Get the Agent attached to the DecisionRequester. + /// + public Agent Agent + { + get => m_Agent; + } + internal void Awake() { m_Agent = gameObject.GetComponent(); @@ -57,6 +65,17 @@ void OnDestroy() } } + /// + /// Information about Academy step used to make decisions about whether to request a decision. + /// + public struct DecisionRequestContext + { + /// + /// The current step count of the Academy, equivalent to Academy.StepCount. + /// + public int AcademyStepCount; + } + /// /// Method that hooks into the Academy in order inform the Agent on whether or not it should request a /// decision, and whether or not it should take actions between decisions. @@ -64,14 +83,40 @@ void OnDestroy() /// The current step count of the academy. void MakeRequests(int academyStepCount) { - if (academyStepCount % DecisionPeriod == 0) + var context = new DecisionRequestContext + { + AcademyStepCount = academyStepCount + }; + + if (ShouldRequestDecision(context)) { m_Agent?.RequestDecision(); } - if (TakeActionsBetweenDecisions) + + if (ShouldRequestAction(context)) { m_Agent?.RequestAction(); } } + + /// + /// Whether Agent.RequestDecision should be called on this update step. + /// + /// + /// + protected virtual bool ShouldRequestDecision(DecisionRequestContext context) + { + return context.AcademyStepCount % DecisionPeriod == 0; + } + + /// + /// Whether Agent.RequestAction should be called on this update step. + /// + /// + /// + protected virtual bool ShouldRequestAction(DecisionRequestContext context) + { + return TakeActionsBetweenDecisions; + } } } diff --git a/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs b/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs index 3633fbf011..f3635b64a7 100644 --- a/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs +++ b/com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using Unity.MLAgents.Sensors; using NUnit.Framework; +using Unity.MLAgents; using UnityEngine; namespace Unity.MLAgentsExamples @@ -76,5 +77,25 @@ public void CheckSetupRayPerceptionSensorComponent() Assert.AreEqual(outputs.RayOutputs.Length, 2*sensorComponent.RaysPerDirection + 1); } #endif + + /// + /// Make sure we can inherit from DecisionRequester and override some logic. + /// + class CustomDecisionRequester : DecisionRequester + { + /// + /// Example logic. If the killswitch flag is set, the Agent never requests a decision. + /// + public bool KillswitchEnabled; + + public CustomDecisionRequester() + { + } + + protected override bool ShouldRequestDecision(DecisionRequestContext context) + { + return !KillswitchEnabled && base.ShouldRequestDecision(context); + } + } } }