Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the Agent reset immediately after Done #3291

Merged
merged 12 commits into from
Jan 28, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public void TestStoreInitalize()
reward = 1f,
actionMasks = new[] { false, true },
done = true,
id = 5,
episodeId = 5,
maxStepReached = true,
storedVectorActions = new[] { 0f, 1f },
};
Expand Down
47 changes: 16 additions & 31 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ public AgentInfo _Info
}
}

public bool IsDone()
{
return (bool)typeof(Agent).GetField("m_Done", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this);
}
public int initializeAgentCalls;
public int collectObservationsCalls;
public int agentActionCalls;
Expand Down Expand Up @@ -191,8 +187,6 @@ public void TestAgent()
agentGo2.AddComponent<TestAgent>();
var agent2 = agentGo2.GetComponent<TestAgent>();

Assert.AreEqual(false, agent1.IsDone());
Assert.AreEqual(false, agent2.IsDone());
Assert.AreEqual(0, agent1.agentResetCalls);
Assert.AreEqual(0, agent2.agentResetCalls);
Assert.AreEqual(0, agent1.initializeAgentCalls);
Expand All @@ -206,8 +200,6 @@ public void TestAgent()
agentEnableMethod?.Invoke(agent2, new object[] { });
agentEnableMethod?.Invoke(agent1, new object[] { });

Assert.AreEqual(false, agent1.IsDone());
Assert.AreEqual(false, agent2.IsDone());
// agent1 was not enabled when the academy started
// The agents have been initialized
Assert.AreEqual(0, agent1.agentResetCalls);
Expand Down Expand Up @@ -422,18 +414,14 @@ public void TestAgent()
if (i % 11 == 5)
{
agent1.Done();
numberAgent1Reset += 1;
}
// Resetting agent 2 regularly
if (i % 13 == 3)
{
if (!(agent2.IsDone()))
{
// If the agent was already reset before the request decision
// We should not reset again
agent2.Done();
numberAgent2Reset += 1;
agent2StepSinceReset = 0;
}
agent2.Done();
numberAgent2Reset += 1;
agent2StepSinceReset = 0;
}
// Request a decision for agent 2 regularly
if (i % 3 == 2)
Expand All @@ -445,16 +433,9 @@ public void TestAgent()
// Request an action without decision regularly
agent2.RequestAction();
}
if (agent1.IsDone())
{
numberAgent1Reset += 1;
}

acaStepsSinceReset += 1;
agent2StepSinceReset += 1;
//Agent 1 is only initialized at step 2
if (i < 2)
{ }
aca.EnvironmentStep();
}
}
Expand Down Expand Up @@ -500,19 +481,23 @@ public void TestCumulativeReward()
var j = 0;
for (var i = 0; i < 500; i++)
{
if (i % 21 == 0)
{
j = 0;
}
else
{
j++;
}
agent2.RequestAction();
Assert.LessOrEqual(Mathf.Abs(j * 0.1f + j * 10f - agent1.GetCumulativeReward()), 0.05f);
Assert.LessOrEqual(Mathf.Abs(j * 10.1f - agent1.GetCumulativeReward()), 0.05f);
Assert.LessOrEqual(Mathf.Abs(i * 0.1f - agent2.GetCumulativeReward()), 0.05f);


aca.EnvironmentStep();
agent1.AddReward(10f);
aca.EnvironmentStep();



if ((i % 21 == 0) && (i > 0))
{
j = 0;
}
j++;
}
}
}
Expand Down
10 changes: 0 additions & 10 deletions UnitySDK/Assets/ML-Agents/Scripts/Academy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,6 @@ public bool IsCommunicatorOn
// in addition to aligning on the step count of the global episode.
public event System.Action<int> AgentSetStatus;

// Signals to all the agents at each environment step so they can reset
// if their flag has been set to done (assuming the agent has requested a
// decision).
public event System.Action AgentResetIfDone;

