Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLA-1159] Add virtual methods to DecisionRequester #5223

Merged
merged 5 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ different sizes using the same model. For a summary of the interface changes, pl
depend on the previous behavior, you can explicitly set the Agent's `InferenceDevice` to `InferenceDevice.CPU`. (#5175)
- Added support for `Goal Signal` as a type of observation. Trainers can now use HyperNetworks to process `Goal Signal`. Trainers with HyperNetworks are more effective at solving multiple tasks. (#5142, #5159, #5149)
- Modified the [GridWorld environment](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Learning-Environment-Examples.md#gridworld) to use the new `Goal Signal` feature. (#5193)
- `DecisionRequester.ShouldRequestDecision()` and `ShouldRequestAction()`methods were added. These are used to
determine whether `Agent.RequestDecision()` and `Agent.RequestAction()` are called (respectively). (#5223)
- `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222)

#### ml-agents / ml-agents-envs / gym-unity (Python)
Expand Down
49 changes: 47 additions & 2 deletions com.unity.ml-agents/Runtime/DecisionRequester.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ public class DecisionRequester : MonoBehaviour
[NonSerialized]
Agent m_Agent;

/// <summary>
/// Get the Agent attached to the DecisionRequester.
/// </summary>
public Agent Agent
{
get => m_Agent;
}

internal void Awake()
{
m_Agent = gameObject.GetComponent<Agent>();
Expand All @@ -57,21 +65,58 @@ void OnDestroy()
}
}

/// <summary>
/// Information about Academy step used to make decisions about whether to request a decision.
/// </summary>
public struct DecisionRequestContext
{
/// <summary>
/// The current step count of the Academy, equivalent to Academy.StepCount.
/// </summary>
public int AcademyStepCount;
}

/// <summary>
/// Method that hooks into the Academy in order inform the Agent on whether or not it should request a
/// decision, and whether or not it should take actions between decisions.
/// </summary>
/// <param name="academyStepCount">The current step count of the academy.</param>
void MakeRequests(int academyStepCount)
{
if (academyStepCount % DecisionPeriod == 0)
var context = new DecisionRequestContext
{
AcademyStepCount = academyStepCount
};

if (ShouldRequestDecision(context))
{
m_Agent?.RequestDecision();
}
if (TakeActionsBetweenDecisions)

if (ShouldRequestAction(context))
{
m_Agent?.RequestAction();
}
}

/// <summary>
/// Whether Agent.RequestDecision should be called on this update step.
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
protected virtual bool ShouldRequestDecision(DecisionRequestContext context)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DecisionRequestContext is maybe a little over-engineered, but it'll let us extend in the future if needed without changing the signature.

{
return context.AcademyStepCount % DecisionPeriod == 0;
}

/// <summary>
/// Whether Agent.RequestAction should be called on this update step.
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
protected virtual bool ShouldRequestAction(DecisionRequestContext context)
{
return TakeActionsBetweenDecisions;
}
}
}
21 changes: 21 additions & 0 deletions com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Collections.Generic;
using Unity.MLAgents.Sensors;
using NUnit.Framework;
using Unity.MLAgents;
using UnityEngine;

namespace Unity.MLAgentsExamples
Expand Down Expand Up @@ -76,5 +77,25 @@ public void CheckSetupRayPerceptionSensorComponent()
Assert.AreEqual(outputs.RayOutputs.Length, 2*sensorComponent.RaysPerDirection + 1);
}
#endif

/// <summary>
/// Make sure we can inherit from DecisionRequester and override some logic.
/// </summary>
class CustomDecisionRequester : DecisionRequester
{
/// <summary>
/// Example logic. If the killswitch flag is set, the Agent never requests a decision.
/// </summary>
public bool KillswitchEnabled;

public CustomDecisionRequester()
{
}

protected override bool ShouldRequestDecision(DecisionRequestContext context)
{
return !KillswitchEnabled && base.ShouldRequestDecision(context);
}
}
}
}