Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Unity.Barracuda;
using System.Collections.Generic;
using System.Diagnostics;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Sensors;
Expand Down Expand Up @@ -59,7 +60,14 @@ internal class BarracudaPolicy : IPolicy
/// </summary>
private bool m_AnalyticsSent;

/// <inheritdoc />
/// <summary>
/// Instantiate a BarracudaPolicy with the necessary objects for it to run.
/// </summary>
/// <param name="actionSpec">The action spec of the behavior.</param>
/// <param name="actuators">The actuators used for this behavior.</param>
/// <param name="model">The Neural Network to use.</param>
/// <param name="inferenceDevice">Which device Barracuda will run on.</param>
/// <param name="behaviorName">The name of the behavior.</param>
public BarracudaPolicy(
ActionSpec actionSpec,
IList<IActuator> actuators,
Expand All @@ -77,6 +85,14 @@ string behaviorName

/// <inheritdoc />
public void RequestDecision(AgentInfo info, List<ISensor> sensors)
{
SendAnalytics(sensors);
m_AgentId = info.episodeId;
m_ModelRunner?.PutObservations(info, sensors);
}

[Conditional("MLA_UNITY_ANALYTICS_MODULE")]
void SendAnalytics(IList<ISensor> sensors)
{
if (!m_AnalyticsSent)
{
Expand All @@ -90,8 +106,6 @@ public void RequestDecision(AgentInfo info, List<ISensor> sensors)
m_Actuators
);
}
m_AgentId = info.episodeId;
m_ModelRunner?.PutObservations(info, sensors);
}

/// <inheritdoc />
Expand Down
14 changes: 10 additions & 4 deletions com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

using System.Collections.Generic;
using System.Diagnostics;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Analytics;
Expand All @@ -17,7 +17,7 @@ internal class RemotePolicy : IPolicy
string m_FullyQualifiedBehaviorName;
ActionSpec m_ActionSpec;
ActionBuffers m_LastActionBuffer;
private bool m_AnalyticsSent = false;
bool m_AnalyticsSent;

internal ICommunicator m_Communicator;

Expand All @@ -41,6 +41,14 @@ public RemotePolicy(

/// <inheritdoc />
public void RequestDecision(AgentInfo info, List<ISensor> sensors)
{
SendAnalytics(sensors);
m_AgentId = info.episodeId;
m_Communicator?.PutObservations(m_FullyQualifiedBehaviorName, info, sensors);
}

[Conditional("MLA_UNITY_ANALYTICS_MODULE")]
void SendAnalytics(IList<ISensor> sensors)
{
if (!m_AnalyticsSent)
{
Expand All @@ -52,8 +60,6 @@ public void RequestDecision(AgentInfo info, List<ISensor> sensors)
m_Actuators
);
}
m_AgentId = info.episodeId;
m_Communicator?.PutObservations(m_FullyQualifiedBehaviorName, info, sensors);
}

/// <inheritdoc />
Expand Down