// Signals to all the agents at each environment step so they can send
// their state to their Policy if they have requested a decision.
public event System.Action AgentSendState;
Expand Down Expand Up @@ -314,7 +309,6 @@ void ResetActions()
DecideAction = () => { };
DestroyAction = () => { };
AgentSetStatus = i => { };
AgentResetIfDone = () => { };
AgentSendState = () => { };
AgentAct = () => { };
AgentForceReset = () => { };
Expand Down Expand Up @@ -392,10 +386,6 @@ public void EnvironmentStep()

AgentSetStatus?.Invoke(m_StepCount);

using (TimerStack.Instance.Scoped("AgentResetIfDone"))
{
AgentResetIfDone?.Invoke();
}

using (TimerStack.Instance.Scoped("AgentSendState"))
{
Expand Down
92 changes: 29 additions & 63 deletions UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public struct AgentInfo
/// Unique identifier each agent receives at initialization. It is used
/// to separate between different agents in the environment.
/// </summary>
public int id;
public int episodeId;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update this comment too.

}

/// <summary>
Expand Down Expand Up @@ -148,13 +148,6 @@ public abstract class Agent : MonoBehaviour
/// Whether or not the agent requests a decision.
bool m_RequestDecision;

/// Whether or not the agent has completed the episode. This may be due
/// to either reaching a success or fail state, or reaching the maximum
/// number of steps (i.e. timing out).
bool m_Done;

/// Whether or not the agent reached the maximum number of steps.
bool m_MaxStepReached;

/// Keeps track of the number of steps taken by the agent in this episode.
/// Note that this value is different for each agent, and may not overlap
Expand All @@ -164,7 +157,7 @@ public abstract class Agent : MonoBehaviour

/// Unique identifier each agent receives at initialization. It is used
/// to separate between different agents in the environment.
int m_Id;
int m_EpisodeId;
vincentpierre marked this conversation as resolved.
Show resolved Hide resolved

