Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLA-1880] Raycast sensor interface improvements #5222

Merged
merged 6 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ sizes and will need to be retrained. (#5181)
- Make com.unity.modules.physics and com.unity.modules.physics2d optional dependencies. (#5112)
- The default `InferenceDevice` is now `InferenceDevice.Default`, which is equivalent to `InferenceDevice.Burst`. If you
depend on the previous behavior, you can explicitly set the Agent's `InferenceDevice` to `InferenceDevice.CPU`. (#5175)
- Added support for `Goal Signal` as a type of observation. Trainers can now use HyperNetworks to process `Goal Signal`. Trainers with HyperNetworks are more effective at solving multiple tasks. (#5142, #5159, #5149)
- Modified the [GridWorld environment](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Learning-Environment-Examples.md#gridworld) to use the new `Goal Signal` feature. (#5193)
- Added support for `Goal Signal` as a type of observation. Trainers can now use HyperNetworks to process `Goal Signal`. Trainers with HyperNetworks are more effective at solving multiple tasks. (#5142, #5159, #5149)
- Modified the [GridWorld environment](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Learning-Environment-Examples.md#gridworld) to use the new `Goal Signal` feature. (#5193)
- `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222)

#### ml-agents / ml-agents-envs / gym-unity (Python)

Expand Down
141 changes: 72 additions & 69 deletions com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,39 @@ public struct RayOutput
/// </summary>
public GameObject HitGameObject;

/// <summary>
/// Start position of the ray in world space.
/// </summary>
public Vector3 StartPositionWorld;

/// <summary>
/// End position of the ray in world space.
/// </summary>
public Vector3 EndPositionWorld;

/// <summary>
/// The scaled length of the ray.
/// </summary>
/// <remarks>
/// If there is non-(1,1,1) scale, |EndPositionWorld - StartPositionWorld| will be different from
/// the input rayLength.
/// </remarks>
public float ScaledRayLength
{
get
{
var rayDirection = EndPositionWorld - StartPositionWorld;
return rayDirection.magnitude;
}
}

/// <summary>
/// The scaled size of the cast.
/// </summary>
/// <remarks>
/// If there is non-(1,1,1) scale, the cast radius will be also be scaled.
/// </remarks>
public float ScaledCastRadius;

/// <summary>
/// Writes the ray output information to a subset of the float array. Each element in the rayAngles array
Expand Down Expand Up @@ -200,37 +233,6 @@ public void ToFloatArray(int numDetectableTags, int rayIndex, float[] buffer)
public RayOutput[] RayOutputs;
}

/// <summary>
/// Debug information for the raycast hits. This is used by the RayPerceptionSensorComponent.
/// </summary>
internal class DebugDisplayInfo
{
public struct RayInfo
{
public Vector3 worldStart;
public Vector3 worldEnd;
public float castRadius;
public RayPerceptionOutput.RayOutput rayOutput;
}

public void Reset()
{
m_Frame = Time.frameCount;
}

/// <summary>
/// "Age" of the results in number of frames. This is used to adjust the alpha when drawing.
/// </summary>
public int age
{
get { return Time.frameCount - m_Frame; }
}

public RayInfo[] rayInfos;

int m_Frame;
}

/// <summary>
/// A sensor implementation that supports ray cast-based observations.
/// </summary>
Expand All @@ -241,12 +243,16 @@ public class RayPerceptionSensor : ISensor, IBuiltInSensor
string m_Name;

RayPerceptionInput m_RayPerceptionInput;
RayPerceptionOutput m_RayPerceptionOutput;

DebugDisplayInfo m_DebugDisplayInfo;
/// <summary>
/// Time.frameCount at the last time Update() was called. This is only used for display in gizmos.
/// </summary>
int m_DebugLastFrameCount;

internal DebugDisplayInfo debugDisplayInfo
internal int DebugLastFrameCount
{
get { return m_DebugDisplayInfo; }
get { return m_DebugLastFrameCount; }
}

/// <summary>
Expand All @@ -261,10 +267,16 @@ public RayPerceptionSensor(string name, RayPerceptionInput rayInput)

SetNumObservations(rayInput.OutputSize());

if (Application.isEditor)
{
m_DebugDisplayInfo = new DebugDisplayInfo();
}
m_DebugLastFrameCount = Time.frameCount;
m_RayPerceptionOutput = new RayPerceptionOutput();
}

/// <summary>
/// The most recent raycast results.
/// </summary>
public RayPerceptionOutput RayPerceptionOutput
{
get { return m_RayPerceptionOutput; }
}

void SetNumObservations(int numObservations)
Expand Down Expand Up @@ -301,33 +313,15 @@ public int Write(ObservationWriter writer)
using (TimerStack.Instance.Scoped("RayPerceptionSensor.Perceive"))
{
Array.Clear(m_Observations, 0, m_Observations.Length);

var numRays = m_RayPerceptionInput.Angles.Count;
var numDetectableTags = m_RayPerceptionInput.DetectableTags.Count;

if (m_DebugDisplayInfo != null)
{
// Reset the age information, and resize the buffer if needed.
m_DebugDisplayInfo.Reset();
if (m_DebugDisplayInfo.rayInfos == null || m_DebugDisplayInfo.rayInfos.Length != numRays)
{
m_DebugDisplayInfo.rayInfos = new DebugDisplayInfo.RayInfo[numRays];
}
}

// For each ray, do the casting, and write the information to the observation buffer
// For each ray, write the information to the observation buffer
for (var rayIndex = 0; rayIndex < numRays; rayIndex++)
{
DebugDisplayInfo.RayInfo debugRay;
var rayOutput = PerceiveSingleRay(m_RayPerceptionInput, rayIndex, out debugRay);

if (m_DebugDisplayInfo != null)
{
m_DebugDisplayInfo.rayInfos[rayIndex] = debugRay;
}

rayOutput.ToFloatArray(numDetectableTags, rayIndex, m_Observations);
m_RayPerceptionOutput.RayOutputs[rayIndex].ToFloatArray(numDetectableTags, rayIndex, m_Observations);
}

// Finally, add the observations to the ObservationWriter
writer.AddList(m_Observations);
}
Expand All @@ -337,6 +331,19 @@ public int Write(ObservationWriter writer)
/// <inheritdoc/>
public void Update()
{
m_DebugLastFrameCount = Time.frameCount;
var numRays = m_RayPerceptionInput.Angles.Count;

if (m_RayPerceptionOutput.RayOutputs == null || m_RayPerceptionOutput.RayOutputs.Length != numRays)
{
m_RayPerceptionOutput.RayOutputs = new RayPerceptionOutput.RayOutput[numRays];
}

// For each ray, do the casting and save the results.
for (var rayIndex = 0; rayIndex < numRays; rayIndex++)
{
m_RayPerceptionOutput.RayOutputs[rayIndex] = PerceiveSingleRay(m_RayPerceptionInput, rayIndex);
}
}

/// <inheritdoc/>
Expand Down Expand Up @@ -384,8 +391,7 @@ public static RayPerceptionOutput Perceive(RayPerceptionInput input)

for (var rayIndex = 0; rayIndex < input.Angles.Count; rayIndex++)
{
DebugDisplayInfo.RayInfo debugRay;
output.RayOutputs[rayIndex] = PerceiveSingleRay(input, rayIndex, out debugRay);
output.RayOutputs[rayIndex] = PerceiveSingleRay(input, rayIndex);
}

return output;
Expand All @@ -396,12 +402,10 @@ public static RayPerceptionOutput Perceive(RayPerceptionInput input)
/// </summary>
/// <param name="input"></param>
/// <param name="rayIndex"></param>
/// <param name="debugRayOut"></param>
/// <returns></returns>
internal static RayPerceptionOutput.RayOutput PerceiveSingleRay(
RayPerceptionInput input,
int rayIndex,
out DebugDisplayInfo.RayInfo debugRayOut
int rayIndex
)
{
var unscaledRayLength = input.RayLength;
Expand Down Expand Up @@ -473,7 +477,10 @@ out DebugDisplayInfo.RayInfo debugRayOut
HitFraction = hitFraction,
HitTaggedObject = false,
HitTagIndex = -1,
HitGameObject = hitObject
HitGameObject = hitObject,
StartPositionWorld = startPositionWorld,
EndPositionWorld = endPositionWorld,
ScaledCastRadius = scaledCastRadius
};

if (castHit)
Expand Down Expand Up @@ -505,10 +512,6 @@ out DebugDisplayInfo.RayInfo debugRayOut
}
}

debugRayOut.worldStart = startPositionWorld;
debugRayOut.worldEnd = endPositionWorld;
debugRayOut.rayOutput = rayOutput;
debugRayOut.castRadius = scaledCastRadius;

return rayOutput;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,26 @@ internal void UpdateSensor()
}
}

internal int SensorObservationAge()
{
if (m_RaySensor != null)
{
return Time.frameCount - m_RaySensor.DebugLastFrameCount;
}

return 0;
}

void OnDrawGizmosSelected()
{
if (m_RaySensor?.debugDisplayInfo?.rayInfos != null)
if (m_RaySensor?.RayPerceptionOutput?.RayOutputs != null)
{
// If we have cached debug info from the sensor, draw that.
// Draw "old" observations in a lighter color.
// Since the agent may not step every frame, this helps de-emphasize "stale" hit information.
var alpha = Mathf.Pow(.5f, m_RaySensor.debugDisplayInfo.age);
var alpha = Mathf.Pow(.5f, SensorObservationAge());

foreach (var rayInfo in m_RaySensor.debugDisplayInfo.rayInfos)
foreach (var rayInfo in m_RaySensor.RayPerceptionOutput.RayOutputs)
{
DrawRaycastGizmos(rayInfo, alpha);
}
Expand All @@ -276,34 +286,33 @@ void OnDrawGizmosSelected()
rayInput.DetectableTags = null;
for (var rayIndex = 0; rayIndex < rayInput.Angles.Count; rayIndex++)
{
DebugDisplayInfo.RayInfo debugRay;
RayPerceptionSensor.PerceiveSingleRay(rayInput, rayIndex, out debugRay);
DrawRaycastGizmos(debugRay);
var rayOutput = RayPerceptionSensor.PerceiveSingleRay(rayInput, rayIndex);
DrawRaycastGizmos(rayOutput);
}
}
}

/// <summary>
/// Draw the debug information from the sensor (if available).
/// </summary>
void DrawRaycastGizmos(DebugDisplayInfo.RayInfo rayInfo, float alpha = 1.0f)
void DrawRaycastGizmos(RayPerceptionOutput.RayOutput rayOutput, float alpha = 1.0f)
{
var startPositionWorld = rayInfo.worldStart;
var endPositionWorld = rayInfo.worldEnd;
var startPositionWorld = rayOutput.StartPositionWorld;
var endPositionWorld = rayOutput.EndPositionWorld;
var rayDirection = endPositionWorld - startPositionWorld;
rayDirection *= rayInfo.rayOutput.HitFraction;
rayDirection *= rayOutput.HitFraction;

// hit fraction ^2 will shift "far" hits closer to the hit color
var lerpT = rayInfo.rayOutput.HitFraction * rayInfo.rayOutput.HitFraction;
var lerpT = rayOutput.HitFraction * rayOutput.HitFraction;
var color = Color.Lerp(rayHitColor, rayMissColor, lerpT);
color.a *= alpha;
Gizmos.color = color;
Gizmos.DrawRay(startPositionWorld, rayDirection);

// Draw the hit point as a sphere. If using rays to cast (0 radius), use a small sphere.
if (rayInfo.rayOutput.HasHit)
if (rayOutput.HasHit)
{
var hitRadius = Mathf.Max(rayInfo.castRadius, .05f);
var hitRadius = Mathf.Max(rayOutput.ScaledCastRadius, .05f);
Gizmos.DrawWireSphere(startPositionWorld + rayDirection, hitRadius);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ public void CheckSetupRayPerceptionSensorComponent()
sensorComponent.ObservationStacks = 2;

sensorComponent.CreateSensors();

var sensor = sensorComponent.RaySensor;
sensor.Update();
var outputs = sensor.RayPerceptionOutput;
Assert.AreEqual(outputs.RayOutputs.Length, 2*sensorComponent.RaysPerDirection + 1);
}
#endif
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ public void TestRaycasts()
{
perception.SphereCastRadius = castRadius;
var sensor = perception.CreateSensors()[0];
sensor.Update();

var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
Expand Down Expand Up @@ -166,6 +167,7 @@ public void TestRaycastMiss()
perception.DetectableTags.Add(k_SphereTag);

var sensor = perception.CreateSensors()[0];
sensor.Update();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
var outputBuffer = new float[expectedObs];
Expand Down Expand Up @@ -214,6 +216,7 @@ public void TestRayFilter()
perception.RayLayerMask = layerMask;

var sensor = perception.CreateSensors()[0];
sensor.Update();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
var outputBuffer = new float[expectedObs];
Expand Down Expand Up @@ -260,6 +263,7 @@ public void TestRaycastsScaled()
{
perception.SphereCastRadius = castRadius;
var sensor = perception.CreateSensors()[0];
sensor.Update();

var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
Expand Down Expand Up @@ -309,6 +313,7 @@ public void TestRayZeroLength()
// Set the layer mask to either the default, or one that ignores the close cube's layer

var sensor = perception.CreateSensors()[0];
sensor.Update();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
var outputBuffer = new float[expectedObs];
Expand Down