Skip to content

Commit

Permalink
Fix width and height in visual observations (#2919)
Browse files Browse the repository at this point in the history
* swap h/w in sensor

* change texture to non-square, retrain model

* get dimensions from RenderTexture
  • Loading branch information
Chris Elion authored Nov 18, 2019
1 parent 7e504b7 commit 7ed6b3a
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ MonoBehaviour:
camera: {fileID: 20743940359151984}
sensorName: CameraSensor
width: 84
height: 84
height: 64
grayscale: 0
--- !u!114 &114935253044749092
MonoBehaviour:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ RenderTexture:
m_ForcedFallbackFormat: 4
m_DownscaleFallback: 0
m_Width: 84
m_Height: 84
m_Height: 64
m_AntiAliasing: 1
m_DepthFormat: 1
m_ColorFormat: 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ MonoBehaviour:
renderTexture: {fileID: 8400000, guid: 114608d5384404f89bff4b6f88432958, type: 2}
sensorName: RenderTextureSensor
width: 84
height: 84
height: 64
grayscale: 0
--- !u!1 &260425459
GameObject:
Expand Down Expand Up @@ -1584,7 +1584,7 @@ RectTransform:
m_AnchorMin: {x: 0.5, y: 0.5}
m_AnchorMax: {x: 0.5, y: 0.5}
m_AnchoredPosition: {x: -369.5, y: -197}
m_SizeDelta: {x: 200, y: 200}
m_SizeDelta: {x: 200, y: 152}
m_Pivot: {x: 0.5, y: 0.5}
--- !u!114 &1305247361
MonoBehaviour:
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,8 @@ static string CheckVisualObsShape(
TensorProxy tensorProxy, SensorComponent sensorComponent)
{
var shape = sensorComponent.GetObservationShape();
var widthBp = shape[0];
var heightBp = shape[1];
var heightBp = shape[0];
var widthBp = shape[1];
var pixelBp = shape[2];
var heightT = tensorProxy.shape[1];
var widthT = tensorProxy.shape[2];
Expand Down
2 changes: 1 addition & 1 deletion UnitySDK/Assets/ML-Agents/Scripts/Sensor/CameraSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public CameraSensor(Camera camera, int width, int height, bool grayscale, string
m_Height = height;
m_Grayscale = grayscale;
m_Name = name;
m_Shape = new[] { width, height, grayscale ? 1 : 3 };
m_Shape = new[] { height, width, grayscale ? 1 : 3 };
}

public string GetName()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public override ISensor CreateSensor()

public override int[] GetObservationShape()
{
return new[] { width, height, grayscale ? 1 : 3 };
return new[] { height, width, grayscale ? 1 : 3 };
}
}
}
32 changes: 9 additions & 23 deletions UnitySDK/Assets/ML-Agents/Scripts/Sensor/RenderTextureSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,18 @@ namespace MLAgents.Sensor
public class RenderTextureSensor : ISensor
{
RenderTexture m_RenderTexture;
int m_Width;
int m_Height;
bool m_Grayscale;
string m_Name;
int[] m_Shape;

public RenderTextureSensor(RenderTexture renderTexture, int width, int height, bool grayscale, string name)
public RenderTextureSensor(RenderTexture renderTexture, bool grayscale, string name)
{
m_RenderTexture = renderTexture;
m_Width = width;
m_Height = height;
var width = renderTexture != null ? renderTexture.width : 0;
var height = renderTexture != null ? renderTexture.height : 0;
m_Grayscale = grayscale;
m_Name = name;
m_Shape = new[] { width, height, grayscale ? 1 : 3 };
m_Shape = new[] { height, width, grayscale ? 1 : 3 };
}

public string GetName()
Expand All @@ -36,7 +34,7 @@ public byte[] GetCompressedObservation()
{
using(TimerStack.Instance.Scoped("RenderTexSensor.GetCompressedObservation"))
{
var texture = ObservationToTexture(m_RenderTexture, m_Width, m_Height);
var texture = ObservationToTexture(m_RenderTexture);
// TODO support more types here, e.g. JPG
var compressed = texture.EncodeToPNG();
UnityEngine.Object.Destroy(texture);
Expand All @@ -48,7 +46,7 @@ public int Write(WriteAdapter adapter)
{
using (TimerStack.Instance.Scoped("RenderTexSensor.GetCompressedObservation"))
{
var texture = ObservationToTexture(m_RenderTexture, m_Width, m_Height);
var texture = ObservationToTexture(m_RenderTexture);
var numWritten = Utilities.TextureToTensorProxy(texture, adapter, m_Grayscale);
UnityEngine.Object.Destroy(texture);
return numWritten;
Expand All @@ -67,25 +65,13 @@ public SensorCompressionType GetCompressionType()
/// </summary>
/// <returns>The 2D texture.</returns>
/// <param name="obsTexture">RenderTexture.</param>
/// <param name="width">Width of resulting 2D texture.</param>
/// <param name="height">Height of resulting 2D texture.</param>
/// <returns name="texture2D">Texture2D to render to.</returns>
public static Texture2D ObservationToTexture(RenderTexture obsTexture, int width, int height)
public static Texture2D ObservationToTexture(RenderTexture obsTexture)
{
var height = obsTexture.height;
var width = obsTexture.width;
var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);

if (width != texture2D.width || height != texture2D.height)
{
texture2D.Resize(width, height);
}

if (width != obsTexture.width || height != obsTexture.height)
{
throw new UnityAgentsException(string.Format(
"RenderTexture {0} : width/height is {1}/{2} brain is expecting {3}/{4}.",
obsTexture.name, obsTexture.width, obsTexture.height, width, height));
}

var prevActiveRt = RenderTexture.active;
RenderTexture.active = obsTexture;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@ public class RenderTextureSensorComponent : SensorComponent
{
public RenderTexture renderTexture;
public string sensorName = "RenderTextureSensor";
public int width = 84;
public int height = 84;
public bool grayscale;

public override ISensor CreateSensor()
{
return new RenderTextureSensor(renderTexture, width, height, grayscale, sensorName);
return new RenderTextureSensor(renderTexture, grayscale, sensorName);
}

public override int[] GetObservationShape()
{
return new[] { width, height, grayscale ? 1 : 3 };
var width = renderTexture != null ? renderTexture.width : 0;
var height = renderTexture != null ? renderTexture.height : 0;

return new[] { height, width, grayscale ? 1 : 3 };
}
}
}

0 comments on commit 7ed6b3a

Please sign in to comment.