/// Keeps track of the actions that are masked at each step.
ActionMasker m_ActionMasker;
Expand All @@ -190,7 +183,7 @@ public abstract class Agent : MonoBehaviour
/// becomes enabled or active.
void OnEnable()
{
m_Id = gameObject.GetInstanceID();
m_EpisodeId = EpisodeIdCounter.GetEpisodeId();
OnEnableHelper();

m_Recorder = GetComponent<DemonstrationRecorder>();
Expand All @@ -204,7 +197,6 @@ void OnEnableHelper()
m_Action = new AgentAction();
sensors = new List<ISensor>();

Academy.Instance.AgentResetIfDone += ResetIfDone;
Academy.Instance.AgentSendState += SendInfo;
Academy.Instance.DecideAction += DecideAction;
Academy.Instance.AgentAct += AgentStep;
Expand All @@ -224,7 +216,6 @@ void OnDisable()
// We don't want to even try, because this will lazily create a new Academy!
if (Academy.IsInitialized)
{
Academy.Instance.AgentResetIfDone -= ResetIfDone;
Academy.Instance.AgentSendState -= SendInfo;
Academy.Instance.DecideAction -= DecideAction;
Academy.Instance.AgentAct -= AgentStep;
Expand All @@ -234,12 +225,14 @@ void OnDisable()
m_Brain?.Dispose();
}

void NotifyAgentDone()
void NotifyAgentDone(bool maxStepReached = false)
{
m_Info.done = true;
m_Info.maxStepReached = maxStepReached;
// Request the last decision with no callbacks
// We request a decision so Python knows the Agent is disabled
m_Brain?.RequestDecision(m_Info, sensors, (a) => { });
m_EpisodeId = EpisodeIdCounter.GetEpisodeId();
vincentpierre marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
Expand Down Expand Up @@ -322,7 +315,12 @@ public float GetCumulativeReward()
/// </summary>
public void Done()
{
m_Done = true;
NotifyAgentDone();
_AgentReset();
m_RequestAction = false;
m_RequestDecision = false;
m_Reward = 0f;
m_CumulativeReward = 0f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like this could be moved into NotifyAgentDone() (or maybe combine Done and NotifyAgentDone, unless you don't want to the user to set maxStepReached)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved some things around

}

/// <summary>
Expand All @@ -342,28 +340,6 @@ public void RequestAction()
m_RequestAction = true;
}

/// <summary>
/// Indicates if the agent has reached his maximum number of steps.
/// </summary>
/// <returns>
/// <c>true</c>, if max step reached was reached, <c>false</c> otherwise.
/// </returns>
public bool IsMaxStepReached()
{
return m_MaxStepReached;
}

/// <summary>
/// Indicates if the agent is done
/// </summary>
/// <returns>
/// <c>true</c>, if the agent is done, <c>false</c> otherwise.
/// </returns>
public bool IsDone()
{
return m_Done;
}

/// Helper function that resets all the data structures associated with
/// the agent. Typically used when the agent is being initialized or reset
/// at the end of an episode.
Expand Down Expand Up @@ -489,9 +465,9 @@ void SendInfoToBrain()
m_Info.actionMasks = m_ActionMasker.GetMask();

m_Info.reward = m_Reward;
m_Info.done = m_Done;
m_Info.maxStepReached = m_MaxStepReached;
m_Info.id = m_Id;
m_Info.done = false;
m_Info.maxStepReached = false;
m_Info.episodeId = m_EpisodeId;

m_Brain.RequestDecision(m_Info, sensors, UpdateAgentAction);

Expand Down Expand Up @@ -742,51 +718,41 @@ protected float ScaleAction(float rawAction, float min, float max)
}


/// Signals the agent that it must reset if its done flag is set to true.
void ResetIfDone()
{
if (m_Done)
{
_AgentReset();
}
}

/// <summary>
/// Signals the agent that it must sent its decision to the brain.
/// </summary>
void SendInfo()
{
// If the Agent is done, it has just reset and thus requires a new decision
if (m_RequestDecision || m_Done)
if (m_RequestDecision)
{
SendInfoToBrain();
m_Reward = 0f;
if (m_Done)
{
m_CumulativeReward = 0f;
}
m_Done = false;
m_MaxStepReached = false;
m_RequestDecision = false;
}
}

/// Used by the brain to make the agent perform a step.
void AgentStep()
{
if ((m_RequestAction) && (m_Brain != null))
if ((m_StepCount >= maxStep - 1) && (maxStep > 0))
{
NotifyAgentDone(true);
_AgentReset();
m_RequestAction = false;
AgentAction(m_Action.vectorActions);
m_RequestDecision = false;
m_Reward = 0f;
m_CumulativeReward = 0f;
}

if ((m_StepCount >= maxStep) && (maxStep > 0))
else
{
m_MaxStepReached = true;
Done();
m_StepCount += 1;
}
if ((m_RequestAction) && (m_Brain != null))
{
m_RequestAction = false;
AgentAction(m_Action.vectorActions);
}

m_StepCount += 1;
}

void DecideAction()
Expand Down
11 changes: 11 additions & 0 deletions UnitySDK/Assets/ML-Agents/Scripts/EpisodeIdCounter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
namespace MLAgents
{
public static class EpisodeIdCounter
{
private static int Counter;
public static int GetEpisodeId()
{
return Counter++;
}
}
}
11 changes: 11 additions & 0 deletions UnitySDK/Assets/ML-Agents/Scripts/EpisodeIdCounter.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai)
Reward = ai.reward,
MaxStepReached = ai.maxStepReached,
Done = ai.done,
Id = ai.id,
Id = ai.episodeId,
};

if (ai.actionMasks != null)
Expand Down
2 changes: 1 addition & 1 deletion UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ public void PutObservations(string brainKey, AgentInfo info, List<ISensor> senso
{
m_ActionCallbacks[brainKey] = new List<IdCallbackPair>();
}
m_ActionCallbacks[brainKey].Add(new IdCallbackPair { AgentId = info.id, Callback = action });
m_ActionCallbacks[brainKey].Add(new IdCallbackPair { AgentId = info.episodeId, Callback = action });
}

/// <summary>
Expand Down
Loading