Skip to content

Commit

Permalink
Develop observation collector (#3352)
Browse files Browse the repository at this point in the history
* Add the VectorSensor to the CollectObservation call

* Example of API change for BalanceBall

* Modified the Examples

* Changes to the migrating doc

* Editing the docs

* Update docs/Learning-Environment-Design-Agents.md

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* Update docs/Migrating.md

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* Update docs/Migrating.md

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* Update docs/Getting-Started-with-Balance-Ball.md

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* addressing comments

* Removed the MLAgents.Sensor namespace

* Removing the MLAgents.Sensor namespace from the tests

* Editing the migrating docs

Co-authored-by: Chris Elion <celion@gmail.com>
  • Loading branch information
vincentpierre and chriselion authored Feb 7, 2020
1 parent ba10540 commit 6551974
Show file tree
Hide file tree
Showing 60 changed files with 162 additions and 255 deletions.
10 changes: 5 additions & 5 deletions Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ public override void InitializeAgent()
SetResetParameters();
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
AddVectorObs(gameObject.transform.rotation.z);
AddVectorObs(gameObject.transform.rotation.x);
AddVectorObs(ball.transform.position - gameObject.transform.position);
AddVectorObs(m_BallRb.velocity);
sensor.AddObservation(gameObject.transform.rotation.z);
sensor.AddObservation(gameObject.transform.rotation.x);
sensor.AddObservation(ball.transform.position - gameObject.transform.position);
sensor.AddObservation(m_BallRb.velocity);
}

public override void AgentAction(float[] vectorAction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ public override void InitializeAgent()
SetResetParameters();
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
AddVectorObs(gameObject.transform.rotation.z);
AddVectorObs(gameObject.transform.rotation.x);
AddVectorObs((ball.transform.position - gameObject.transform.position));
sensor.AddObservation(gameObject.transform.rotation.z);
sensor.AddObservation(gameObject.transform.rotation.x);
sensor.AddObservation((ball.transform.position - gameObject.transform.position));
}

public override void AgentAction(float[] vectorAction)
Expand Down
4 changes: 2 additions & 2 deletions Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ public override void InitializeAgent()
{
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
AddVectorObs(m_Position, 20);
sensor.AddOneHotObservation(m_Position, 20);
}

public override void AgentAction(float[] vectorAction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ public override void InitializeAgent()
SetResetParameters();
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
AddVectorObs(gameObject.transform.localPosition);
AddVectorObs(target.transform.localPosition);
sensor.AddObservation(gameObject.transform.localPosition);
sensor.AddObservation(target.transform.localPosition);
}

public override void AgentAction(float[] vectorAction)
Expand Down
30 changes: 15 additions & 15 deletions Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,29 +72,29 @@ public override void InitializeAgent()
/// <summary>
/// Add relevant information on each body part to observations.
/// </summary>
public void CollectObservationBodyPart(BodyPart bp)
public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
{
var rb = bp.rb;
AddVectorObs(bp.groundContact.touchingGround ? 1 : 0); // Whether the bp touching the ground
sensor.AddObservation(bp.groundContact.touchingGround ? 1 : 0); // Whether the bp touching the ground

var velocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.velocity);
AddVectorObs(velocityRelativeToLookRotationToTarget);
sensor.AddObservation(velocityRelativeToLookRotationToTarget);

var angularVelocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.angularVelocity);
AddVectorObs(angularVelocityRelativeToLookRotationToTarget);
sensor.AddObservation(angularVelocityRelativeToLookRotationToTarget);

if (bp.rb.transform != body)
{
var localPosRelToBody = body.InverseTransformPoint(rb.position);
AddVectorObs(localPosRelToBody);
AddVectorObs(bp.currentXNormalizedRot); // Current x rot
AddVectorObs(bp.currentYNormalizedRot); // Current y rot
AddVectorObs(bp.currentZNormalizedRot); // Current z rot
AddVectorObs(bp.currentStrength / m_JdController.maxJointForceLimit);
sensor.AddObservation(localPosRelToBody);
sensor.AddObservation(bp.currentXNormalizedRot); // Current x rot
sensor.AddObservation(bp.currentYNormalizedRot); // Current y rot
sensor.AddObservation(bp.currentZNormalizedRot); // Current z rot
sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit);
}
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
m_JdController.GetCurrentJointForces();

