Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
914 changes: 899 additions & 15 deletions Project/Assets/ML-Agents/Examples/GridWorld/Prefabs/Area.prefab

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

1,380 changes: 1,105 additions & 275 deletions Project/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity

Large diffs are not rendered by default.

67 changes: 63 additions & 4 deletions Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using UnityEngine;
using System.Linq;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using UnityEngine.Rendering;
using UnityEngine.Serialization;
Expand All @@ -19,6 +20,42 @@ public class GridAgent : Agent
"RenderTexture as observations.")]
public Camera renderCamera;

VectorSensorComponent m_GoalSensor;

public enum GridGoal
{
GreenPlus,
RedEx,
}

// Visual representations of the agent. Both are blue on top, but different colors on the bottom - this
// allows the user to see which corresponds to the current goal, but it's not visible to the camera.
// Only one is active at a time.
public GameObject GreenBottom;
public GameObject RedBottom;

GridGoal m_CurrentGoal;

public GridGoal CurrentGoal
{
get { return m_CurrentGoal; }
set
{
switch (value)
{
case GridGoal.GreenPlus:
GreenBottom.SetActive(true);
RedBottom.SetActive(false);
break;
case GridGoal.RedEx:
GreenBottom.SetActive(false);
RedBottom.SetActive(true);
break;
}
m_CurrentGoal = value;
}
}

[Tooltip("Selecting will turn on action masking. Note that a model trained with action " +
"masking turned on may not behave optimally when action masking is turned off.")]
public bool maskActions = true;
Expand All @@ -33,9 +70,17 @@ public class GridAgent : Agent

public override void Initialize()
{
m_GoalSensor = this.GetComponent<VectorSensorComponent>();
m_ResetParams = Academy.Instance.EnvironmentParameters;
}

public override void CollectObservations(VectorSensor sensor)
{
Array values = Enum.GetValues(typeof(GridGoal));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this happen somewhere else? It feels like abuse of CollectObservations(), since it's not touching the input VectorSensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VectorSensor is null here, I do not see an issue with this. Goal Signal is an observation, so it makes sense to me that it is called in CollectObservation.
Would it be better if I put this logic into a CollectGoal method with no arguments that I call in CollectObservations ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CollectGoal is maybe for the example (but let's not add it Agent). Let me think about a better way.

One problem (which I didn't realize until now) is that we don't check for null CollectObservationsSensor during the normal update step:

CollectObservations(collectObservationsSensor);

but we do check for null when the agent is done:
if (collectObservationsSensor != null)
{
// Make sure the latest observations are being passed to training.
collectObservationsSensor.Reset();
using (m_CollectObservationsChecker.Start())
{
CollectObservations(collectObservationsSensor);
}
}

int goalNum = (int)CurrentGoal;
m_GoalSensor.GetSensor().AddOneHotObservation(goalNum, values.Length);
}

public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
// Mask the necessary actions if selected by the user.
Expand Down Expand Up @@ -103,19 +148,31 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
{
transform.position = targetPos;

if (hit.Where(col => col.gameObject.CompareTag("goal")).ToArray().Length == 1)
if (hit.Where(col => col.gameObject.CompareTag("plus")).ToArray().Length == 1)
{
SetReward(1f);
ProvideReward(GridGoal.GreenPlus);
EndEpisode();
}
else if (hit.Where(col => col.gameObject.CompareTag("pit")).ToArray().Length == 1)
else if (hit.Where(col => col.gameObject.CompareTag("ex")).ToArray().Length == 1)
{
SetReward(-1f);
ProvideReward(GridGoal.RedEx);
EndEpisode();
}
}
}

private void ProvideReward(GridGoal hitObject)
{
if (CurrentGoal == hitObject)
{
SetReward(1f);
}
else
{
SetReward(-1f);
}
}

public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
Expand All @@ -142,6 +199,8 @@ public override void Heuristic(in ActionBuffers actionsOut)
public override void OnEpisodeBegin()
{
area.AreaReset();
Array values = Enum.GetValues(typeof(GridGoal));
CurrentGoal = (GridGoal)values.GetValue(UnityEngine.Random.Range(0, values.Length));
}

public void FixedUpdate()
Expand Down
18 changes: 10 additions & 8 deletions Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using UnityEngine;
using System.Linq;
using Unity.MLAgents;
using UnityEngine.Serialization;


public class GridArea : MonoBehaviour
Expand All @@ -15,10 +16,11 @@ public class GridArea : MonoBehaviour

Camera m_AgentCam;

public GameObject goalPref;
public GameObject pitPref;
[FormerlySerializedAs("PlusPref")] public GameObject GreenPlusPrefab;
[FormerlySerializedAs("ExPref")] public GameObject RedExPrefab;
GameObject[] m_Objects;
public int numberOfObstacles = 1;
public int numberOfPlus = 1;
public int numberOfEx = 1;

GameObject m_Plane;
GameObject m_Sn;
Expand All @@ -34,7 +36,7 @@ public void Start()
{
m_ResetParams = Academy.Instance.EnvironmentParameters;

m_Objects = new[] { goalPref, pitPref };
m_Objects = new[] { GreenPlusPrefab, RedExPrefab };

m_AgentCam = transform.Find("agentCam").GetComponent<Camera>();

Expand All @@ -55,14 +57,14 @@ void SetEnvironment()
transform.position = m_InitialPosition * (m_ResetParams.GetWithDefault("gridSize", 5f) + 1);
var playersList = new List<int>();

for (var i = 0; i < (int)m_ResetParams.GetWithDefault("numObstacles", numberOfObstacles); i++)
for (var i = 0; i < (int)m_ResetParams.GetWithDefault("numPlusGoals", numberOfPlus); i++)
{
playersList.Add(1);
playersList.Add(0);
}

for (var i = 0; i < (int)m_ResetParams.GetWithDefault("numGoals", 1f); i++)
for (var i = 0; i < (int)m_ResetParams.GetWithDefault("numExGoals", numberOfEx); i++)
{
playersList.Add(0);
playersList.Add(1);
}
players = playersList.ToArray();

Expand Down
3 changes: 1 addition & 2 deletions Project/Assets/ML-Agents/Examples/GridWorld/TFModels.meta
100755 → 100644

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file not shown.

This file was deleted.

Binary file not shown.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,22 @@ public void TestRenderTextureSensor()
}
}
}

