Skip to content

Commit

Permalink
Bring back root reference in grid sensor (#5300)
Browse files Browse the repository at this point in the history
  • Loading branch information
dongruoping committed Apr 23, 2021
1 parent 2fb3825 commit 1c148b5
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2261,6 +2261,16 @@ PrefabInstance:
m_Modification:
m_TransformParent: {fileID: 8188317207052398481}
m_Modifications:
- target: {fileID: 1548337883655231979, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
propertyPath: m_RootReference
value:
objectReference: {fileID: 8190299122290044757}
- target: {fileID: 1548337883655231979, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
propertyPath: m_AgentGameObject
value:
objectReference: {fileID: 8190299122290044757}
- target: {fileID: 2598450485826216109, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
propertyPath: m_Model
Expand Down Expand Up @@ -2336,7 +2346,7 @@ MonoBehaviour:
type: 3}
m_PrefabInstance: {fileID: 6067781793364901444}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 0}
m_GameObject: {fileID: 8190299122290044757}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: d94a85eca2e074578943301959c555ba, type: 3}
Expand All @@ -2348,13 +2358,29 @@ Transform:
type: 3}
m_PrefabInstance: {fileID: 6067781793364901444}
m_PrefabAsset: {fileID: 0}
--- !u!1 &8190299122290044757 stripped
GameObject:
m_CorrespondingSourceObject: {fileID: 2710286047221272849, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
m_PrefabInstance: {fileID: 6067781793364901444}
m_PrefabAsset: {fileID: 0}
--- !u!1001 &6565363751102736699
PrefabInstance:
m_ObjectHideFlags: 0
serializedVersion: 2
m_Modification:
m_TransformParent: {fileID: 8188317207052398481}
m_Modifications:
- target: {fileID: 1548337883655231979, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
propertyPath: m_RootReference
value:
objectReference: {fileID: 9115291448867436586}
- target: {fileID: 1548337883655231979, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
propertyPath: m_AgentGameObject
value:
objectReference: {fileID: 9115291448867436586}
- target: {fileID: 2598450485826216109, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
propertyPath: m_Model
Expand Down Expand Up @@ -2435,7 +2461,7 @@ MonoBehaviour:
type: 3}
m_PrefabInstance: {fileID: 6565363751102736699}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 0}
m_GameObject: {fileID: 9115291448867436586}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: d94a85eca2e074578943301959c555ba, type: 3}
Expand All @@ -2447,13 +2473,29 @@ Transform:
type: 3}
m_PrefabInstance: {fileID: 6565363751102736699}
m_PrefabAsset: {fileID: 0}
--- !u!1 &9115291448867436586 stripped
GameObject:
m_CorrespondingSourceObject: {fileID: 2710286047221272849, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
m_PrefabInstance: {fileID: 6565363751102736699}
m_PrefabAsset: {fileID: 0}
--- !u!1001 &6716844123244810954
PrefabInstance:
m_ObjectHideFlags: 0
serializedVersion: 2
m_Modification:
m_TransformParent: {fileID: 8188317207052398481}
m_Modifications:
- target: {fileID: 1548337883655231979, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
propertyPath: m_RootReference
value:
objectReference: {fileID: 8695281997955662811}
- target: {fileID: 1548337883655231979, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
propertyPath: m_AgentGameObject
value:
objectReference: {fileID: 8695281997955662811}
- target: {fileID: 2598450485826216109, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
propertyPath: m_Model
Expand Down Expand Up @@ -2534,7 +2576,7 @@ MonoBehaviour:
type: 3}
m_PrefabInstance: {fileID: 6716844123244810954}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 0}
m_GameObject: {fileID: 8695281997955662811}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: d94a85eca2e074578943301959c555ba, type: 3}
Expand All @@ -2546,3 +2588,9 @@ Transform:
type: 3}
m_PrefabInstance: {fileID: 6716844123244810954}
m_PrefabAsset: {fileID: 0}
--- !u!1 &8695281997955662811 stripped
GameObject:
m_CorrespondingSourceObject: {fileID: 2710286047221272849, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
m_PrefabInstance: {fileID: 6716844123244810954}
m_PrefabAsset: {fileID: 0}
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,11 @@ PrefabInstance:
m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
- target: {fileID: 4704531522807670703, guid: f5bbed44a6ea747a687fbbb738eb1730,
type: 3}
propertyPath: m_ShowGizmos
value: 0
objectReference: {fileID: 0}
- target: {fileID: 8188317207052398481, guid: f5bbed44a6ea747a687fbbb738eb1730,
type: 3}
propertyPath: m_RootOrder
Expand Down
1 change: 1 addition & 0 deletions com.unity.ml-agents/Editor/GridSensorComponentEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public override void OnInspectorGUI()
gridSize.vector3IntValue = new Vector3Int(newGridSize.x, 1, newGridSize.z);
}
EditorGUI.EndDisabledGroup();
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_AgentGameObject)), true);
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_RotateWithAgent)), true);

EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
Expand Down
21 changes: 12 additions & 9 deletions com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ internal class BoxOverlapChecker
Vector3Int m_GridSize;
bool m_RotateWithAgent;
LayerMask m_ColliderMask;
GameObject m_RootReference;
GameObject m_CenterObject;
GameObject m_AgentGameObject;
string[] m_DetectableTags;
int m_InitialColliderBufferSize;
int m_MaxColliderBufferSize;
Expand All @@ -32,7 +33,8 @@ public BoxOverlapChecker(
Vector3Int gridSize,
bool rotateWithAgent,
LayerMask colliderMask,
GameObject rootReference,
GameObject centerObject,
GameObject agentGameObject,
string[] detectableTags,
int initialColliderBufferSize,
int maxColliderBufferSize)
Expand All @@ -41,7 +43,8 @@ public BoxOverlapChecker(
m_GridSize = gridSize;
m_RotateWithAgent = rotateWithAgent;
m_ColliderMask = colliderMask;
m_RootReference = rootReference;
m_CenterObject = centerObject;
m_AgentGameObject = agentGameObject;
m_DetectableTags = detectableTags;
m_InitialColliderBufferSize = initialColliderBufferSize;
m_MaxColliderBufferSize = maxColliderBufferSize;
Expand Down Expand Up @@ -95,17 +98,17 @@ internal Vector3 GetCellGlobalPosition(int cellIndex)
{
if (m_RotateWithAgent)
{
return m_RootReference.transform.TransformPoint(m_CellLocalPositions[cellIndex]);
return m_CenterObject.transform.TransformPoint(m_CellLocalPositions[cellIndex]);
}
else
{
return m_CellLocalPositions[cellIndex] + m_RootReference.transform.position;
return m_CellLocalPositions[cellIndex] + m_CenterObject.transform.position;
}
}

internal Quaternion GetGridRotation()
{
return m_RotateWithAgent ? m_RootReference.transform.rotation : Quaternion.identity;
return m_RotateWithAgent ? m_CenterObject.transform.rotation : Quaternion.identity;
}

/// <summary>
Expand Down Expand Up @@ -191,13 +194,13 @@ void ParseCollidersClosest(Collider[] foundColliders, int numFound, int cellInde
var currentColliderGo = foundColliders[i].gameObject;

// Continue if the current collider go is the root reference
if (ReferenceEquals(currentColliderGo, m_RootReference))
if (ReferenceEquals(currentColliderGo, m_AgentGameObject))
{
continue;
}

var closestColliderPoint = foundColliders[i].ClosestPointOnBounds(cellCenter);
var currentDistanceSquared = (closestColliderPoint - m_RootReference.transform.position).sqrMagnitude;
var currentDistanceSquared = (closestColliderPoint - m_CenterObject.transform.position).sqrMagnitude;

if (currentDistanceSquared >= minDistanceSquared)
{
Expand Down Expand Up @@ -235,7 +238,7 @@ void ParseCollidersAll(Collider[] foundColliders, int numFound, int cellIndex, V
for (int i = 0; i < numFound; i++)
{
var currentColliderGo = foundColliders[i].gameObject;
if (!ReferenceEquals(currentColliderGo, m_RootReference))
if (!ReferenceEquals(currentColliderGo, m_AgentGameObject))
{
detectedAction.Invoke(currentColliderGo, cellIndex);
}
Expand Down
13 changes: 13 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ public bool RotateWithAgent
set { m_RotateWithAgent = value; }
}

[HideInInspector, SerializeField]
internal GameObject m_AgentGameObject;
/// <summary>
/// The reference of the root of the agent. This is used to disambiguate objects with
/// the same tag as the agent. Defaults to current GameObject.
/// </summary>
public GameObject AgentGameObject
{
get { return (m_AgentGameObject == null ? gameObject : m_AgentGameObject); }
set { m_AgentGameObject = value; }
}

[HideInInspector, SerializeField]
internal string[] m_DetectableTags;
/// <summary>
Expand Down Expand Up @@ -191,6 +203,7 @@ public override ISensor[] CreateSensors()
m_RotateWithAgent,
m_ColliderMask,
gameObject,
AgentGameObject,
m_DetectableTags,
m_InitialColliderBufferSize,
m_MaxColliderBufferSize
Expand Down
28 changes: 17 additions & 11 deletions com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ public TestBoxOverlapChecker(
Vector3Int gridSize,
bool rotateWithAgent,
LayerMask colliderMask,
GameObject rootReference,
GameObject centerObject,
GameObject agentGameObject,
string[] detectableTags,
int initialColliderBufferSize,
int maxColliderBufferSize
Expand All @@ -23,7 +24,8 @@ int maxColliderBufferSize
gridSize,
rotateWithAgent,
colliderMask,
rootReference,
centerObject,
agentGameObject,
detectableTags,
initialColliderBufferSize,
maxColliderBufferSize)
Expand Down Expand Up @@ -53,7 +55,8 @@ public static TestBoxOverlapChecker CreateChecker(
int gridSizeX = 10,
int gridSizeZ = 10,
bool rotateWithAgent = true,
GameObject rootReference = null,
GameObject centerObject = null,
GameObject agentGameObject = null,
string[] detectableTags = null,
int initialColliderBufferSize = 4,
int maxColliderBufferSize = 500)
Expand All @@ -63,7 +66,8 @@ public static TestBoxOverlapChecker CreateChecker(
new Vector3Int(gridSizeX, 1, gridSizeZ),
rotateWithAgent,
LayerMask.GetMask("Default"),
rootReference,
centerObject,
agentGameObject,
detectableTags,
initialColliderBufferSize,
maxColliderBufferSize);
Expand All @@ -77,7 +81,7 @@ public void TestCellLocalPosition()
{
var testGo = new GameObject("test");
testGo.transform.position = Vector3.zero;
var boxOverlapSquare = TestBoxOverlapChecker.CreateChecker(gridSizeX: 10, gridSizeZ: 10, rotateWithAgent: false, rootReference: testGo);
var boxOverlapSquare = TestBoxOverlapChecker.CreateChecker(gridSizeX: 10, gridSizeZ: 10, rotateWithAgent: false, agentGameObject: testGo);

var localPos = boxOverlapSquare.CellLocalPositions;
Assert.AreEqual(new Vector3(-4.5f, 0, -4.5f), localPos[0]);
Expand All @@ -88,7 +92,7 @@ public void TestCellLocalPosition()

var testGo2 = new GameObject("test");
testGo2.transform.position = new Vector3(3.5f, 8f, 17f); // random, should have no effect on local positions
var boxOverlapRect = TestBoxOverlapChecker.CreateChecker(gridSizeX: 5, gridSizeZ: 15, rotateWithAgent: true, rootReference: testGo);
var boxOverlapRect = TestBoxOverlapChecker.CreateChecker(gridSizeX: 5, gridSizeZ: 15, rotateWithAgent: true, agentGameObject: testGo);

localPos = boxOverlapRect.CellLocalPositions;
Assert.AreEqual(new Vector3(-2f, 0, -7f), localPos[0]);
Expand All @@ -104,7 +108,7 @@ public void TestCellGlobalPositionNoRotate()
var testGo = new GameObject("test");
var position = new Vector3(3.5f, 8f, 17f);
testGo.transform.position = position;
var boxOverlap = TestBoxOverlapChecker.CreateChecker(gridSizeX: 10, gridSizeZ: 10, rotateWithAgent: false, rootReference: testGo);
var boxOverlap = TestBoxOverlapChecker.CreateChecker(gridSizeX: 10, gridSizeZ: 10, rotateWithAgent: false, agentGameObject: testGo, centerObject: testGo);

Assert.AreEqual(new Vector3(-4.5f, 0, -4.5f) + position, boxOverlap.GetCellGlobalPosition(0));
Assert.AreEqual(new Vector3(-4.5f, 0, 4.5f) + position, boxOverlap.GetCellGlobalPosition(9));
Expand All @@ -126,7 +130,7 @@ public void TestCellGlobalPositionRotate()
var testGo = new GameObject("test");
var position = new Vector3(15f, 6f, 13f);
testGo.transform.position = position;
var boxOverlap = TestBoxOverlapChecker.CreateChecker(gridSizeX: 5, gridSizeZ: 15, rotateWithAgent: true, rootReference: testGo);
var boxOverlap = TestBoxOverlapChecker.CreateChecker(gridSizeX: 5, gridSizeZ: 15, rotateWithAgent: true, agentGameObject: testGo, centerObject: testGo);

Assert.AreEqual(new Vector3(-2f, 0, -7f) + position, boxOverlap.GetCellGlobalPosition(0));
Assert.AreEqual(new Vector3(-2f, 0, 7f) + position, boxOverlap.GetCellGlobalPosition(14));
Expand All @@ -150,7 +154,7 @@ public void TestBufferResize()
var testGo = new GameObject("test");
testGo.transform.position = Vector3.zero;
testObjects.Add(testGo);
var boxOverlap = TestBoxOverlapChecker.CreateChecker(rootReference: testGo, initialColliderBufferSize: 2, maxColliderBufferSize: 5);
var boxOverlap = TestBoxOverlapChecker.CreateChecker(agentGameObject: testGo, centerObject: testGo, initialColliderBufferSize: 2, maxColliderBufferSize: 5);
boxOverlap.Update();
Assert.AreEqual(2, boxOverlap.ColliderBuffer.Length);

Expand Down Expand Up @@ -193,7 +197,8 @@ public void TestParseCollidersClosest()
cellScaleZ: 10f,
gridSizeX: 2,
gridSizeZ: 2,
rootReference: testGo,
agentGameObject: testGo,
centerObject: testGo,
detectableTags: new [] { tag1 });
var helper = new VerifyParseCollidersHelper();
boxOverlap.GridOverlapDetectedClosest += helper.DetectedAction;
Expand Down Expand Up @@ -229,7 +234,8 @@ public void TestParseCollidersAll()
cellScaleZ: 10f,
gridSizeX: 2,
gridSizeZ: 2,
rootReference: testGo,
agentGameObject: testGo,
centerObject: testGo,
detectableTags: new [] { tag1 });
var helper = new VerifyParseCollidersHelper();
boxOverlap.GridOverlapDetectedAll += helper.DetectedAction;
Expand Down
1 change: 0 additions & 1 deletion docs/Migrating.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ will not work in newer version. Some errors might show up when loading the old s
You'll need to remove the old sensor and create a new GridSensor.
* These parameters names have changed but still refer to the same concept in the sensor: `GridNumSide` -> `GridSize`,
`RotateToAgent` -> `RotateWithAgent`, `ObserveMask` -> `ColliderMask`, `DetectableObjects` -> `DetectableTags`
* `RootReference` is removed and the sensor component's GameObject will always be ignored for hit results.
* `DepthType` (`ChanelBase`/`ChannelHot`) option and `ChannelDepth` are removed. Now the default is
one-hot encoding for detected tag. If you were using original GridSensor without overriding any method,
switching to new GridSensor will produce similar effect for training although the actual observations
Expand Down

0 comments on commit 1c148b5

Please sign in to comment.