Expand All @@ -106,21 +106,21 @@ public override void CollectObservations()
RaycastHit hit;
if (Physics.Raycast(body.position, Vector3.down, out hit, 10.0f))
{
AddVectorObs(hit.distance);
sensor.AddObservation(hit.distance);
}
else
AddVectorObs(10.0f);
sensor.AddObservation(10.0f);

// Forward & up to help with orientation
var bodyForwardRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(body.forward);
AddVectorObs(bodyForwardRelativeToLookRotationToTarget);
sensor.AddObservation(bodyForwardRelativeToLookRotationToTarget);

var bodyUpRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(body.up);
AddVectorObs(bodyUpRelativeToLookRotationToTarget);
sensor.AddObservation(bodyUpRelativeToLookRotationToTarget);

foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
{
CollectObservationBodyPart(bodyPart);
CollectObservationBodyPart(bodyPart, sensor);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ public override void InitializeAgent()
SetResetParameters();
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
if (useVectorObs)
{
var localVelocity = transform.InverseTransformDirection(m_AgentRb.velocity);
AddVectorObs(localVelocity.x);
AddVectorObs(localVelocity.z);
AddVectorObs(System.Convert.ToInt32(m_Frozen));
AddVectorObs(System.Convert.ToInt32(m_Shoot));
sensor.AddObservation(localVelocity.x);
sensor.AddObservation(localVelocity.z);
sensor.AddObservation(System.Convert.ToInt32(m_Frozen));
sensor.AddObservation(System.Convert.ToInt32(m_Shoot));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public override void InitializeAgent()
{
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
// There are no numeric observations to collect as this environment uses visual
// observations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ public override void InitializeAgent()
m_GroundMaterial = m_GroundRenderer.material;
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
if (useVectorObs)
{
AddVectorObs(GetStepCount() / (float)maxStep);
sensor.AddObservation(GetStepCount() / (float)maxStep);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ public override void InitializeAgent()
m_SwitchLogic = areaSwitch.GetComponent<PyramidSwitch>();
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
if (useVectorObs)
{
AddVectorObs(m_SwitchLogic.GetState());
AddVectorObs(transform.InverseTransformDirection(m_AgentRb.velocity));
sensor.AddObservation(m_SwitchLogic.GetState());
sensor.AddObservation(transform.InverseTransformDirection(m_AgentRb.velocity));
}
}

Expand Down
24 changes: 12 additions & 12 deletions Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,22 @@ public override void InitializeAgent()
/// We collect the normalized rotations, angularal velocities, and velocities of both
/// limbs of the reacher as well as the relative position of the target and hand.
/// </summary>
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
AddVectorObs(pendulumA.transform.localPosition);
AddVectorObs(pendulumA.transform.rotation);
AddVectorObs(m_RbA.angularVelocity);
AddVectorObs(m_RbA.velocity);
sensor.AddObservation(pendulumA.transform.localPosition);
sensor.AddObservation(pendulumA.transform.rotation);
sensor.AddObservation(m_RbA.angularVelocity);
sensor.AddObservation(m_RbA.velocity);

AddVectorObs(pendulumB.transform.localPosition);
AddVectorObs(pendulumB.transform.rotation);
AddVectorObs(m_RbB.angularVelocity);
AddVectorObs(m_RbB.velocity);
sensor.AddObservation(pendulumB.transform.localPosition);
sensor.AddObservation(pendulumB.transform.rotation);
sensor.AddObservation(m_RbB.angularVelocity);
sensor.AddObservation(m_RbB.velocity);

AddVectorObs(goal.transform.localPosition);
AddVectorObs(hand.transform.localPosition);
sensor.AddObservation(goal.transform.localPosition);
sensor.AddObservation(hand.transform.localPosition);

AddVectorObs(m_GoalSpeed);
sensor.AddObservation(m_GoalSpeed);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

public class TemplateAgent : Agent
{
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
}

Expand Down
20 changes: 10 additions & 10 deletions Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,19 @@ public override void InitializeAgent()
SetResetParameters();
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
AddVectorObs(m_InvertMult * (transform.position.x - myArea.transform.position.x));
AddVectorObs(transform.position.y - myArea.transform.position.y);
AddVectorObs(m_InvertMult * m_AgentRb.velocity.x);
AddVectorObs(m_AgentRb.velocity.y);
sensor.AddObservation(m_InvertMult * (transform.position.x - myArea.transform.position.x));
sensor.AddObservation(transform.position.y - myArea.transform.position.y);
sensor.AddObservation(m_InvertMult * m_AgentRb.velocity.x);
sensor.AddObservation(m_AgentRb.velocity.y);

AddVectorObs(m_InvertMult * (ball.transform.position.x - myArea.transform.position.x));
AddVectorObs(ball.transform.position.y - myArea.transform.position.y);
AddVectorObs(m_InvertMult * m_BallRb.velocity.x);
AddVectorObs(m_BallRb.velocity.y);
sensor.AddObservation(m_InvertMult * (ball.transform.position.x - myArea.transform.position.x));
sensor.AddObservation(ball.transform.position.y - myArea.transform.position.y);
sensor.AddObservation(m_InvertMult * m_BallRb.velocity.x);
sensor.AddObservation(m_BallRb.velocity.y);

AddVectorObs(m_InvertMult * gameObject.transform.rotation.z);
sensor.AddObservation(m_InvertMult * gameObject.transform.rotation.z);
}

public override void AgentAction(float[] vectorAction)
Expand Down
30 changes: 15 additions & 15 deletions Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,40 +66,40 @@ public override void InitializeAgent()
/// <summary>
/// Add relevant information on each body part to observations.
/// </summary>
public void CollectObservationBodyPart(BodyPart bp)
public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
{
var rb = bp.rb;
AddVectorObs(bp.groundContact.touchingGround ? 1 : 0); // Is this bp touching the ground
AddVectorObs(rb.velocity);
AddVectorObs(rb.angularVelocity);
sensor.AddObservation(bp.groundContact.touchingGround ? 1 : 0); // Is this bp touching the ground
sensor.AddObservation(rb.velocity);
sensor.AddObservation(rb.angularVelocity);
var localPosRelToHips = hips.InverseTransformPoint(rb.position);
AddVectorObs(localPosRelToHips);
sensor.AddObservation(localPosRelToHips);

if (bp.rb.transform != hips && bp.rb.transform != handL && bp.rb.transform != handR &&
bp.rb.transform != footL && bp.rb.transform != footR && bp.rb.transform != head)
{
AddVectorObs(bp.currentXNormalizedRot);
AddVectorObs(bp.currentYNormalizedRot);
AddVectorObs(bp.currentZNormalizedRot);
AddVectorObs(bp.currentStrength / m_JdController.maxJointForceLimit);
sensor.AddObservation(bp.currentXNormalizedRot);
sensor.AddObservation(bp.currentYNormalizedRot);
sensor.AddObservation(bp.currentZNormalizedRot);
sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit);
}
}

/// <summary>
/// Loop over body parts to add them to observation.
/// </summary>
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
m_JdController.GetCurrentJointForces();

AddVectorObs(m_DirToTarget.normalized);
AddVectorObs(m_JdController.bodyPartsDict[hips].rb.position);
AddVectorObs(hips.forward);
AddVectorObs(hips.up);
sensor.AddObservation(m_DirToTarget.normalized);
sensor.AddObservation(m_JdController.bodyPartsDict[hips].rb.position);
sensor.AddObservation(hips.forward);
sensor.AddObservation(hips.up);

foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
{
CollectObservationBodyPart(bodyPart);
CollectObservationBodyPart(bodyPart, sensor);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,12 @@ void MoveTowards(
}
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
var agentPos = m_AgentRb.position - ground.transform.position;

AddVectorObs(agentPos / 20f);
AddVectorObs(DoGroundCheck(true) ? 1 : 0);
sensor.AddObservation(agentPos / 20f);
sensor.AddObservation(DoGroundCheck(true) ? 1 : 0);
}

/// <summary>
Expand Down
1 change: 0 additions & 1 deletion com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using UnityEngine;
using UnityEditor;
using Barracuda;
using MLAgents.Sensor;

namespace MLAgents
{
Expand Down
Loading

0 comments on commit 6551974

Please sign in to comment.