[Test]
public void TestObservationType()
{
var width = 24;
var height = 16;
var camera = Camera.main;
var sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None);
var spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.Default);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.GoalSignal);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.GoalSignal);
}
}
}
2 changes: 1 addition & 1 deletion config/ppo/GridWorld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ behaviors:
learning_rate_schedule: linear
network_settings:
normalize: false
hidden_units: 256
hidden_units: 128
num_layers: 1
vis_encode_type: simple
reward_signals:
Expand Down
2 changes: 2 additions & 0 deletions docs/Learning-Environment-Design-Agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,8 @@ weights of the policy using the goal observations as input. Note that using a
HyperNetwork requires a lot of computations, it is recommended to use a smaller
number of hidden units in the policy to alleviate this.
If set to `none` the goal signal will be considered as regular observations.
For an example on how to use a goal signal, see the
[GridWorld example](Learning-Environment-Examples.md#gridworld).

#### Goal Signal Summary & Best Practices
- Attach a `VectorSensorComponent` or `CameraSensorComponent` to an agent and
Expand Down
16 changes: 9 additions & 7 deletions docs/Learning-Environment-Examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,16 @@ you would like to contribute environments, please see our

![GridWorld](images/gridworld.png)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to link to this environment in the goal signal docs and the Changelog? Just in case a user wants an example of how to use these features


- Set-up: A version of the classic grid-world task. Scene contains agent, goal,
- Set-up: A multi-goal version of the grid-world task. Scene contains agent, goal,
and obstacles.
- Goal: The agent must navigate the grid to the goal while avoiding the
obstacles.
- Goal: The agent must navigate the grid to the appropriate goal while
avoiding the obstacles.
- Agents: The environment contains nine agents with the same Behavior
Parameters.
- Agent Reward Function:
- -0.01 for every step.
- +1.0 if the agent navigates to the goal position of the grid (episode ends).
- -1.0 if the agent navigates to an obstacle (episode ends).
- +1.0 if the agent navigates to the correct goal (episode ends).
- -1.0 if the agent navigates to an incorrect goal (episode ends).
- Behavior Parameters:
- Vector Observation space: None
- Actions: 1 discrete action branch with 5 actions, corresponding to movement in
Expand All @@ -101,8 +101,10 @@ you would like to contribute environments, please see our
checkbox within the `trueAgent` GameObject). The trained model file provided
was generated with action masking turned on.
- Visual Observations: One corresponding to top-down view of GridWorld.
- Float Properties: Three, corresponding to grid size, number of obstacles, and
number of goals.
- Goal Signal : A one hot vector corresponding to which color is the correct goal
for the Agent
- Float Properties: Three, corresponding to grid size, number of green goals, and
number of red goals.
- Benchmark Mean Reward: 0.8

## Push Block
Expand Down
Binary file modified docs/images/gridworld.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.