diff --git a/.circleci/config.yml b/.circleci/config.yml
index a554cfb8ef..e36df80ec3 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -86,6 +86,11 @@ jobs:
markdown_link_check:
+ parameters:
+ precommit_command:
+ type: string
+ description: precommit hook to run
+ default: markdown-link-check
docker:
- image: circleci/node:12.6.0
working_directory: ~/repo
@@ -117,7 +122,7 @@ jobs:
name: Run markdown-link-check via precommit
command: |
. venv/bin/activate
- pre-commit run --hook-stage manual markdown-link-check --all-files
+ pre-commit run --hook-stage manual << parameters.precommit_command >> --all-files
protobuf_generation_check:
docker:
@@ -223,7 +228,13 @@ workflows:
executor: python373
pyversion: 3.7.3
# Test python 3.7 with the newest supported versions
- pip_constraints: test_constraints_max_version.txt
+ pip_constraints: test_constraints_max_tf1_version.txt
+ - build_python:
+ name: python_3.7.3+tf2
+ executor: python373
+ pyversion: 3.7.3
+ # Test python 3.7 with the newest supported versions
+ pip_constraints: test_constraints_max_tf2_version.txt
- markdown_link_check
- protobuf_generation_check
- deploy:
@@ -250,3 +261,15 @@ workflows:
only: /[0-9]+(\.[0-9]+)*(\.dev[0-9]+)*/
branches:
ignore: /.*/
+ nightly:
+ triggers:
+ - schedule:
+ cron: "0 0 * * *"
+ filters:
+ branches:
+ only:
+ - develop
+ jobs:
+ - markdown_link_check:
+ name: markdown-link-check full
+ precommit_command: markdown-link-check-full
diff --git a/.gitignore b/.gitignore
index c019b218d6..1a7b3d983f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,7 +4,6 @@
/UnitySDK/[Bb]uild/
/UnitySDK/[Bb]uilds/
/UnitySDK/[Pp]ackages/
-/UnitySDK/[Uu]nity[Pp]ackage[Mm]anager/
/UnitySDK/Assets/AssetStoreTools*
/UnitySDK/Assets/Plugins*
/UnitySDK/Assets/Demonstrations*
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index da921eb018..7f77c2d7f9 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -10,7 +10,7 @@ repos:
)$
- repo: https://github.com/pre-commit/mirrors-mypy
- rev: v0.720
+ rev: v0.740
hooks:
- id: mypy
name: mypy-ml-agents
@@ -28,7 +28,7 @@ repos:
args: [--ignore-missing-imports, --disallow-incomplete-defs]
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v2.2.3
+ rev: v2.4.0
hooks:
- id: mixed-line-ending
exclude: >
@@ -44,17 +44,33 @@ repos:
.*_pb2.py|
.*_pb2_grpc.py
)$
- # temporarily pin flake8-comprehensions
- additional_dependencies: [flake8-comprehensions==3.0.1]
+ # flake8-tidy-imports is used for banned-modules, not actually tidying
+ additional_dependencies: [flake8-comprehensions, flake8-tidy-imports]
- id: trailing-whitespace
name: trailing-whitespace-markdown
types: [markdown]
+ - id: check-merge-conflict
+ args: [--assume-in-merge]
- repo: https://github.com/pre-commit/pygrep-hooks
- rev: v1.4.1 # Use the ref you want to point at
+ rev: v1.4.2
hooks:
- id: python-check-mock-methods
+
+
+- repo: https://github.com/pre-commit/mirrors-pylint
+ rev: v2.4.4
+ hooks:
+ - id: pylint
+ exclude: >
+ (?x)^(
+ .*_pb2.py|
+ .*_pb2_grpc.py|
+ .*/tests/.*
+ )$
+ args: [--score=n]
+
# "Local" hooks, see https://pre-commit.com/#repository-local-hooks
- repo: local
hooks:
@@ -63,15 +79,22 @@ repos:
# markdown-link-check doesn't support multiple files on the commandline, so this hacks around that.
# Note that you must install the package separately via npm. For example:
# brew install npm; npm install -g markdown-link-check
- entry: bash -xc 'echo "$@" | xargs -n1 -t markdown-link-check -c markdown-link-check.config.json' --
+ entry: bash -xc 'echo "$@" | xargs -n1 -t markdown-link-check -c markdown-link-check.fast.json' --
language: system
types: [markdown]
# Don't check localized files since their target might not be localized.
exclude: ".*localized.*"
# Only run manually, e.g. pre-commit run --hook-stage manual markdown-link-check
stages: [manual]
+ - id: markdown-link-check-full
+ name: markdown-link-check-full
+ entry: bash -xc 'echo "$@" | xargs -n1 -t markdown-link-check -c markdown-link-check.full.json' --
+ language: system
+ types: [markdown]
+ exclude: ".*localized.*"
+ stages: [manual]
- id: validate-versions
name: validate library versions
language: script
entry: utils/validate_versions.py
- files: ".*/setup.py"
+ files: ".*/__init__.py"
diff --git a/.pylintrc b/.pylintrc
new file mode 100644
index 0000000000..11d1f0b75c
--- /dev/null
+++ b/.pylintrc
@@ -0,0 +1,43 @@
+[MASTER]
+# Add files or directories to the blacklist. They should be base names, not
+# paths.
+ignore=CVS
+
+[MESSAGES CONTROL]
+#enable=
+
+disable =
+ # C0301: Line too long
+ # C0330: Wrong hanging indentation before block
+ # disabled because black handles this
+ C0301,C0330,
+
+ # C0114: Missing module docstring
+ # C0115: Missing class docstring
+ # C0116: Missing function or method docstring
+ C0114,C0115,C0116,
+
+ # All convention and refactor for now
+ C,R,
+
+ # W1201: Specify string format arguments as logging function parameters
+ # W1202: Use % formatting in logging functions and pass the % parameters as arguments
+ W1201,W1202,
+
+ # W0612: Unused variable
+ # W0613: Unused argument
+ W0612, W0613,
+
+ # W0107: Unnecessary pass statement
+ W0107,
+
+ # W0511 "TODO"
+ W0511,
+
+ # W0703: Catching too general exception Exception
+ W0703,
+
+ # E0401: Unable to import...
+ # E0611: No name '...' in module '...'
+ # need to look into these, probably namespace packages
+ E0401, E0611
diff --git a/.yamato/csharp-tests.yml b/.yamato/csharp-tests.yml
index c34d3b7c00..36053001e6 100644
--- a/.yamato/csharp-tests.yml
+++ b/.yamato/csharp-tests.yml
@@ -8,22 +8,17 @@ test_mac_editmode_{{ editor.version }}:
name: Test Mac EditMode {{ editor.version }}
agent:
type: Unity::VM::osx
- image: ml-agents/ml-agents-bokken-mac:release
+ image: ml-agents/ml-agents-bokken-mac:v0.1.2-440635
flavor: i1.small
variables:
UNITY_VERSION: {{ editor.version }}
commands:
- ./run-tests-editmode-osx-editor.sh
triggers:
- branches:
- only:
- - "/develop-.*/"
- targets:
- only:
- - "develop"
pull_requests:
- targets:
only:
+ - "develop"
- "master"
- "/release-.*/"
- "/hotfix-.*/"
diff --git a/.yamato/standalone-build-test.yml b/.yamato/standalone-build-test.yml
index e2a9147eb2..1bc2d68e27 100644
--- a/.yamato/standalone-build-test.yml
+++ b/.yamato/standalone-build-test.yml
@@ -8,23 +8,18 @@ test_mac_standalone_{{ editor.version }}:
name: Test Mac Standalone {{ editor.version }}
agent:
type: Unity::VM::osx
- image: ml-agents/ml-agents-bokken-mac:release
+ image: ml-agents/ml-agents-bokken-mac:v0.1.2-440635
flavor: i1.small
variables:
UNITY_VERSION: {{ editor.version }}
commands:
- ./run-standalone-build-osx.sh
triggers:
- branches:
- only:
- - "/develop-.*/"
- targets:
- only:
- - "develop"
pull_requests:
- targets:
only:
+ - "develop"
- "master"
- "/release-.*/"
- "/hotfix-.*/"
-{% endfor %}
\ No newline at end of file
+{% endfor %}
diff --git a/UnitySDK/Assets/ML-Agents/Editor/AgentEditor.cs b/UnitySDK/Assets/ML-Agents/Editor/AgentEditor.cs
index a58fc7035f..8494ed2d1d 100644
--- a/UnitySDK/Assets/ML-Agents/Editor/AgentEditor.cs
+++ b/UnitySDK/Assets/ML-Agents/Editor/AgentEditor.cs
@@ -1,6 +1,5 @@
using UnityEngine;
using UnityEditor;
-using Barracuda;
namespace MLAgents
{
diff --git a/UnitySDK/Assets/ML-Agents/Editor/BehaviorParametersEditor.cs b/UnitySDK/Assets/ML-Agents/Editor/BehaviorParametersEditor.cs
index d08e085059..d09b598942 100644
--- a/UnitySDK/Assets/ML-Agents/Editor/BehaviorParametersEditor.cs
+++ b/UnitySDK/Assets/ML-Agents/Editor/BehaviorParametersEditor.cs
@@ -1,6 +1,7 @@
using UnityEngine;
using UnityEditor;
using Barracuda;
+using MLAgents.Sensor;
namespace MLAgents
{
@@ -11,27 +12,27 @@ This code is meant to modify the behavior of the inspector on Agent Components.
[CanEditMultipleObjects]
public class BehaviorParametersEditor : Editor
{
- private const float k_TimeBetweenModelReloads = 2f;
+ const float k_TimeBetweenModelReloads = 2f;
// Time since the last reload of the model
- private float m_TimeSinceModelReload;
+ float m_TimeSinceModelReload;
// Whether or not the model needs to be reloaded
- private bool m_RequireReload;
+ bool m_RequireReload;
public override void OnInspectorGUI()
{
- var serializedObject = base.serializedObject;
- serializedObject.Update();
+ var so = serializedObject;
+ so.Update();
// Drawing the Behavior Parameters
EditorGUI.BeginChangeCheck();
EditorGUI.indentLevel++;
- EditorGUILayout.PropertyField(serializedObject.FindProperty("m_BehaviorName"));
- EditorGUILayout.PropertyField(serializedObject.FindProperty("m_BrainParameters"), true);
- EditorGUILayout.PropertyField(serializedObject.FindProperty("m_Model"), true);
+ EditorGUILayout.PropertyField(so.FindProperty("m_BehaviorName"));
+ EditorGUILayout.PropertyField(so.FindProperty("m_BrainParameters"), true);
+ EditorGUILayout.PropertyField(so.FindProperty("m_Model"), true);
EditorGUI.indentLevel++;
- EditorGUILayout.PropertyField(serializedObject.FindProperty("m_InferenceDevice"), true);
+ EditorGUILayout.PropertyField(so.FindProperty("m_InferenceDevice"), true);
EditorGUI.indentLevel--;
- EditorGUILayout.PropertyField(serializedObject.FindProperty("m_UseHeuristic"));
+ EditorGUILayout.PropertyField(so.FindProperty("m_BehaviorType"));
// EditorGUILayout.PropertyField(serializedObject.FindProperty("m_Heuristic"), true);
EditorGUI.indentLevel--;
if (EditorGUI.EndChangeCheck())
@@ -39,13 +40,13 @@ public override void OnInspectorGUI()
m_RequireReload = true;
}
DisplayFailedModelChecks();
- serializedObject.ApplyModifiedProperties();
+ so.ApplyModifiedProperties();
}
///
/// Must be called within OnEditorGUI()
///
- private void DisplayFailedModelChecks()
+ void DisplayFailedModelChecks()
{
if (m_RequireReload && m_TimeSinceModelReload > k_TimeBetweenModelReloads)
{
@@ -56,7 +57,9 @@ private void DisplayFailedModelChecks()
D.logEnabled = false;
Model barracudaModel = null;
var model = (NNModel)serializedObject.FindProperty("m_Model").objectReferenceValue;
- var brainParameters = ((BehaviorParameters)target).brainParameters;
+ var behaviorParameters = (BehaviorParameters)target;
+ var sensorComponents = behaviorParameters.GetComponents();
+ var brainParameters = behaviorParameters.brainParameters;
if (model != null)
{
barracudaModel = ModelLoader.Load(model.Value);
@@ -64,7 +67,7 @@ private void DisplayFailedModelChecks()
if (brainParameters != null)
{
var failedChecks = InferenceBrain.BarracudaModelParamLoader.CheckModel(
- barracudaModel, brainParameters);
+ barracudaModel, brainParameters, sensorComponents);
foreach (var check in failedChecks)
{
if (check != null)
diff --git a/UnitySDK/Assets/ML-Agents/Editor/BrainParametersDrawer.cs b/UnitySDK/Assets/ML-Agents/Editor/BrainParametersDrawer.cs
index 3890d69509..7f00c35b70 100644
--- a/UnitySDK/Assets/ML-Agents/Editor/BrainParametersDrawer.cs
+++ b/UnitySDK/Assets/ML-Agents/Editor/BrainParametersDrawer.cs
@@ -11,13 +11,13 @@ namespace MLAgents
public class BrainParametersDrawer : PropertyDrawer
{
// The height of a line in the Unity Inspectors
- private const float k_LineHeight = 17f;
- private const int k_VecObsNumLine = 3;
- private const string k_ActionSizePropName = "vectorActionSize";
- private const string k_ActionTypePropName = "vectorActionSpaceType";
- private const string k_ActionDescriptionPropName = "vectorActionDescriptions";
- private const string k_VecObsPropName = "vectorObservationSize";
- private const string k_NumVecObsPropName = "numStackedVectorObservations";
+ const float k_LineHeight = 17f;
+ const int k_VecObsNumLine = 3;
+ const string k_ActionSizePropName = "vectorActionSize";
+ const string k_ActionTypePropName = "vectorActionSpaceType";
+ const string k_ActionDescriptionPropName = "vectorActionDescriptions";
+ const string k_VecObsPropName = "vectorObservationSize";
+ const string k_NumVecObsPropName = "numStackedVectorObservations";
///
public override float GetPropertyHeight(SerializedProperty property, GUIContent label)
@@ -55,7 +55,7 @@ public override void OnGUI(Rect position, SerializedProperty property, GUIConten
/// Rectangle on the screen to use for the property GUI.
/// The SerializedProperty of the BrainParameters
/// to make the custom GUI for.
- private static void DrawVectorObservation(Rect position, SerializedProperty property)
+ static void DrawVectorObservation(Rect position, SerializedProperty property)
{
EditorGUI.LabelField(position, "Vector Observation");
position.y += k_LineHeight;
@@ -82,7 +82,7 @@ private static void DrawVectorObservation(Rect position, SerializedProperty prop
/// The Height required to draw the Vector Observations paramaters
///
/// The height of the drawer of the Vector Observations
- private static float GetHeightDrawVectorObservation()
+ static float GetHeightDrawVectorObservation()
{
return k_VecObsNumLine * k_LineHeight;
}
@@ -93,7 +93,7 @@ private static float GetHeightDrawVectorObservation()
/// Rectangle on the screen to use for the property GUI.
/// The SerializedProperty of the BrainParameters
/// to make the custom GUI for.
- private static void DrawVectorAction(Rect position, SerializedProperty property)
+ static void DrawVectorAction(Rect position, SerializedProperty property)
{
EditorGUI.LabelField(position, "Vector Action");
position.y += k_LineHeight;
@@ -122,7 +122,7 @@ private static void DrawVectorAction(Rect position, SerializedProperty property)
/// Rectangle on the screen to use for the property GUI.
/// The SerializedProperty of the BrainParameters
/// to make the custom GUI for.
- private static void DrawContinuousVectorAction(Rect position, SerializedProperty property)
+ static void DrawContinuousVectorAction(Rect position, SerializedProperty property)
{
var vecActionSize = property.FindPropertyRelative(k_ActionSizePropName);
vecActionSize.arraySize = 1;
@@ -140,7 +140,7 @@ private static void DrawContinuousVectorAction(Rect position, SerializedProperty
/// Rectangle on the screen to use for the property GUI.
/// The SerializedProperty of the BrainParameters
/// to make the custom GUI for.
- private static void DrawDiscreteVectorAction(Rect position, SerializedProperty property)
+ static void DrawDiscreteVectorAction(Rect position, SerializedProperty property)
{
var vecActionSize = property.FindPropertyRelative(k_ActionSizePropName);
vecActionSize.arraySize = EditorGUI.IntField(
@@ -168,7 +168,7 @@ private static void DrawDiscreteVectorAction(Rect position, SerializedProperty p
/// The Height required to draw the Vector Action parameters
///
/// The height of the drawer of the Vector Action
- private static float GetHeightDrawVectorAction(SerializedProperty property)
+ static float GetHeightDrawVectorAction(SerializedProperty property)
{
var actionSize = 2 + property.FindPropertyRelative(k_ActionSizePropName).arraySize;
if (property.FindPropertyRelative(k_ActionTypePropName).enumValueIndex == 0)
diff --git a/UnitySDK/Assets/ML-Agents/Editor/DemonstrationImporter.cs b/UnitySDK/Assets/ML-Agents/Editor/DemonstrationImporter.cs
index 1f1acab322..1492e42f93 100644
--- a/UnitySDK/Assets/ML-Agents/Editor/DemonstrationImporter.cs
+++ b/UnitySDK/Assets/ML-Agents/Editor/DemonstrationImporter.cs
@@ -13,7 +13,7 @@ namespace MLAgents
[ScriptedImporter(1, new[] {"demo"})]
public class DemonstrationImporter : ScriptedImporter
{
- private const string k_IconPath = "Assets/ML-Agents/Resources/DemoIcon.png";
+ const string k_IconPath = "Assets/ML-Agents/Resources/DemoIcon.png";
public override void OnImportAsset(AssetImportContext ctx)
{
diff --git a/UnitySDK/Assets/ML-Agents/Editor/ResetParameterDrawer.cs b/UnitySDK/Assets/ML-Agents/Editor/ResetParameterDrawer.cs
index c76657d15e..390041b058 100644
--- a/UnitySDK/Assets/ML-Agents/Editor/ResetParameterDrawer.cs
+++ b/UnitySDK/Assets/ML-Agents/Editor/ResetParameterDrawer.cs
@@ -14,11 +14,11 @@ namespace MLAgents
[CustomPropertyDrawer(typeof(ResetParameters))]
public class ResetParameterDrawer : PropertyDrawer
{
- private ResetParameters m_Parameters;
+ ResetParameters m_Parameters;
// The height of a line in the Unity Inspectors
- private const float k_LineHeight = 17f;
+ const float k_LineHeight = 17f;
// This is the prefix for the key when you add a reset parameter
- private const string k_NewKeyPrefix = "Param-";
+ const string k_NewKeyPrefix = "Param-";
///
/// Computes the height of the Drawer depending on the property it is showing
@@ -84,7 +84,7 @@ public override void OnGUI(Rect position, SerializedProperty property, GUIConten
///
/// The rectangle for the Add New button.
/// The rectangle for the Remove Last button.
- private void DrawAddRemoveButtons(Rect addRect, Rect removeRect)
+ void DrawAddRemoveButtons(Rect addRect, Rect removeRect)
{
// This is the Add button
if (m_Parameters.Count == 0)
@@ -119,7 +119,7 @@ private void DrawAddRemoveButtons(Rect addRect, Rect removeRect)
/// Signals that the property has been modified and requires the scene to be saved for
/// the changes to persist. Only works when the Editor is not playing.
///
- private static void MarkSceneAsDirty()
+ static void MarkSceneAsDirty()
{
if (!EditorApplication.isPlaying)
{
@@ -132,7 +132,7 @@ private static void MarkSceneAsDirty()
///
/// The SerializedProperty of the ResetParameters
/// to make the custom GUI for.
- private void LazyInitializeParameters(SerializedProperty property)
+ void LazyInitializeParameters(SerializedProperty property)
{
if (m_Parameters != null)
{
@@ -150,7 +150,7 @@ private void LazyInitializeParameters(SerializedProperty property)
///
/// Removes the last ResetParameter from the ResetParameters
///
- private void RemoveLastParameter()
+ void RemoveLastParameter()
{
if (m_Parameters.Count > 0)
{
@@ -162,7 +162,7 @@ private void RemoveLastParameter()
///
/// Adds a new ResetParameter to the ResetParameters with a default name.
///
- private void AddParameter()
+ void AddParameter()
{
var key = k_NewKeyPrefix + m_Parameters.Count;
var value = default(float);
diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs b/UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
index 4500bd21ed..85086c7be2 100644
--- a/UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
@@ -1,15 +1,16 @@
-using System.Collections.Generic;
using NUnit.Framework;
using UnityEngine;
using System.IO.Abstractions.TestingHelpers;
+using System.Reflection;
+using MLAgents.CommunicatorObjects;
namespace MLAgents.Tests
{
public class DemonstrationTests : MonoBehaviour
{
- private const string k_DemoDirecory = "Assets/Demonstrations/";
- private const string k_ExtensionType = ".demo";
- private const string k_DemoName = "Test";
+ const string k_DemoDirecory = "Assets/Demonstrations/";
+ const string k_ExtensionType = ".demo";
+ const string k_DemoName = "Test";
[Test]
public void TestSanitization()
@@ -33,8 +34,8 @@ public void TestStoreInitalize()
{
vectorObservationSize = 3,
numStackedVectorObservations = 2,
- vectorActionDescriptions = new[] {"TestActionA", "TestActionB"},
- vectorActionSize = new[] {2, 2},
+ vectorActionDescriptions = new[] { "TestActionA", "TestActionB" },
+ vectorActionSize = new[] { 2, 2 },
vectorActionSpaceType = SpaceType.Discrete
};
@@ -46,19 +47,87 @@ public void TestStoreInitalize()
var agentInfo = new AgentInfo
{
reward = 1f,
- actionMasks = new[] {false, true},
+ actionMasks = new[] { false, true },
done = true,
id = 5,
maxStepReached = true,
- memories = new List(),
- stackedVectorObservation = new List() {1f, 1f, 1f},
- storedTextActions = "TestAction",
- storedVectorActions = new[] {0f, 1f},
- textObservation = "TestAction",
+ storedVectorActions = new[] { 0f, 1f },
};
demoStore.Record(agentInfo);
demoStore.Close();
}
+
+ public class ObservationAgent : TestAgent
+ {
+ public override void CollectObservations()
+ {
+ collectObservationsCalls += 1;
+ AddVectorObs(1f);
+ AddVectorObs(2f);
+ AddVectorObs(3f);
+ }
+ }
+
+ [Test]
+ public void TestAgentWrite()
+ {
+ var agentGo1 = new GameObject("TestAgent");
+ var bpA = agentGo1.AddComponent();
+ bpA.brainParameters.vectorObservationSize = 3;
+ bpA.brainParameters.numStackedVectorObservations = 1;
+ bpA.brainParameters.vectorActionDescriptions = new[] { "TestActionA", "TestActionB" };
+ bpA.brainParameters.vectorActionSize = new[] { 2, 2 };
+ bpA.brainParameters.vectorActionSpaceType = SpaceType.Discrete;
+
+ agentGo1.AddComponent();
+ var agent1 = agentGo1.GetComponent();
+
+ agentGo1.AddComponent();
+ var demoRecorder = agentGo1.GetComponent();
+ var fileSystem = new MockFileSystem();
+ demoRecorder.demonstrationName = "TestBrain";
+ demoRecorder.record = true;
+ demoRecorder.InitializeDemoStore(fileSystem);
+
+ var acaGo = new GameObject("TestAcademy");
+ acaGo.AddComponent();
+ var aca = acaGo.GetComponent();
+ aca.resetParameters = new ResetParameters();
+
+ var academyInitializeMethod = typeof(Academy).GetMethod("InitializeEnvironment",
+ BindingFlags.Instance | BindingFlags.NonPublic);
+ var agentEnableMethod = typeof(Agent).GetMethod("OnEnable",
+ BindingFlags.Instance | BindingFlags.NonPublic);
+ var agentSendInfo = typeof(Agent).GetMethod("SendInfo",
+ BindingFlags.Instance | BindingFlags.NonPublic);
+
+ agentEnableMethod?.Invoke(agent1, new object[] { });
+ academyInitializeMethod?.Invoke(aca, new object[] { });
+
+ // Step the agent
+ agent1.RequestDecision();
+ agentSendInfo?.Invoke(agent1, new object[] { });
+
+ demoRecorder.Close();
+
+ // Read back the demo file and make sure observations were written
+ var reader = fileSystem.File.OpenRead("Assets/Demonstrations/TestBrain.demo");
+ reader.Seek(DemonstrationStore.MetaDataBytes + 1, 0);
+ BrainParametersProto.Parser.ParseDelimitedFrom(reader);
+
+ var agentInfoProto = AgentInfoActionPairProto.Parser.ParseDelimitedFrom(reader).AgentInfo;
+ var obs = agentInfoProto.Observations[2]; // skip dummy sensors
+ {
+ var vecObs = obs.FloatData.Data;
+ Assert.AreEqual(bpA.brainParameters.vectorObservationSize, vecObs.Count);
+ for (var i = 0; i < vecObs.Count; i++)
+ {
+ Assert.AreEqual((float)i + 1, vecObs[i]);
+ }
+ }
+
+
+ }
}
}
diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs b/UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs
index 3b3ef2245d..9c504ac4ad 100644
--- a/UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs
@@ -9,24 +9,24 @@ namespace MLAgents.Tests
{
public class EditModeTestInternalBrainTensorApplier
{
- private class TestAgent : Agent
+ class TestAgent : Agent
{
public AgentAction GetAction()
{
- var f = typeof(Agent).GetField(
+ var f = typeof(Agent).GetField(
"m_Action", BindingFlags.Instance | BindingFlags.NonPublic);
return (AgentAction)f.GetValue(this);
}
}
- private List GetFakeAgentInfos()
+ List GetFakeAgentInfos()
{
var goA = new GameObject("goA");
var agentA = goA.AddComponent();
var goB = new GameObject("goB");
var agentB = goB.AddComponent();
- return new List {agentA, agentB};
+ return new List { agentA, agentB };
}
[Test]
@@ -34,7 +34,8 @@ public void Construction()
{
var bp = new BrainParameters();
var alloc = new TensorCachingAllocator();
- var tensorGenerator = new TensorApplier(bp, 0, alloc);
+ var mem = new Dictionary>();
+ var tensorGenerator = new TensorApplier(bp, 0, alloc, mem);
Assert.IsNotNull(tensorGenerator);
alloc.Dispose();
}
@@ -44,8 +45,8 @@ public void ApplyContinuousActionOutput()
{
var inputTensor = new TensorProxy()
{
- shape = new long[] {2, 3},
- data = new Tensor(2, 3, new float[] {1, 2, 3, 4, 5, 6})
+ shape = new long[] { 2, 3 },
+ data = new Tensor(2, 3, new float[] { 1, 2, 3, 4, 5, 6 })
};
var agentInfos = GetFakeAgentInfos();
@@ -73,15 +74,15 @@ public void ApplyDiscreteActionOutput()
{
var inputTensor = new TensorProxy()
{
- shape = new long[] {2, 5},
+ shape = new long[] { 2, 5 },
data = new Tensor(
2,
5,
- new[] {0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f})
+ new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f })
};
var agentInfos = GetFakeAgentInfos();
var alloc = new TensorCachingAllocator();
- var applier = new DiscreteActionOutputApplier(new[] {2, 3}, 0, alloc);
+ var applier = new DiscreteActionOutputApplier(new[] { 2, 3 }, 0, alloc);
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos;
@@ -99,43 +100,13 @@ public void ApplyDiscreteActionOutput()
alloc.Dispose();
}
- [Test]
- public void ApplyMemoryOutput()
- {
- var inputTensor = new TensorProxy()
- {
- shape = new long[] {2, 5},
- data = new Tensor(
- 2,
- 5,
- new[] {0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f})
- };
- var agentInfos = GetFakeAgentInfos();
-
- var applier = new MemoryOutputApplier();
- applier.Apply(inputTensor, agentInfos);
- var agents = agentInfos;
-
- var agent = agents[0] as TestAgent;
- Assert.NotNull(agent);
- var action = agent.GetAction();
- Assert.AreEqual(action.memories[0], 0.5f);
- Assert.AreEqual(action.memories[1], 22.5f);
-
- agent = agents[1] as TestAgent;
- Assert.NotNull(agent);
- action = agent.GetAction();
- Assert.AreEqual(action.memories[2], 6);
- Assert.AreEqual(action.memories[3], 7);
- }
-
[Test]
public void ApplyValueEstimate()
{
var inputTensor = new TensorProxy()
{
- shape = new long[] {2, 1},
- data = new Tensor(2, 1, new[] {0.5f, 8f})
+ shape = new long[] { 2, 1 },
+ data = new Tensor(2, 1, new[] { 0.5f, 8f })
};
var agentInfos = GetFakeAgentInfos();
diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs b/UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs
index 0d64b92214..79bab1bd76 100644
--- a/UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs
@@ -1,46 +1,68 @@
using System.Collections.Generic;
-using System.Linq;
using Barracuda;
using NUnit.Framework;
using UnityEngine;
using MLAgents.InferenceBrain;
+using System.Reflection;
+
namespace MLAgents.Tests
{
public class EditModeTestInternalBrainTensorGenerator
{
- private static IEnumerable GetFakeAgentInfos()
+ static IEnumerable GetFakeAgents()
{
+ var acaGo = new GameObject("TestAcademy");
+ acaGo.AddComponent();
+ var aca = acaGo.GetComponent();
+ aca.resetParameters = new ResetParameters();
+
var goA = new GameObject("goA");
+ var bpA = goA.AddComponent();
+ bpA.brainParameters.vectorObservationSize = 3;
+ bpA.brainParameters.numStackedVectorObservations = 1;
var agentA = goA.AddComponent();
+
+ var goB = new GameObject("goB");
+ var bpB = goB.AddComponent();
+ bpB.brainParameters.vectorObservationSize = 3;
+ bpB.brainParameters.numStackedVectorObservations = 1;
+ var agentB = goB.AddComponent();
+
+ var agents = new List { agentA, agentB };
+ foreach (var agent in agents)
+ {
+ var agentEnableMethod = typeof(Agent).GetMethod("OnEnableHelper",
+ BindingFlags.Instance | BindingFlags.NonPublic);
+ agentEnableMethod?.Invoke(agent, new object[] { aca });
+ }
+ agentA.collectObservationsSensor.AddObservation(new Vector3(1, 2, 3));
+ agentB.collectObservationsSensor.AddObservation(new Vector3(4, 5, 6));
+
var infoA = new AgentInfo
{
- stackedVectorObservation = new[] { 1f, 2f, 3f }.ToList(),
- memories = null,
storedVectorActions = new[] { 1f, 2f },
actionMasks = null
};
- var goB = new GameObject("goB");
- var agentB = goB.AddComponent();
+
var infoB = new AgentInfo
{
- stackedVectorObservation = new[] { 4f, 5f, 6f }.ToList(),
- memories = new[] { 1f, 1f, 1f }.ToList(),
storedVectorActions = new[] { 3f, 4f },
actionMasks = new[] { true, false, false, false, false },
};
+
agentA.Info = infoA;
agentB.Info = infoB;
- return new List { agentA, agentB };
+ return agents;
}
[Test]
public void Construction()
{
- var bp = new BrainParameters();
var alloc = new TensorCachingAllocator();
- var tensorGenerator = new TensorGenerator(bp, 0, alloc);
+ var mem = new Dictionary>();
+ var tensorGenerator = new TensorGenerator(0, alloc, mem);
Assert.IsNotNull(tensorGenerator);
alloc.Dispose();
}
@@ -79,9 +101,12 @@ public void GenerateVectorObservation()
shape = new long[] { 2, 3 }
};
const int batchSize = 4;
- var agentInfos = GetFakeAgentInfos();
+ var agentInfos = GetFakeAgents();
var alloc = new TensorCachingAllocator();
var generator = new VectorObservationGenerator(alloc);
+ generator.AddSensorIndex(0);
+ generator.AddSensorIndex(1);
+ generator.AddSensorIndex(2);
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 0], 1);
@@ -91,26 +116,6 @@ public void GenerateVectorObservation()
alloc.Dispose();
}
- [Test]
- public void GenerateRecurrentInput()
- {
- var inputTensor = new TensorProxy
- {
- shape = new long[] { 2, 5 }
- };
- const int batchSize = 4;
- var agentInfos = GetFakeAgentInfos();
- var alloc = new TensorCachingAllocator();
- var generator = new RecurrentInputGenerator(alloc);
- generator.Generate(inputTensor, batchSize, agentInfos);
- Assert.IsNotNull(inputTensor.data);
- Assert.AreEqual(inputTensor.data[0, 0], 0);
- Assert.AreEqual(inputTensor.data[0, 4], 0);
- Assert.AreEqual(inputTensor.data[1, 0], 1);
- Assert.AreEqual(inputTensor.data[1, 4], 0);
- alloc.Dispose();
- }
-
[Test]
public void GeneratePreviousActionInput()
{
@@ -120,7 +125,7 @@ public void GeneratePreviousActionInput()
valueType = TensorProxy.TensorType.Integer
};
const int batchSize = 4;
- var agentInfos = GetFakeAgentInfos();
+ var agentInfos = GetFakeAgents();
var alloc = new TensorCachingAllocator();
var generator = new PreviousActionInputGenerator(alloc);
@@ -142,7 +147,7 @@ public void GenerateActionMaskInput()
valueType = TensorProxy.TensorType.FloatingPoint
};
const int batchSize = 4;
- var agentInfos = GetFakeAgentInfos();
+ var agentInfos = GetFakeAgents();
var alloc = new TensorCachingAllocator();
var generator = new ActionMaskInputGenerator(alloc);
generator.Generate(inputTensor, batchSize, agentInfos);
diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs b/UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
index a5602f2a03..e64fe937cd 100644
--- a/UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
@@ -2,7 +2,6 @@
using NUnit.Framework;
using System.Reflection;
using MLAgents.Sensor;
-using MLAgents.InferenceBrain;
namespace MLAgents.Tests
{
@@ -36,12 +35,12 @@ public override void InitializeAgent()
{
initializeAgentCalls += 1;
- // Add in some custom sensors so we can confirm they get sorted as expected.
+ // Add in some custom Sensors so we can confirm they get sorted as expected.
var sensor1 = new TestSensor("testsensor1");
var sensor2 = new TestSensor("testsensor2");
- m_Sensors.Add(sensor2);
- m_Sensors.Add(sensor1);
+ sensors.Add(sensor2);
+ sensors.Add(sensor1);
}
public override void CollectObservations()
@@ -50,7 +49,7 @@ public override void CollectObservations()
AddVectorObs(0f);
}
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
agentActionCalls += 1;
AddReward(0.1f);
@@ -83,10 +82,14 @@ public TestSensor(string n)
public int[] GetFloatObservationShape()
{
- return new[] { 1 };
+ return new[] { 0 };
}
- public void WriteToTensor(TensorProxy tensorProxy, int agentIndex) { }
+ public int Write(WriteAdapter adapter)
+ {
+ // No-op
+ return 0;
+ }
public byte[] GetCompressedObservation()
{
@@ -102,6 +105,8 @@ public string GetName()
{
return sensorName;
}
+
+ public void Update() { }
}
public class EditModeTestGeneration
@@ -196,9 +201,9 @@ public void TestAgent()
Assert.AreEqual(0, agent1.agentActionCalls);
Assert.AreEqual(0, agent2.agentActionCalls);
- // Make sure the sensors were sorted
- Assert.AreEqual(agent1.m_Sensors[0].GetName(), "testsensor1");
- Assert.AreEqual(agent1.m_Sensors[1].GetName(), "testsensor2");
+ // Make sure the Sensors were sorted
+ Assert.AreEqual(agent1.sensors[0].GetName(), "testsensor1");
+ Assert.AreEqual(agent1.sensors[1].GetName(), "testsensor2");
}
}
diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/RandomNormalTest.cs b/UnitySDK/Assets/ML-Agents/Editor/Tests/RandomNormalTest.cs
index 6c6ced2d5d..3284caca0d 100644
--- a/UnitySDK/Assets/ML-Agents/Editor/Tests/RandomNormalTest.cs
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/RandomNormalTest.cs
@@ -6,9 +6,9 @@ namespace MLAgents.Tests
{
public class RandomNormalTest
{
- private const float k_FirstValue = -1.19580f;
- private const float k_SecondValue = -0.97345f;
- private const double k_Epsilon = 0.0001;
+ const float k_FirstValue = -1.19580f;
+ const float k_SecondValue = -0.97345f;
+ const double k_Epsilon = 0.0001;
[Test]
public void RandomNormalTestTwoDouble()
diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/RayPerceptionTests.cs b/UnitySDK/Assets/ML-Agents/Editor/Tests/RayPerceptionTests.cs
index 75e3cc6bff..dd7ce5745f 100644
--- a/UnitySDK/Assets/ML-Agents/Editor/Tests/RayPerceptionTests.cs
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/RayPerceptionTests.cs
@@ -14,9 +14,7 @@ public void TestPerception3D()
var go = new GameObject("MyGameObject");
var rayPer3D = go.AddComponent();
- var result = rayPer3D.Perceive(1f, angles ,
- tags, 0f, 0f);
- Debug.Log(result.Count);
+ var result = rayPer3D.Perceive(1f, angles, tags);
Assert.IsTrue(result.Count == angles.Length * (tags.Length + 2));
}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core.meta b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor.meta
similarity index 77%
rename from UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core.meta
rename to UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor.meta
index 42930051ad..aa087de946 100644
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core.meta
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor.meta
@@ -1,5 +1,5 @@
fileFormatVersion: 2
-guid: 13df47c141a644f57bdb0a667879ef0b
+guid: 1b196836e6e3a4361bc62265ec88ebed
folderAsset: yes
DefaultImporter:
externalObjects: {}
diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/RayPerceptionSensorTests.cs b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/RayPerceptionSensorTests.cs
new file mode 100644
index 0000000000..8d8f689b11
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/RayPerceptionSensorTests.cs
@@ -0,0 +1,21 @@
+using NUnit.Framework;
+using UnityEngine;
+using MLAgents.Sensor;
+
+namespace MLAgents.Tests
+{
+ public class RayPerceptionSensorTests
+ {
+ [Test]
+ public void TestGetRayAngles()
+ {
+ var angles = RayPerceptionSensorComponentBase.GetRayAngles(3, 90f);
+ var expectedAngles = new [] { 90f, 60f, 120f, 30f, 150f, 0f, 180f };
+ Assert.AreEqual(expectedAngles.Length, angles.Length);
+ for (var i = 0; i < angles.Length; i++)
+ {
+ Assert.AreEqual(expectedAngles[i], angles[i], .01);
+ }
+ }
+ }
+}
diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/RayPerceptionSensorTests.cs.meta b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/RayPerceptionSensorTests.cs.meta
new file mode 100644
index 0000000000..ae0be2c197
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/RayPerceptionSensorTests.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: d2983e2bca9a40398f287727dc0472a5
+timeCreated: 1573242741
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/StackingSensorTests.cs b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/StackingSensorTests.cs
new file mode 100644
index 0000000000..894b6bfece
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/StackingSensorTests.cs
@@ -0,0 +1,49 @@
+using NUnit.Framework;
+using UnityEngine;
+using MLAgents.Sensor;
+
+namespace MLAgents.Tests
+{
+ public class StackingSensorTests
+ {
+ [Test]
+ public void TestCtor()
+ {
+ ISensor wrapped = new VectorSensor(4);
+ ISensor sensor = new StackingSensor(wrapped, 4);
+ Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName());
+ Assert.AreEqual(sensor.GetFloatObservationShape(), new [] {16});
+ }
+
+ [Test]
+ public void TestStacking()
+ {
+ VectorSensor wrapped = new VectorSensor(2);
+ ISensor sensor = new StackingSensor(wrapped, 3);
+
+ wrapped.AddObservation(new [] {1f, 2f});
+ SensorTestHelper.CompareObservation(sensor, new [] {0f, 0f, 0f, 0f, 1f, 2f});
+
+ sensor.Update();
+ wrapped.AddObservation(new [] {3f, 4f});
+ SensorTestHelper.CompareObservation(sensor, new [] {0f, 0f, 1f, 2f, 3f, 4f});
+
+ sensor.Update();
+ wrapped.AddObservation(new [] {5f, 6f});
+ SensorTestHelper.CompareObservation(sensor, new [] {1f, 2f, 3f, 4f, 5f, 6f});
+
+ sensor.Update();
+ wrapped.AddObservation(new [] {7f, 8f});
+ SensorTestHelper.CompareObservation(sensor, new [] {3f, 4f, 5f, 6f, 7f, 8f});
+
+ sensor.Update();
+ wrapped.AddObservation(new [] {9f, 10f});
+ SensorTestHelper.CompareObservation(sensor, new [] {5f, 6f, 7f, 8f, 9f, 10f});
+
+ // Check that if we don't call Update(), the same observations are produced
+ SensorTestHelper.CompareObservation(sensor, new [] {5f, 6f, 7f, 8f, 9f, 10f});
+ }
+
+
+ }
+}
diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/StackingSensorTests.cs.meta b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/StackingSensorTests.cs.meta
new file mode 100644
index 0000000000..81723dd4cc
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/StackingSensorTests.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 7b071fdf91474d18a05ea20175c6b3bd
+timeCreated: 1572564843
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/VectorSensorTests.cs b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/VectorSensorTests.cs
new file mode 100644
index 0000000000..da9eaec0ec
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/VectorSensorTests.cs
@@ -0,0 +1,145 @@
+using NUnit.Framework;
+using UnityEngine;
+using MLAgents.Sensor;
+
+namespace MLAgents.Tests
+{
+ public class SensorTestHelper
+ {
+ public static void CompareObservation(ISensor sensor, float[] expected)
+ {
+ var numExpected = expected.Length;
+ const float fill = -1337f;
+ var output = new float[numExpected];
+ for (var i = 0; i < numExpected; i++)
+ {
+ output[i] = fill;
+ }
+ Assert.AreEqual(fill, output[0]);
+
+ WriteAdapter writer = new WriteAdapter();
+ writer.SetTarget(output, 0);
+
+ // Make sure WriteAdapter didn't touch anything
+ Assert.AreEqual(fill, output[0]);
+
+ sensor.Write(writer);
+ for (var i = 0; i < numExpected; i++)
+ {
+ Assert.AreEqual(expected[i], output[i]);
+ }
+ }
+ }
+
+ public class VectorSensorTests
+ {
+ [Test]
+ public void TestCtor()
+ {
+ ISensor sensor = new VectorSensor(4);
+ Assert.AreEqual("VectorSensor_size4", sensor.GetName());
+
+ sensor = new VectorSensor(3, "test_sensor");
+ Assert.AreEqual("test_sensor", sensor.GetName());
+ }
+
+ [Test]
+ public void TestWrite()
+ {
+ var sensor = new VectorSensor(4);
+ sensor.AddObservation(1f);
+ sensor.AddObservation(2f);
+ sensor.AddObservation(3f);
+ sensor.AddObservation(4f);
+
+ SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f });
+ // Check that if we don't call Update(), the same observations are produced
+ SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f });
+
+ // Check that Update() clears the data
+ sensor.Update();
+ SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f });
+
+ }
+
+ [Test]
+ public void TestAddObservationFloat()
+ {
+ var sensor = new VectorSensor(1);
+ sensor.AddObservation(1.2f);
+ SensorTestHelper.CompareObservation(sensor, new []{1.2f});
+ }
+
+ [Test]
+ public void TestAddObservationInt()
+ {
+ var sensor = new VectorSensor(1);
+ sensor.AddObservation(42);
+ SensorTestHelper.CompareObservation(sensor, new []{42f});
+ }
+
+ [Test]
+ public void TestAddObservationVec()
+ {
+ var sensor = new VectorSensor(3);
+ sensor.AddObservation(new Vector3(1,2,3));
+ SensorTestHelper.CompareObservation(sensor, new []{1f, 2f, 3f});
+
+ sensor = new VectorSensor(2);
+ sensor.AddObservation(new Vector2(4,5));
+ SensorTestHelper.CompareObservation(sensor, new[] { 4f, 5f });
+ }
+
+ [Test]
+ public void TestAddObservationQuaternion()
+ {
+ var sensor = new VectorSensor(4);
+ sensor.AddObservation(Quaternion.identity);
+ SensorTestHelper.CompareObservation(sensor, new []{0f, 0f, 0f, 1f});
+ }
+
+ [Test]
+ public void TestWriteEnumerable()
+ {
+ var sensor = new VectorSensor(4);
+ sensor.AddObservation(new [] {1f, 2f, 3f, 4f});
+
+ SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f });
+ }
+
+ [Test]
+ public void TestAddObservationBool()
+ {
+ var sensor = new VectorSensor(1);
+ sensor.AddObservation(true);
+ SensorTestHelper.CompareObservation(sensor, new []{1f});
+ }
+
+ [Test]
+ public void TestAddObservationOneHot()
+ {
+ var sensor = new VectorSensor(4);
+ sensor.AddOneHotObservation(2, 4);
+ SensorTestHelper.CompareObservation(sensor, new []{0f, 0f, 1f, 0f});
+ }
+
+ [Test]
+ public void TestWriteTooMany()
+ {
+ var sensor = new VectorSensor(2);
+ sensor.AddObservation(new [] {1f, 2f, 3f, 4f});
+
+ SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f});
+ }
+
+ [Test]
+ public void TestWriteNotEnough()
+ {
+ var sensor = new VectorSensor(4);
+ sensor.AddObservation(new [] {1f, 2f});
+
+ // Make sure extra zeros are added
+ SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 0f, 0f});
+ }
+ }
+}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/NNModelImporter.cs.meta b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/VectorSensorTests.cs.meta
similarity index 83%
rename from UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/NNModelImporter.cs.meta
rename to UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/VectorSensorTests.cs.meta
index 98a74a1038..05c14f9206 100644
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/NNModelImporter.cs.meta
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/VectorSensorTests.cs.meta
@@ -1,5 +1,5 @@
fileFormatVersion: 2
-guid: 19ed1486aa27d4903b34839f37b8f69f
+guid: 18c0d390ce4c5464ab48b96db0392eb0
MonoImporter:
externalObjects: {}
serializedVersion: 2
diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/WriterAdapterTests.cs b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/WriterAdapterTests.cs
new file mode 100644
index 0000000000..424198dcd7
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/WriterAdapterTests.cs
@@ -0,0 +1,99 @@
+using NUnit.Framework;
+using UnityEngine;
+using MLAgents.Sensor;
+
+using Barracuda;
+using MLAgents.InferenceBrain;
+
+
+namespace MLAgents.Tests
+{
+ public class WriteAdapterTests
+ {
+ [Test]
+ public void TestWritesToIList()
+ {
+ WriteAdapter writer = new WriteAdapter();
+ var buffer = new[] { 0f, 0f, 0f };
+
+ writer.SetTarget(buffer, 0);
+ // Elementwise writes
+ writer[0] = 1f;
+ writer[2] = 2f;
+ Assert.AreEqual(new[] { 1f, 0f, 2f }, buffer);
+
+ // Elementwise writes with offset
+ writer.SetTarget(buffer, 1);
+ writer[0] = 3f;
+ Assert.AreEqual(new[] { 1f, 3f, 2f }, buffer);
+
+ // AddRange
+ writer.SetTarget(buffer, 0);
+ writer.AddRange(new [] {4f, 5f});
+ Assert.AreEqual(new[] { 4f, 5f, 2f }, buffer);
+
+ // AddRange with offset
+ writer.SetTarget(buffer, 1);
+ writer.AddRange(new [] {6f, 7f});
+ Assert.AreEqual(new[] { 4f, 6f, 7f }, buffer);
+ }
+
+ [Test]
+ public void TestWritesToTensor()
+ {
+ WriteAdapter writer = new WriteAdapter();
+ var t = new TensorProxy
+ {
+ valueType = TensorProxy.TensorType.FloatingPoint,
+ data = new Tensor(2, 3)
+ };
+ writer.SetTarget(t, 0, 0);
+ Assert.AreEqual(0f, t.data[0, 0]);
+ writer[0] = 1f;
+ Assert.AreEqual(1f, t.data[0, 0]);
+
+ writer.SetTarget(t, 1, 1);
+ writer[0] = 2f;
+ writer[1] = 3f;
+ // [0, 0] shouldn't change
+ Assert.AreEqual(1f, t.data[0, 0]);
+ Assert.AreEqual(2f, t.data[1, 1]);
+ Assert.AreEqual(3f, t.data[1, 2]);
+
+ // AddRange
+ t = new TensorProxy
+ {
+ valueType = TensorProxy.TensorType.FloatingPoint,
+ data = new Tensor(2, 3)
+ };
+
+ writer.SetTarget(t, 1, 1);
+ writer.AddRange(new [] {-1f, -2f});
+ Assert.AreEqual(0f, t.data[0, 0]);
+ Assert.AreEqual(0f, t.data[0, 1]);
+ Assert.AreEqual(0f, t.data[0, 2]);
+ Assert.AreEqual(0f, t.data[1, 0]);
+ Assert.AreEqual(-1f, t.data[1, 1]);
+ Assert.AreEqual(-2f, t.data[1, 2]);
+ }
+
+ [Test]
+ public void TestWritesToTensor3D()
+ {
+ WriteAdapter writer = new WriteAdapter();
+ var t = new TensorProxy
+ {
+ valueType = TensorProxy.TensorType.FloatingPoint,
+ data = new Tensor(2, 2, 2, 3)
+ };
+
+ writer.SetTarget(t, 0, 0);
+ writer[1, 0, 1] = 1f;
+ Assert.AreEqual(1f, t.data[0, 1, 0, 1]);
+
+ writer.SetTarget(t, 0, 1);
+ writer[1, 0, 0] = 2f;
+ Assert.AreEqual(2f, t.data[0, 1, 0, 1]);
+ }
+ }
+}
diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/WriterAdapterTests.cs.meta b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/WriterAdapterTests.cs.meta
new file mode 100644
index 0000000000..31e61311e9
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/WriterAdapterTests.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 3de9cbda816e4d7b907e765577dd54f7
+timeCreated: 1572568337
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/StandaloneBuildTest.cs b/UnitySDK/Assets/ML-Agents/Editor/Tests/StandaloneBuildTest.cs
index d88aacbd43..3acd01bd9a 100644
--- a/UnitySDK/Assets/ML-Agents/Editor/Tests/StandaloneBuildTest.cs
+++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/StandaloneBuildTest.cs
@@ -14,7 +14,7 @@ static void BuildStandalonePlayerOSX()
string[] scenes = { "Assets/ML-Agents/Examples/3DBall/Scenes/3DBall.unity" };
var buildResult = BuildPipeline.BuildPlayer(scenes, "testPlayer", BuildTarget.StandaloneOSX, BuildOptions.None);
#if UNITY_2018_1_OR_NEWER
- var isOK = buildResult.summary.result == BuildResult.Succeeded;
+ var isOk = buildResult.summary.result == BuildResult.Succeeded;
var error = "";
foreach (var stepInfo in buildResult.steps)
{
@@ -28,9 +28,9 @@ static void BuildStandalonePlayerOSX()
}
#else
var error = buildResult;
- var isOK = string.IsNullOrEmpty(error);
+ var isOk = string.IsNullOrEmpty(error);
#endif
- if (isOK)
+ if (isOk)
{
EditorApplication.Exit(0);
}
diff --git a/UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs b/UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
index 2ab36cdbd9..8e91858fb0 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
@@ -5,8 +5,8 @@ public class Ball3DAgent : Agent
{
[Header("Specific to Ball3D")]
public GameObject ball;
- private Rigidbody m_BallRb;
- private ResetParameters m_ResetParams;
+ Rigidbody m_BallRb;
+ ResetParameters m_ResetParams;
public override void InitializeAgent()
{
@@ -24,7 +24,7 @@ public override void CollectObservations()
AddVectorObs(m_BallRb.velocity);
}
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
diff --git a/UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs b/UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs
index 97be575f78..972d6b11e0 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs
@@ -5,8 +5,8 @@ public class Ball3DHardAgent : Agent
{
[Header("Specific to Ball3DHard")]
public GameObject ball;
- private Rigidbody m_BallRb;
- private ResetParameters m_ResetParams;
+ Rigidbody m_BallRb;
+ ResetParameters m_ResetParams;
public override void InitializeAgent()
{
@@ -23,7 +23,7 @@ public override void CollectObservations()
AddVectorObs((ball.transform.position - gameObject.transform.position));
}
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
diff --git a/UnitySDK/Assets/ML-Agents/Examples/3DBall/TFModels/3DBall.nn b/UnitySDK/Assets/ML-Agents/Examples/3DBall/TFModels/3DBall.nn
index 22af655d84..a98273d84c 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/3DBall/TFModels/3DBall.nn and b/UnitySDK/Assets/ML-Agents/Examples/3DBall/TFModels/3DBall.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/3DBall/TFModels/3DBallHard.nn b/UnitySDK/Assets/ML-Agents/Examples/3DBall/TFModels/3DBallHard.nn
index 66cfc5120b..dd32c10718 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/3DBall/TFModels/3DBallHard.nn and b/UnitySDK/Assets/ML-Agents/Examples/3DBall/TFModels/3DBallHard.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Basic/Prefabs/Basic.prefab b/UnitySDK/Assets/ML-Agents/Examples/Basic/Prefabs/Basic.prefab
index f9d4ab6e74..10bea835b0 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Basic/Prefabs/Basic.prefab
+++ b/UnitySDK/Assets/ML-Agents/Examples/Basic/Prefabs/Basic.prefab
@@ -912,12 +912,11 @@ MonoBehaviour:
vectorObservationSize: 20
numStackedVectorObservations: 1
vectorActionSize: 03000000
- cameraResolutions: []
vectorActionDescriptions: []
vectorActionSpaceType: 0
- m_Model: {fileID: 11400000, guid: 53fa7c392ce3c492281be273668f6aaf, type: 3}
+ m_Model: {fileID: 11400000, guid: 468c183196f1844f69e125c99dd135a1, type: 3}
m_InferenceDevice: 0
- m_UseHeuristic: 0
+ m_BehaviorType: 0
m_BehaviorName: Basic
--- !u!114 &114827551040495112
MonoBehaviour:
@@ -931,8 +930,6 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
agentParameters:
- agentCameras: []
- agentRenderTextures: []
maxStep: 0
resetOnDone: 1
onDemandDecision: 1
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs b/UnitySDK/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs
index bc7ed3f1d0..71ecb3ff9d 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs
@@ -4,9 +4,9 @@
public class BasicAgent : Agent
{
[Header("Specific to Basic")]
- private BasicAcademy m_Academy;
+ BasicAcademy m_Academy;
public float timeBetweenDecisionsAtInference;
- private float m_TimeSinceDecision;
+ float m_TimeSinceDecision;
int m_Position;
int m_SmallGoalPosition;
int m_LargeGoalPosition;
@@ -25,7 +25,7 @@ public override void CollectObservations()
AddVectorObs(m_Position, 20);
}
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
var movement = (int)vectorAction[0];
@@ -95,7 +95,7 @@ public void FixedUpdate()
WaitTimeInference();
}
- private void WaitTimeInference()
+ void WaitTimeInference()
{
if (!m_Academy.GetIsInference())
{
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Basic/TFModels/Basic.nn b/UnitySDK/Assets/ML-Agents/Examples/Basic/TFModels/Basic.nn
index d30b05bfa0..992e053099 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/Basic/TFModels/Basic.nn and b/UnitySDK/Assets/ML-Agents/Examples/Basic/TFModels/Basic.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Bouncer/Prefabs/Environment.prefab b/UnitySDK/Assets/ML-Agents/Examples/Bouncer/Prefabs/Environment.prefab
index dc7003a97a..dac829b2cb 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Bouncer/Prefabs/Environment.prefab
+++ b/UnitySDK/Assets/ML-Agents/Examples/Bouncer/Prefabs/Environment.prefab
@@ -828,8 +828,6 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
agentParameters:
- agentCameras: []
- agentRenderTextures: []
maxStep: 0
resetOnDone: 1
onDemandDecision: 1
@@ -852,10 +850,9 @@ MonoBehaviour:
vectorObservationSize: 6
numStackedVectorObservations: 3
vectorActionSize: 03000000
- cameraResolutions: []
vectorActionDescriptions: []
vectorActionSpaceType: 1
- m_Model: {fileID: 11400000, guid: f5250a39cb2134db49b833e3c92527a1, type: 3}
+ m_Model: {fileID: 11400000, guid: 6c4ee6ab37d9b49b492a5cc49ed47ca0, type: 3}
m_InferenceDevice: 0
- m_UseHeuristic: 0
+ m_BehaviorType: 0
m_BehaviorName: Bouncer
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs b/UnitySDK/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs
index e72bca3621..f885236d78 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs
@@ -32,7 +32,7 @@ public override void CollectObservations()
AddVectorObs(target.transform.localPosition);
}
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
for (var i = 0; i < vectorAction.Length; i++)
{
@@ -72,7 +72,7 @@ public override void AgentOnDone()
{
}
- private void FixedUpdate()
+ void FixedUpdate()
{
if (Physics.Raycast(transform.position, new Vector3(0f, -1f, 0f), 0.51f) && m_JumpCooldown <= 0f)
{
@@ -114,7 +114,7 @@ public override float[] Heuristic()
return action;
}
- private void Update()
+ void Update()
{
if (m_LookDir.magnitude > float.Epsilon)
{
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerTarget.cs b/UnitySDK/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerTarget.cs
index 9432e812c0..84404b99f1 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerTarget.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerTarget.cs
@@ -9,7 +9,7 @@ void FixedUpdate()
gameObject.transform.Rotate(new Vector3(1, 0, 0), 0.5f);
}
- private void OnTriggerEnter(Collider collision)
+ void OnTriggerEnter(Collider collision)
{
var agent = collision.gameObject.GetComponent();
if (agent != null)
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Bouncer/TFModels/Bouncer.nn b/UnitySDK/Assets/ML-Agents/Examples/Bouncer/TFModels/Bouncer.nn
index 4d65955bb1..895b624ac1 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/Bouncer/TFModels/Bouncer.nn and b/UnitySDK/Assets/ML-Agents/Examples/Bouncer/TFModels/Bouncer.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Crawler/Prefabs/DynamicPlatform.prefab b/UnitySDK/Assets/ML-Agents/Examples/Crawler/Prefabs/DynamicPlatform.prefab
index 7f34ab30e9..a08e6d6afd 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Crawler/Prefabs/DynamicPlatform.prefab
+++ b/UnitySDK/Assets/ML-Agents/Examples/Crawler/Prefabs/DynamicPlatform.prefab
@@ -2102,12 +2102,11 @@ MonoBehaviour:
vectorObservationSize: 126
numStackedVectorObservations: 1
vectorActionSize: 14000000
- cameraResolutions: []
vectorActionDescriptions: []
vectorActionSpaceType: 1
- m_Model: {fileID: 11400000, guid: abc9c8f2180154ed7ba3f116ab0beb90, type: 3}
+ m_Model: {fileID: 11400000, guid: 039557e683d584183a2a82cf8b1904c0, type: 3}
m_InferenceDevice: 0
- m_UseHeuristic: 0
+ m_BehaviorType: 0
m_BehaviorName: CrawlerDynamic
--- !u!114 &114157055237627828
MonoBehaviour:
@@ -2216,8 +2215,6 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
agentParameters:
- agentCameras: []
- agentRenderTextures: []
maxStep: 5000
resetOnDone: 1
onDemandDecision: 0
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab b/UnitySDK/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab
index 0f97e6d3a8..654094ff4b 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab
+++ b/UnitySDK/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab
@@ -1812,8 +1812,6 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
agentParameters:
- agentCameras: []
- agentRenderTextures: []
maxStep: 5000
resetOnDone: 1
onDemandDecision: 0
@@ -1953,12 +1951,11 @@ MonoBehaviour:
vectorObservationSize: 126
numStackedVectorObservations: 1
vectorActionSize: 14000000
- cameraResolutions: []
vectorActionDescriptions: []
vectorActionSpaceType: 1
- m_Model: {fileID: 11400000, guid: 48982d8fa360a4ed0bb265495e4f378b, type: 3}
+ m_Model: {fileID: 11400000, guid: ac4a23ff4713140198629ae0844926ee, type: 3}
m_InferenceDevice: 0
- m_UseHeuristic: 0
+ m_BehaviorType: 0
m_BehaviorName: CrawlerStatic
--- !u!114 &114954029223843696
MonoBehaviour:
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs b/UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
index 2983ed9e83..9046383086 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
@@ -163,7 +163,7 @@ public void GetRandomTargetPos()
target.position = newTargetPos + ground.position;
}
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
if (detectTargets)
{
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn b/UnitySDK/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn
index fc8e03a99c..3f011db176 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn and b/UnitySDK/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn b/UnitySDK/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn
index 77cf29377c..d77ff7c947 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn and b/UnitySDK/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/FoodCollector/Prefabs/FoodCollectorArea.prefab b/UnitySDK/Assets/ML-Agents/Examples/FoodCollector/Prefabs/FoodCollectorArea.prefab
index df1d5ff597..48980bc0a0 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/FoodCollector/Prefabs/FoodCollectorArea.prefab
+++ b/UnitySDK/Assets/ML-Agents/Examples/FoodCollector/Prefabs/FoodCollectorArea.prefab
@@ -495,8 +495,8 @@ GameObject:
- component: {fileID: 65550728419070768}
- component: {fileID: 54936164982484646}
- component: {fileID: 114374774605792098}
- - component: {fileID: 114762047763154270}
- component: {fileID: 114176228333253036}
+ - component: {fileID: 114725457980523372}
m_Layer: 0
m_Name: Agent
m_TagString: agent
@@ -549,8 +549,8 @@ GameObject:
- component: {fileID: 65905012397919158}
- component: {fileID: 54504078365531932}
- component: {fileID: 114522573150607728}
- - component: {fileID: 114416645532260476}
- component: {fileID: 114711827726849508}
+ - component: {fileID: 114443152683847924}
m_Layer: 0
m_Name: Agent (1)
m_TagString: agent
@@ -604,8 +604,8 @@ GameObject:
- component: {fileID: 65152194455140476}
- component: {fileID: 54961653455021136}
- component: {fileID: 114980787530065684}
- - component: {fileID: 114192565006091356}
- component: {fileID: 114542632553128056}
+ - component: {fileID: 114986980423924774}
m_Layer: 0
m_Name: Agent (2)
m_TagString: agent
@@ -725,8 +725,8 @@ GameObject:
- component: {fileID: 65761952312736034}
- component: {fileID: 54819001862035794}
- component: {fileID: 114878550018296316}
- - component: {fileID: 114661830999747712}
- component: {fileID: 114189751434580810}
+ - component: {fileID: 114644889237473510}
m_Layer: 0
m_Name: Agent (4)
m_TagString: agent
@@ -779,8 +779,8 @@ GameObject:
- component: {fileID: 65367560123033576}
- component: {fileID: 54895479068989492}
- component: {fileID: 114035338027591536}
- - component: {fileID: 114821937036444478}
- component: {fileID: 114235147148547996}
+ - component: {fileID: 114276061479012222}
m_Layer: 0
m_Name: Agent (3)
m_TagString: agent
@@ -3683,7 +3683,7 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- vectorObservationSize: 53
+ vectorObservationSize: 4
numStackedVectorObservations: 1
vectorActionSize: 03000000030000000300000002000000
vectorActionDescriptions: []
@@ -3761,17 +3761,6 @@ MonoBehaviour:
myLaser: {fileID: 1617924810425504}
contribute: 0
useVectorObs: 1
---- !u!114 &114192565006091356
-MonoBehaviour:
- m_ObjectHideFlags: 1
- m_PrefabParentObject: {fileID: 0}
- m_PrefabInternal: {fileID: 100100000}
- m_GameObject: {fileID: 1601500200010266}
- m_Enabled: 1
- m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: bb172294dbbcc408286b156a2c4b553c, type: 3}
- m_Name:
- m_EditorClassIdentifier:
--- !u!114 &114235147148547996
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -3798,6 +3787,34 @@ MonoBehaviour:
myLaser: {fileID: 1045923826166930}
contribute: 0
useVectorObs: 1
+--- !u!114 &114276061479012222
+MonoBehaviour:
+ m_ObjectHideFlags: 1
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 100100000}
+ m_GameObject: {fileID: 1706274796045088}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ sensorName: RayPerceptionSensor
+ detectableTags:
+ - food
+ - agent
+ - wall
+ - badFood
+ - frozenAgent
+ raysPerDirection: 3
+ maxRayDegrees: 70
+ sphereCastRadius: 0.5
+ rayLength: 50
+ observationStacks: 1
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 0
+ endVerticalOffset: 0
--- !u!114 &114374774605792098
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -3810,7 +3827,7 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- vectorObservationSize: 53
+ vectorObservationSize: 4
numStackedVectorObservations: 1
vectorActionSize: 03000000030000000300000002000000
vectorActionDescriptions: []
@@ -3819,7 +3836,7 @@ MonoBehaviour:
m_InferenceDevice: 0
m_UseHeuristic: 0
m_BehaviorName: FoodCollector
---- !u!114 &114416645532260476
+--- !u!114 &114443152683847924
MonoBehaviour:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
@@ -3827,9 +3844,26 @@ MonoBehaviour:
m_GameObject: {fileID: 1495617568563208}
m_Enabled: 1
m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: bb172294dbbcc408286b156a2c4b553c, type: 3}
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
m_Name:
m_EditorClassIdentifier:
+ sensorName: RayPerceptionSensor
+ detectableTags:
+ - food
+ - agent
+ - wall
+ - badFood
+ - frozenAgent
+ raysPerDirection: 3
+ maxRayDegrees: 70
+ sphereCastRadius: 0.5
+ rayLength: 50
+ observationStacks: 1
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 0
+ endVerticalOffset: 0
--- !u!114 &114522573150607728
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -3842,7 +3876,7 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- vectorObservationSize: 53
+ vectorObservationSize: 4
numStackedVectorObservations: 1
vectorActionSize: 03000000030000000300000002000000
vectorActionDescriptions: []
@@ -3877,7 +3911,7 @@ MonoBehaviour:
myLaser: {fileID: 1421240237750412}
contribute: 0
useVectorObs: 1
---- !u!114 &114661830999747712
+--- !u!114 &114644889237473510
MonoBehaviour:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
@@ -3885,9 +3919,26 @@ MonoBehaviour:
m_GameObject: {fileID: 1672905243433088}
m_Enabled: 1
m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: bb172294dbbcc408286b156a2c4b553c, type: 3}
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
m_Name:
m_EditorClassIdentifier:
+ sensorName: RayPerceptionSensor
+ detectableTags:
+ - food
+ - agent
+ - wall
+ - badFood
+ - frozenAgent
+ raysPerDirection: 3
+ maxRayDegrees: 70
+ sphereCastRadius: 0.5
+ rayLength: 50
+ observationStacks: 1
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 0
+ endVerticalOffset: 0
--- !u!114 &114711827726849508
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -3914,7 +3965,7 @@ MonoBehaviour:
myLaser: {fileID: 1941433838307300}
contribute: 0
useVectorObs: 1
---- !u!114 &114762047763154270
+--- !u!114 &114725457980523372
MonoBehaviour:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
@@ -3922,20 +3973,26 @@ MonoBehaviour:
m_GameObject: {fileID: 1464820575638702}
m_Enabled: 1
m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: bb172294dbbcc408286b156a2c4b553c, type: 3}
- m_Name:
- m_EditorClassIdentifier:
---- !u!114 &114821937036444478
-MonoBehaviour:
- m_ObjectHideFlags: 1
- m_PrefabParentObject: {fileID: 0}
- m_PrefabInternal: {fileID: 100100000}
- m_GameObject: {fileID: 1706274796045088}
- m_Enabled: 1
- m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: bb172294dbbcc408286b156a2c4b553c, type: 3}
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
m_Name:
m_EditorClassIdentifier:
+ sensorName: RayPerceptionSensor
+ detectableTags:
+ - food
+ - agent
+ - wall
+ - badFood
+ - frozenAgent
+ raysPerDirection: 3
+ maxRayDegrees: 70
+ sphereCastRadius: 0.5
+ rayLength: 50
+ observationStacks: 1
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 0
+ endVerticalOffset: 0
--- !u!114 &114878550018296316
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -3948,7 +4005,7 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- vectorObservationSize: 53
+ vectorObservationSize: 4
numStackedVectorObservations: 1
vectorActionSize: 03000000030000000300000002000000
vectorActionDescriptions: []
@@ -3969,7 +4026,7 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- vectorObservationSize: 53
+ vectorObservationSize: 4
numStackedVectorObservations: 1
vectorActionSize: 03000000030000000300000002000000
vectorActionDescriptions: []
@@ -3978,3 +4035,31 @@ MonoBehaviour:
m_InferenceDevice: 0
m_UseHeuristic: 0
m_BehaviorName: FoodCollector
+--- !u!114 &114986980423924774
+MonoBehaviour:
+ m_ObjectHideFlags: 1
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 100100000}
+ m_GameObject: {fileID: 1601500200010266}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ sensorName: RayPerceptionSensor
+ detectableTags:
+ - food
+ - agent
+ - wall
+ - badFood
+ - frozenAgent
+ raysPerDirection: 3
+ maxRayDegrees: 70
+ sphereCastRadius: 0.5
+ rayLength: 50
+ observationStacks: 1
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 0
+ endVerticalOffset: 0
diff --git a/UnitySDK/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs b/UnitySDK/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs
index c00e27adb0..9cab47ce5e 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs
@@ -3,7 +3,7 @@
public class FoodCollectorAgent : Agent
{
- private FoodCollectorAcademy m_MyAcademy;
+ FoodCollectorAcademy m_MyAcademy;
public GameObject area;
FoodCollectorArea m_MyArea;
bool m_Frozen;
@@ -13,7 +13,7 @@ public class FoodCollectorAgent : Agent
float m_FrozenTime;
float m_EffectTime;
Rigidbody m_AgentRb;
- private float m_LaserLength;
+ float m_LaserLength;
// Speed of agent rotation.
public float turnSpeed = 300;
@@ -25,7 +25,6 @@ public class FoodCollectorAgent : Agent
public Material frozenMaterial;
public GameObject myLaser;
public bool contribute;
- private RayPerception3D m_RayPer;
public bool useVectorObs;
@@ -35,7 +34,6 @@ public override void InitializeAgent()
m_AgentRb = GetComponent();
Monitor.verticalOffset = 1f;
m_MyArea = area.GetComponent();
- m_RayPer = GetComponent();
m_MyAcademy = FindObjectOfType();
SetResetParameters();
@@ -45,10 +43,6 @@ public override void CollectObservations()
{
if (useVectorObs)
{
- const float rayDistance = 50f;
- float[] rayAngles = { 20f, 90f, 160f, 45f, 135f, 70f, 110f };
- string[] detectableObjects = { "food", "agent", "wall", "badFood", "frozenAgent" };
- AddVectorObs(m_RayPer.Perceive(rayDistance, rayAngles, detectableObjects, 0f, 0f));
var localVelocity = transform.InverseTransformDirection(m_AgentRb.velocity);
AddVectorObs(localVelocity.x);
AddVectorObs(localVelocity.z);
@@ -150,10 +144,10 @@ public void MoveAgent(float[] act)
{
var myTransform = transform;
myLaser.transform.localScale = new Vector3(1f, 1f, m_LaserLength);
- var position = myTransform.TransformDirection(RayPerception3D.PolarToCartesian(25f, 90f));
- Debug.DrawRay(myTransform.position, position, Color.red, 0f, true);
+ var rayDir = 25.0f * myTransform.forward;
+ Debug.DrawRay(myTransform.position, rayDir, Color.red, 0f, true);
RaycastHit hit;
- if (Physics.SphereCast(transform.position, 2f, position, out hit, 25f))
+ if (Physics.SphereCast(transform.position, 2f, rayDir, out hit, 25f))
{
if (hit.collider.gameObject.CompareTag("agent"))
{
@@ -208,7 +202,7 @@ void Unsatiate()
gameObject.GetComponentInChildren().material = normalMaterial;
}
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
MoveAgent(vectorAction);
}
diff --git a/UnitySDK/Assets/ML-Agents/Examples/FoodCollector/TFModels/FoodCollector.nn b/UnitySDK/Assets/ML-Agents/Examples/FoodCollector/TFModels/FoodCollector.nn
index 6876fd633f..0463686575 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/FoodCollector/TFModels/FoodCollector.nn and b/UnitySDK/Assets/ML-Agents/Examples/FoodCollector/TFModels/FoodCollector.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Prefabs/Area.prefab b/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Prefabs/Area.prefab
index 7d6c5073a9..e91aade640 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Prefabs/Area.prefab
+++ b/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Prefabs/Area.prefab
@@ -1005,7 +1005,7 @@ MonoBehaviour:
camera: {fileID: 20743940359151984}
sensorName: CameraSensor
width: 84
- height: 84
+ height: 64
grayscale: 0
--- !u!114 &114935253044749092
MonoBehaviour:
diff --git a/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Prefabs/agentRenderTexture.renderTexture b/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Prefabs/agentRenderTexture.renderTexture
index 57e68f3f68..3c2366415f 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Prefabs/agentRenderTexture.renderTexture
+++ b/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Prefabs/agentRenderTexture.renderTexture
@@ -12,7 +12,7 @@ RenderTexture:
m_ForcedFallbackFormat: 4
m_DownscaleFallback: 0
m_Width: 84
- m_Height: 84
+ m_Height: 64
m_AntiAliasing: 1
m_DepthFormat: 1
m_ColorFormat: 0
diff --git a/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity b/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity
index 1d5e451569..cf1412870a 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity
+++ b/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity
@@ -38,7 +38,7 @@ RenderSettings:
m_ReflectionIntensity: 1
m_CustomReflection: {fileID: 0}
m_Sun: {fileID: 0}
- m_IndirectSpecularColor: {r: 0.44971162, g: 0.49977726, b: 0.5756362, a: 1}
+ m_IndirectSpecularColor: {r: 0.4497121, g: 0.49977785, b: 0.57563704, a: 1}
--- !u!157 &3
LightmapSettings:
m_ObjectHideFlags: 0
@@ -354,7 +354,7 @@ MonoBehaviour:
vectorActionSize: 05000000
vectorActionDescriptions: []
vectorActionSpaceType: 0
- m_Model: {fileID: 11400000, guid: 07afbd1d35ed345eeb850fcbb59eae0b, type: 3}
+ m_Model: {fileID: 11400000, guid: a812f1ce7763a4a0c912717f3594fe20, type: 3}
m_InferenceDevice: 0
m_UseHeuristic: 0
m_BehaviorName: GridWorld
@@ -372,7 +372,7 @@ MonoBehaviour:
renderTexture: {fileID: 8400000, guid: 114608d5384404f89bff4b6f88432958, type: 2}
sensorName: RenderTextureSensor
width: 84
- height: 84
+ height: 64
grayscale: 0
--- !u!1 &260425459
GameObject:
@@ -1584,7 +1584,7 @@ RectTransform:
m_AnchorMin: {x: 0.5, y: 0.5}
m_AnchorMax: {x: 0.5, y: 0.5}
m_AnchoredPosition: {x: -369.5, y: -197}
- m_SizeDelta: {x: 200, y: 200}
+ m_SizeDelta: {x: 200, y: 152}
m_Pivot: {x: 0.5, y: 0.5}
--- !u!114 &1305247361
MonoBehaviour:
diff --git a/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs b/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
index c142413d2d..7bb1ab01f1 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
@@ -6,12 +6,12 @@
public class GridAgent : Agent
{
- private Academy m_Academy;
+ Academy m_Academy;
[FormerlySerializedAs("m_Area")]
[Header("Specific to GridWorld")]
public GridArea area;
public float timeBetweenDecisionsAtInference;
- private float m_TimeSinceDecision;
+ float m_TimeSinceDecision;
[Tooltip("Because we want an observation right before making a decision, we can force " +
"a camera to render before making a decision. Place the agentCam here if using " +
@@ -22,11 +22,11 @@ public class GridAgent : Agent
"masking turned on may not behave optimally when action masking is turned off.")]
public bool maskActions = true;
- private const int k_NoAction = 0; // do nothing!
- private const int k_Up = 1;
- private const int k_Down = 2;
- private const int k_Left = 3;
- private const int k_Right = 4;
+ const int k_NoAction = 0; // do nothing!
+ const int k_Up = 1;
+ const int k_Down = 2;
+ const int k_Left = 3;
+ const int k_Right = 4;
public override void InitializeAgent()
{
@@ -48,7 +48,7 @@ public override void CollectObservations()
///
/// Applies the mask for the agents action to disallow unnecessary actions.
///
- private void SetMask()
+ void SetMask()
{
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;
@@ -77,7 +77,7 @@ private void SetMask()
}
// to be implemented by the developer
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
AddReward(-0.01f);
var action = Mathf.FloorToInt(vectorAction[0]);
@@ -155,7 +155,7 @@ public void FixedUpdate()
WaitTimeInference();
}
- private void WaitTimeInference()
+ void WaitTimeInference()
{
if (renderCamera != null)
{
diff --git a/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs b/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs
index 4c19f1ba3d..a88e8cd0e8 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs
@@ -13,7 +13,7 @@ public class GridArea : MonoBehaviour
public GameObject trueAgent;
- private ResetParameters m_ResetParameters;
+ ResetParameters m_ResetParameters;
Camera m_AgentCam;
@@ -27,7 +27,7 @@ public class GridArea : MonoBehaviour
GameObject m_Se;
GameObject m_Sw;
- private Vector3 m_InitialPosition;
+ Vector3 m_InitialPosition;
public void Awake()
{
diff --git a/UnitySDK/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld.nn b/UnitySDK/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld.nn
index 7a438f090e..7c8976e17a 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld.nn and b/UnitySDK/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Hallway/Prefabs/SymbolFinderArea.prefab b/UnitySDK/Assets/ML-Agents/Examples/Hallway/Prefabs/SymbolFinderArea.prefab
index 8a3a024dba..9e6fe38845 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Hallway/Prefabs/SymbolFinderArea.prefab
+++ b/UnitySDK/Assets/ML-Agents/Examples/Hallway/Prefabs/SymbolFinderArea.prefab
@@ -248,7 +248,7 @@ GameObject:
- component: {fileID: 54112968250075710}
- component: {fileID: 114907778469006590}
- component: {fileID: 114286701363010626}
- - component: {fileID: 114569343444552314}
+ - component: {fileID: 114388598785529460}
m_Layer: 0
m_Name: Agent
m_TagString: agent
@@ -1585,7 +1585,7 @@ MonoBehaviour:
symbolO: {fileID: 1453690758295050}
symbolX: {fileID: 1915733999209864}
useVectorObs: 1
---- !u!114 &114569343444552314
+--- !u!114 &114388598785529460
MonoBehaviour:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
@@ -1593,9 +1593,26 @@ MonoBehaviour:
m_GameObject: {fileID: 1471560210313468}
m_Enabled: 1
m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: bb172294dbbcc408286b156a2c4b553c, type: 3}
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
m_Name:
m_EditorClassIdentifier:
+ sensorName: RayPerceptionSensor
+ detectableTags:
+ - symbol_O_Goal
+ - symbol_X_Goal
+ - symbol_O
+ - symbol_X
+ - wall
+ raysPerDirection: 2
+ maxRayDegrees: 70
+ sphereCastRadius: 0.5
+ rayLength: 12
+ observationStacks: 3
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 0
+ endVerticalOffset: 0
--- !u!114 &114907778469006590
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -1608,7 +1625,7 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- vectorObservationSize: 36
+ vectorObservationSize: 1
numStackedVectorObservations: 3
vectorActionSize: 05000000
vectorActionDescriptions: []
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs b/UnitySDK/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs
index 5912ce2521..90f1a55372 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs
@@ -11,7 +11,6 @@ public class HallwayAgent : Agent
public GameObject symbolO;
public GameObject symbolX;
public bool useVectorObs;
- RayPerception m_RayPer;
Rigidbody m_AgentRb;
Material m_GroundMaterial;
Renderer m_GroundRenderer;
@@ -22,7 +21,6 @@ public override void InitializeAgent()
{
base.InitializeAgent();
m_Academy = FindObjectOfType();
- m_RayPer = GetComponent();
m_AgentRb = GetComponent();
m_GroundRenderer = ground.GetComponent();
m_GroundMaterial = m_GroundRenderer.material;
@@ -32,11 +30,7 @@ public override void CollectObservations()
{
if (useVectorObs)
{
- var rayDistance = 12f;
- float[] rayAngles = { 20f, 60f, 90f, 120f, 160f };
- string[] detectableObjects = { "symbol_O_Goal", "symbol_X_Goal", "symbol_O", "symbol_X", "wall" };
AddVectorObs(GetStepCount() / (float)agentParameters.maxStep);
- AddVectorObs(m_RayPer.Perceive(rayDistance, rayAngles, detectableObjects, 0f, 0f));
}
}
@@ -72,7 +66,7 @@ public void MoveAgent(float[] act)
m_AgentRb.AddForce(dirToGo * m_Academy.agentRunSpeed, ForceMode.VelocityChange);
}
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
AddReward(-1f / agentParameters.maxStep);
MoveAgent(vectorAction);
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Hallway/TFModels/Hallway.nn b/UnitySDK/Assets/ML-Agents/Examples/Hallway/TFModels/Hallway.nn
index c80aaa91b2..3147d7a24d 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/Hallway/TFModels/Hallway.nn and b/UnitySDK/Assets/ML-Agents/Examples/Hallway/TFModels/Hallway.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockArea.prefab b/UnitySDK/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockArea.prefab
index fde9a4ad4c..9f009a424f 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockArea.prefab
+++ b/UnitySDK/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockArea.prefab
@@ -106,8 +106,9 @@ GameObject:
- component: {fileID: 54817351390947638}
- component: {fileID: 114306175693660464}
- component: {fileID: 114505490781873732}
- - component: {fileID: 114421647563711602}
- component: {fileID: 65880096262939968}
+ - component: {fileID: 114807072692257076}
+ - component: {fileID: 114451319691753174}
m_Layer: 0
m_Name: Agent
m_TagString: agent
@@ -946,8 +947,8 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- vectorObservationSize: 70
- numStackedVectorObservations: 3
+ vectorObservationSize: 0
+ numStackedVectorObservations: 2
vectorActionSize: 07000000
vectorActionDescriptions: []
vectorActionSpaceType: 0
@@ -955,7 +956,7 @@ MonoBehaviour:
m_InferenceDevice: 0
m_UseHeuristic: 0
m_BehaviorName: PushBlock
---- !u!114 &114421647563711602
+--- !u!114 &114451319691753174
MonoBehaviour:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
@@ -963,9 +964,24 @@ MonoBehaviour:
m_GameObject: {fileID: 1489716781518988}
m_Enabled: 1
m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: bb172294dbbcc408286b156a2c4b553c, type: 3}
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
m_Name:
m_EditorClassIdentifier:
+ sensorName: OffsetRayPerceptionSensor
+ detectableTags:
+ - block
+ - goal
+ - wall
+ raysPerDirection: 3
+ maxRayDegrees: 90
+ sphereCastRadius: 0.5
+ rayLength: 12
+ observationStacks: 3
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 1.5
+ endVerticalOffset: 1.5
--- !u!114 &114505490781873732
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -991,3 +1007,29 @@ MonoBehaviour:
block: {fileID: 1831337770648600}
goalDetect: {fileID: 0}
useVectorObs: 1
+--- !u!114 &114807072692257076
+MonoBehaviour:
+ m_ObjectHideFlags: 1
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 100100000}
+ m_GameObject: {fileID: 1489716781518988}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ sensorName: RayPerceptionSensor
+ detectableTags:
+ - block
+ - goal
+ - wall
+ raysPerDirection: 3
+ maxRayDegrees: 90
+ sphereCastRadius: 0.5
+ rayLength: 12
+ observationStacks: 3
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 0
+ endVerticalOffset: 0
diff --git a/UnitySDK/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs b/UnitySDK/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs
index 87d19f8ae1..89b1abd714 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs
@@ -42,10 +42,6 @@ public class PushAgentBasic : Agent
Rigidbody m_BlockRb; //cached on initialization
Rigidbody m_AgentRb; //cached on initialization
Material m_GroundMaterial; //cached on Awake()
- RayPerception m_RayPer;
-
- float[] m_RayAngles = { 0f, 45f, 90f, 135f, 180f, 110f, 70f };
- string[] m_DetectableObjects = { "block", "goal", "wall" };
///
/// We will be changing the ground material based on success/failue
@@ -62,7 +58,6 @@ public override void InitializeAgent()
base.InitializeAgent();
goalDetect = block.GetComponent();
goalDetect.agent = this;
- m_RayPer = GetComponent();
// Cache the agent rigidbody
m_AgentRb = GetComponent();
@@ -78,17 +73,6 @@ public override void InitializeAgent()
SetResetParameters();
}
- public override void CollectObservations()
- {
- if (useVectorObs)
- {
- var rayDistance = 12f;
-
- AddVectorObs(m_RayPer.Perceive(rayDistance, m_RayAngles, m_DetectableObjects, 0f, 0f));
- AddVectorObs(m_RayPer.Perceive(rayDistance, m_RayAngles, m_DetectableObjects, 1.5f, 0f));
- }
- }
-
///
/// Use the ground's bounds to pick a random spawn position.
///
@@ -177,7 +161,7 @@ public void MoveAgent(float[] act)
///
/// Called every step of the engine. Here the agent takes an action.
///
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
// Move the agent using the action.
MoveAgent(vectorAction);
diff --git a/UnitySDK/Assets/ML-Agents/Examples/PushBlock/TFModels/PushBlock.nn b/UnitySDK/Assets/ML-Agents/Examples/PushBlock/TFModels/PushBlock.nn
index 1598e66520..5cef55d048 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/PushBlock/TFModels/PushBlock.nn and b/UnitySDK/Assets/ML-Agents/Examples/PushBlock/TFModels/PushBlock.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs b/UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs
index 9f0f1a3ebf..66179a2b45 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs
@@ -7,10 +7,10 @@
public class PyramidAgent : Agent
{
public GameObject area;
- private PyramidArea m_MyArea;
- private Rigidbody m_AgentRb;
- private RayPerception m_RayPer;
- private PyramidSwitch m_SwitchLogic;
+ PyramidArea m_MyArea;
+ Rigidbody m_AgentRb;
+ RayPerception m_RayPer;
+ PyramidSwitch m_SwitchLogic;
public GameObject areaSwitch;
public bool useVectorObs;
@@ -33,7 +33,7 @@ public override void CollectObservations()
float[] rayAngles2 = { 15f, 85f, 155f, 40f, 130f, 65f, 105f };
string[] detectableObjects = { "block", "wall", "goal", "switchOff", "switchOn", "stone" };
- AddVectorObs(m_RayPer.Perceive(rayDistance, rayAngles, detectableObjects, 0f, 0f));
+ AddVectorObs(m_RayPer.Perceive(rayDistance, rayAngles, detectableObjects));
AddVectorObs(m_RayPer.Perceive(rayDistance, rayAngles1, detectableObjects, 0f, 5f));
AddVectorObs(m_RayPer.Perceive(rayDistance, rayAngles2, detectableObjects, 0f, 10f));
AddVectorObs(m_SwitchLogic.GetState());
@@ -66,7 +66,7 @@ public void MoveAgent(float[] act)
m_AgentRb.AddForce(dirToGo * 2f, ForceMode.VelocityChange);
}
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
AddReward(-1f / agentParameters.maxStep);
MoveAgent(vectorAction);
@@ -113,7 +113,7 @@ public override void AgentReset()
m_MyArea.CreateStonePyramid(1, items[8]);
}
- private void OnCollisionEnter(Collision collision)
+ void OnCollisionEnter(Collision collision)
{
if (collision.gameObject.CompareTag("goal"))
{
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidArea.cs b/UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidArea.cs
index 688e4e6378..a2ac297727 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidArea.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidArea.cs
@@ -19,7 +19,7 @@ public void CreateStonePyramid(int numObjects, int spawnAreaIndex)
CreateObject(numObjects, stonePyramid, spawnAreaIndex);
}
- private void CreateObject(int numObjects, GameObject desiredObject, int spawnAreaIndex)
+ void CreateObject(int numObjects, GameObject desiredObject, int spawnAreaIndex)
{
for (var i = 0; i < numObjects; i++)
{
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidSwitch.cs b/UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidSwitch.cs
index 97d32c9c93..6f2b627398 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidSwitch.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidSwitch.cs
@@ -5,17 +5,17 @@ public class PyramidSwitch : MonoBehaviour
public Material onMaterial;
public Material offMaterial;
public GameObject myButton;
- private bool m_State;
- private GameObject m_Area;
- private PyramidArea m_AreaComponent;
- private int m_PyramidIndex;
+ bool m_State;
+ GameObject m_Area;
+ PyramidArea m_AreaComponent;
+ int m_PyramidIndex;
public bool GetState()
{
return m_State;
}
- private void Start()
+ void Start()
{
m_Area = gameObject.transform.parent.gameObject;
m_AreaComponent = m_Area.GetComponent();
@@ -31,7 +31,7 @@ public void ResetSwitch(int spawnAreaIndex, int pyramidSpawnIndex)
myButton.GetComponent().material = offMaterial;
}
- private void OnCollisionEnter(Collision other)
+ void OnCollisionEnter(Collision other)
{
if (other.gameObject.CompareTag("agent") && m_State == false)
{
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Pyramids/TFModels/Pyramids.nn b/UnitySDK/Assets/ML-Agents/Examples/Pyramids/TFModels/Pyramids.nn
index e2227dc86f..2fb0bc9b52 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/Pyramids/TFModels/Pyramids.nn and b/UnitySDK/Assets/ML-Agents/Examples/Pyramids/TFModels/Pyramids.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs b/UnitySDK/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs
index cb45cd93a8..2c30d0c141 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs
@@ -7,18 +7,18 @@ public class ReacherAgent : Agent
public GameObject pendulumB;
public GameObject hand;
public GameObject goal;
- private ReacherAcademy m_MyAcademy;
+ ReacherAcademy m_MyAcademy;
float m_GoalDegree;
- private Rigidbody m_RbA;
- private Rigidbody m_RbB;
+ Rigidbody m_RbA;
+ Rigidbody m_RbB;
// speed of the goal zone around the arm (in radians)
- private float m_GoalSpeed;
+ float m_GoalSpeed;
// radius of the goal zone
- private float m_GoalSize;
+ float m_GoalSize;
// Magnitude of sinusoidal (cosine) deviation of the goal along the vertical dimension
- private float m_Deviation;
+ float m_Deviation;
// Frequency of the cosine deviation of the goal along the vertical dimension
- private float m_DeviationFreq;
+ float m_DeviationFreq;
///
/// Collect the rigidbodies of the reacher in order to resue them for
@@ -58,7 +58,7 @@ public override void CollectObservations()
///
/// The agent's four actions correspond to torques on each of the two joints.
///
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
m_GoalDegree += m_GoalSpeed;
UpdateGoalPosition();
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherGoal.cs b/UnitySDK/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherGoal.cs
index e2076b511a..a31cb6908b 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherGoal.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherGoal.cs
@@ -6,7 +6,7 @@ public class ReacherGoal : MonoBehaviour
public GameObject hand;
public GameObject goalOn;
- private void OnTriggerEnter(Collider other)
+ void OnTriggerEnter(Collider other)
{
if (other.gameObject == hand)
{
@@ -14,7 +14,7 @@ private void OnTriggerEnter(Collider other)
}
}
- private void OnTriggerExit(Collider other)
+ void OnTriggerExit(Collider other)
{
if (other.gameObject == hand)
{
@@ -22,7 +22,7 @@ private void OnTriggerExit(Collider other)
}
}
- private void OnTriggerStay(Collider other)
+ void OnTriggerStay(Collider other)
{
if (other.gameObject == hand)
{
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Reacher/TFModels/Reacher.nn b/UnitySDK/Assets/ML-Agents/Examples/Reacher/TFModels/Reacher.nn
index bcda57d751..d8fb548ffd 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/Reacher/TFModels/Reacher.nn and b/UnitySDK/Assets/ML-Agents/Examples/Reacher/TFModels/Reacher.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/FlyCamera.cs b/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/FlyCamera.cs
index 88f4093544..b646dd6bc5 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/FlyCamera.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/FlyCamera.cs
@@ -17,12 +17,12 @@ public class FlyCamera : MonoBehaviour
public bool rotateOnlyIfMousedown = true;
public bool movementStaysFlat = true;
- private Vector3
+ Vector3
m_LastMouse =
new Vector3(255, 255,
255); // kind of in the middle of the screen, rather than at the top (play)
- private float m_TotalRun = 1.0f;
+ float m_TotalRun = 1.0f;
void Awake()
{
@@ -86,7 +86,7 @@ void Update()
}
}
- private Vector3 GetBaseInput()
+ Vector3 GetBaseInput()
{
// returns the basic values, if it's 0 than it's not active.
var pVelocity = new Vector3();
diff --git a/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/GroundContact.cs b/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/GroundContact.cs
index 35fff3eed0..9fb04fe6b2 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/GroundContact.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/GroundContact.cs
@@ -16,7 +16,7 @@ public class GroundContact : MonoBehaviour
public bool penalizeGroundContact; // Whether to penalize on contact.
public float groundContactPenalty; // Penalty amount (ex: -1).
public bool touchingGround;
- private const string k_Ground = "ground"; // Tag of ground object.
+ const string k_Ground = "ground"; // Tag of ground object.
///
/// Check for collision with ground, and optionally penalize agent.
diff --git a/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/RayPerception.cs b/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/RayPerception.cs
index a4b4fbc7bb..45744fa1fa 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/RayPerception.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/RayPerception.cs
@@ -1,19 +1,15 @@
-using System.Collections.Generic;
+using System;
+using System.Collections.Generic;
using UnityEngine;
+[Obsolete]
public abstract class RayPerception : MonoBehaviour
{
- protected List m_PerceptionBuffer = new List();
+ protected float[] m_PerceptionBuffer;
- abstract public List Perceive(float rayDistance,
+ abstract public IList Perceive(float rayDistance,
float[] rayAngles, string[] detectableObjects,
float startOffset=0.0f, float endOffset=0.0f);
- ///
- /// Converts degrees to radians.
- ///
- public static float DegreeToRadian(float degree)
- {
- return degree * Mathf.PI / 180f;
- }
+
}
diff --git a/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/RayPerception2D.cs b/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/RayPerception2D.cs
index 6c669d5950..2d76754b19 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/RayPerception2D.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/RayPerception2D.cs
@@ -1,5 +1,7 @@
-using System.Collections.Generic;
+using System;
+using System.Collections.Generic;
using UnityEngine;
+using MLAgents.Sensor;
namespace MLAgents
{
@@ -7,9 +9,9 @@ namespace MLAgents
/// Ray 2D perception component. Attach this to agents to enable "local perception"
/// via the use of ray casts directed outward from the agent.
///
+ [Obsolete("The RayPerception MonoBehaviour is deprecated. Use the RayPerceptionSensorComponent instead")]
public class RayPerception2D : RayPerception
{
- Vector2 m_EndPosition;
RaycastHit2D m_Hit;
///
@@ -30,56 +32,25 @@ public class RayPerception2D : RayPerception
/// List of tags which correspond to object types agent can see
/// Unused
/// Unused
- public override List Perceive(float rayDistance,
+ public override IList Perceive(float rayDistance,
float[] rayAngles, string[] detectableObjects,
float startOffset=0.0f, float endOffset=0.0f)
{
- m_PerceptionBuffer.Clear();
- // For each ray sublist stores categorical information on detected object
- // along with object distance.
- foreach (var angle in rayAngles)
+ var perceptionSize = (detectableObjects.Length + 2) * rayAngles.Length;
+ if (m_PerceptionBuffer == null || m_PerceptionBuffer.Length != perceptionSize)
{
- m_EndPosition = transform.TransformDirection(
- PolarToCartesian(rayDistance, angle));
- if (Application.isEditor)
- {
- Debug.DrawRay(transform.position,
- m_EndPosition, Color.black, 0.01f, true);
- }
-
- var subList = new float[detectableObjects.Length + 2];
- m_Hit = Physics2D.CircleCast(transform.position, 0.5f, m_EndPosition, rayDistance);
- if (m_Hit)
- {
- for (var i = 0; i < detectableObjects.Length; i++)
- {
- if (m_Hit.collider.gameObject.CompareTag(detectableObjects[i]))
- {
- subList[i] = 1;
- subList[detectableObjects.Length + 1] = m_Hit.distance / rayDistance;
- break;
- }
- }
- }
- else
- {
- subList[detectableObjects.Length] = 1f;
- }
-
- m_PerceptionBuffer.AddRange(subList);
+ m_PerceptionBuffer = new float[perceptionSize];
}
+ const float castRadius = 0.5f;
+ const bool legacyHitFractionBehavior = true;
+ RayPerceptionSensor.PerceiveStatic(
+ rayDistance, rayAngles, detectableObjects, startOffset, endOffset, castRadius,
+ transform, RayPerceptionSensor.CastType.Cast3D, m_PerceptionBuffer, legacyHitFractionBehavior
+ );
+
return m_PerceptionBuffer;
}
- ///
- /// Converts polar coordinate to cartesian coordinate.
- ///
- public static Vector2 PolarToCartesian(float radius, float angle)
- {
- var x = radius * Mathf.Cos(DegreeToRadian(angle));
- var y = radius * Mathf.Sin(DegreeToRadian(angle));
- return new Vector2(x, y);
- }
}
}
diff --git a/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/RayPerception3D.cs b/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/RayPerception3D.cs
index 7b85add817..8054effe2d 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/RayPerception3D.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/RayPerception3D.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using UnityEngine;
+using MLAgents.Sensor;
namespace MLAgents
{
@@ -8,12 +9,9 @@ namespace MLAgents
/// Ray perception component. Attach this to agents to enable "local perception"
/// via the use of ray casts directed outward from the agent.
///
+ [Obsolete("The RayPerception MonoBehaviour is deprecated. Use the RayPerceptionSensorComponent instead")]
public class RayPerception3D : RayPerception
{
- Vector3 m_EndPosition;
- RaycastHit m_Hit;
- private float[] m_SubList;
-
///
/// Creates perception vector to be used as part of an observation of an agent.
/// Each ray in the rayAngles array adds a sublist of data to the observation.
@@ -32,64 +30,26 @@ public class RayPerception3D : RayPerception
/// List of tags which correspond to object types agent can see
/// Starting height offset of ray from center of agent.
/// Ending height offset of ray from center of agent.
- public override List Perceive(float rayDistance,
+ public override IList Perceive(float rayDistance,
float[] rayAngles, string[] detectableObjects,
float startOffset=0.0f, float endOffset=0.0f)
{
- if (m_SubList == null || m_SubList.Length != detectableObjects.Length + 2)
- m_SubList = new float[detectableObjects.Length + 2];
-
- m_PerceptionBuffer.Clear();
- m_PerceptionBuffer.Capacity = m_SubList.Length * rayAngles.Length;
-
- // For each ray sublist stores categorical information on detected object
- // along with object distance.
- foreach (var angle in rayAngles)
+ var perceptionSize = (detectableObjects.Length + 2) * rayAngles.Length;
+ if (m_PerceptionBuffer == null || m_PerceptionBuffer.Length != perceptionSize)
{
- m_EndPosition = transform.TransformDirection(
- PolarToCartesian(rayDistance, angle));
- m_EndPosition.y = endOffset;
- if (Application.isEditor)
- {
- Debug.DrawRay(transform.position + new Vector3(0f, startOffset, 0f),
- m_EndPosition, Color.black, 0.01f, true);
- }
-
- Array.Clear(m_SubList, 0, m_SubList.Length);
-
- if (Physics.SphereCast(transform.position +
- new Vector3(0f, startOffset, 0f), 0.5f,
- m_EndPosition, out m_Hit, rayDistance))
- {
- for (var i = 0; i < detectableObjects.Length; i++)
- {
- if (m_Hit.collider.gameObject.CompareTag(detectableObjects[i]))
- {
- m_SubList[i] = 1;
- m_SubList[detectableObjects.Length + 1] = m_Hit.distance / rayDistance;
- break;
- }
- }
- }
- else
- {
- m_SubList[detectableObjects.Length] = 1f;
- }
-
- Utilities.AddRangeNoAlloc(m_PerceptionBuffer, m_SubList);
+ m_PerceptionBuffer = new float[perceptionSize];
}
+ const float castRadius = 0.5f;
+ const bool legacyHitFractionBehavior = true;
+ RayPerceptionSensor.PerceiveStatic(
+ rayDistance, rayAngles, detectableObjects, startOffset, endOffset, castRadius,
+ transform, RayPerceptionSensor.CastType.Cast3D, m_PerceptionBuffer, legacyHitFractionBehavior
+ );
+
return m_PerceptionBuffer;
}
- ///
- /// Converts polar coordinate to cartesian coordinate.
- ///
- public static Vector3 PolarToCartesian(float radius, float angle)
- {
- var x = radius * Mathf.Cos(DegreeToRadian(angle));
- var z = radius * Mathf.Sin(DegreeToRadian(angle));
- return new Vector3(x, 0f, z);
- }
+
}
}
diff --git a/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/TargetContact.cs b/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/TargetContact.cs
index 045c5e77bf..680190eb9c 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/TargetContact.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/SharedAssets/Scripts/TargetContact.cs
@@ -11,7 +11,7 @@ namespace MLAgents
public class TargetContact : MonoBehaviour
{
[Header("Detect Targets")] public bool touchingTarget;
- private const string k_Target = "target"; // Tag on target object.
+ const string k_Target = "target"; // Tag on target object.
///
/// Check for collision with a target.
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Soccer/Prefabs/SoccerFieldTwos.prefab b/UnitySDK/Assets/ML-Agents/Examples/Soccer/Prefabs/SoccerFieldTwos.prefab
index bacae626c9..5bad4718ca 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Soccer/Prefabs/SoccerFieldTwos.prefab
+++ b/UnitySDK/Assets/ML-Agents/Examples/Soccer/Prefabs/SoccerFieldTwos.prefab
@@ -89,7 +89,8 @@ GameObject:
- component: {fileID: 135232974003521068}
- component: {fileID: 114734187185382186}
- component: {fileID: 114492261207303438}
- - component: {fileID: 114692966630797794}
+ - component: {fileID: 114320493772006642}
+ - component: {fileID: 114413496910417180}
m_Layer: 13
m_Name: PurpleStriker
m_TagString: purpleAgent
@@ -180,7 +181,8 @@ GameObject:
- component: {fileID: 135154818167532598}
- component: {fileID: 114105115387635628}
- component: {fileID: 114698199869072806}
- - component: {fileID: 114381244552195858}
+ - component: {fileID: 114402225209785518}
+ - component: {fileID: 114691053776668376}
m_Layer: 11
m_Name: BlueGoalie
m_TagString: blueAgent
@@ -200,7 +202,8 @@ GameObject:
- component: {fileID: 135208952479003512}
- component: {fileID: 114387866097048300}
- component: {fileID: 114850431417842684}
- - component: {fileID: 114965771318032104}
+ - component: {fileID: 114516244030127556}
+ - component: {fileID: 114736358897902410}
m_Layer: 13
m_Name: BlueStriker
m_TagString: blueAgent
@@ -766,7 +769,8 @@ GameObject:
- component: {fileID: 135133947297127334}
- component: {fileID: 114529615399004778}
- component: {fileID: 114284769194328828}
- - component: {fileID: 114724674330921748}
+ - component: {fileID: 114742734491650780}
+ - component: {fileID: 114206319503468014}
m_Layer: 11
m_Name: PurpleGoalie
m_TagString: purpleAgent
@@ -3521,7 +3525,7 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- vectorObservationSize: 112
+ vectorObservationSize: 0
numStackedVectorObservations: 1
vectorActionSize: 05000000
vectorActionDescriptions: []
@@ -3530,6 +3534,35 @@ MonoBehaviour:
m_InferenceDevice: 0
m_UseHeuristic: 0
m_BehaviorName: Goalie
+--- !u!114 &114206319503468014
+MonoBehaviour:
+ m_ObjectHideFlags: 1
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 100100000}
+ m_GameObject: {fileID: 1890219402901316}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ sensorName: PurpleOffsetRayPerceptionSensor
+ detectableTags:
+ - ball
+ - purpleGoal
+ - blueGoal
+ - wall
+ - purpleAgent
+ - blueAgent
+ raysPerDirection: 3
+ maxRayDegrees: 90
+ sphereCastRadius: 0.5
+ rayLength: 20
+ observationStacks: 1
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 1
+ endVerticalOffset: 1
--- !u!114 &114273807544954564
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -3566,17 +3599,35 @@ MonoBehaviour:
agentRole: 1
area: {fileID: 114559182131992928}
agentRb: {fileID: 0}
---- !u!114 &114381244552195858
+--- !u!114 &114320493772006642
MonoBehaviour:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
- m_GameObject: {fileID: 1124213441168130}
+ m_GameObject: {fileID: 1095606497496374}
m_Enabled: 1
m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: bb172294dbbcc408286b156a2c4b553c, type: 3}
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
m_Name:
m_EditorClassIdentifier:
+ sensorName: PurpleRayPerceptionSensor
+ detectableTags:
+ - ball
+ - purpleGoal
+ - blueGoal
+ - wall
+ - purpleAgent
+ - blueAgent
+ raysPerDirection: 3
+ maxRayDegrees: 90
+ sphereCastRadius: 0.5
+ rayLength: 20
+ observationStacks: 1
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 0
+ endVerticalOffset: 0
--- !u!114 &114387866097048300
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -3589,7 +3640,7 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- vectorObservationSize: 112
+ vectorObservationSize: 0
numStackedVectorObservations: 1
vectorActionSize: 07000000
vectorActionDescriptions: []
@@ -3598,6 +3649,64 @@ MonoBehaviour:
m_InferenceDevice: 0
m_UseHeuristic: 0
m_BehaviorName: Striker
+--- !u!114 &114402225209785518
+MonoBehaviour:
+ m_ObjectHideFlags: 1
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 100100000}
+ m_GameObject: {fileID: 1124213441168130}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ sensorName: BlueRayPerceptionSensor
+ detectableTags:
+ - ball
+ - blueGoal
+ - purpleGoal
+ - wall
+ - blueAgent
+ - purpleAgent
+ raysPerDirection: 3
+ maxRayDegrees: 90
+ sphereCastRadius: 0.5
+ rayLength: 20
+ observationStacks: 1
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 0
+ endVerticalOffset: 0
+--- !u!114 &114413496910417180
+MonoBehaviour:
+ m_ObjectHideFlags: 1
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 100100000}
+ m_GameObject: {fileID: 1095606497496374}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ sensorName: PurpleOffsetRayPerceptionSensor
+ detectableTags:
+ - ball
+ - purpleGoal
+ - blueGoal
+ - wall
+ - purpleAgent
+ - blueAgent
+ raysPerDirection: 3
+ maxRayDegrees: 90
+ sphereCastRadius: 0.5
+ rayLength: 20
+ observationStacks: 1
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 1
+ endVerticalOffset: 1
--- !u!114 &114492261207303438
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -3618,6 +3727,35 @@ MonoBehaviour:
agentRole: 0
area: {fileID: 114559182131992928}
agentRb: {fileID: 0}
+--- !u!114 &114516244030127556
+MonoBehaviour:
+ m_ObjectHideFlags: 1
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 100100000}
+ m_GameObject: {fileID: 1131626411948014}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ sensorName: BlueRayPerceptionSensor
+ detectableTags:
+ - ball
+ - blueGoal
+ - purpleGoal
+ - wall
+ - blueAgent
+ - purpleAgent
+ raysPerDirection: 3
+ maxRayDegrees: 90
+ sphereCastRadius: 0.5
+ rayLength: 20
+ observationStacks: 1
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 0
+ endVerticalOffset: 0
--- !u!114 &114529615399004778
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -3630,7 +3768,7 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- vectorObservationSize: 112
+ vectorObservationSize: 0
numStackedVectorObservations: 1
vectorActionSize: 05000000
vectorActionDescriptions: []
@@ -3658,17 +3796,35 @@ MonoBehaviour:
ballStartingPos: {x: 0, y: 0, z: 0}
goalTextUI: {fileID: 0}
canResetBall: 0
---- !u!114 &114692966630797794
+--- !u!114 &114691053776668376
MonoBehaviour:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
- m_GameObject: {fileID: 1095606497496374}
+ m_GameObject: {fileID: 1124213441168130}
m_Enabled: 1
m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: bb172294dbbcc408286b156a2c4b553c, type: 3}
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
m_Name:
m_EditorClassIdentifier:
+ sensorName: BlueOffsetRayPerceptionSensor
+ detectableTags:
+ - ball
+ - blueGoal
+ - purpleGoal
+ - wall
+ - blueAgent
+ - purpleAgent
+ raysPerDirection: 3
+ maxRayDegrees: 90
+ sphereCastRadius: 0.5
+ rayLength: 20
+ observationStacks: 1
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 1
+ endVerticalOffset: 1
--- !u!114 &114698199869072806
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -3689,17 +3845,6 @@ MonoBehaviour:
agentRole: 1
area: {fileID: 114559182131992928}
agentRb: {fileID: 0}
---- !u!114 &114724674330921748
-MonoBehaviour:
- m_ObjectHideFlags: 1
- m_PrefabParentObject: {fileID: 0}
- m_PrefabInternal: {fileID: 100100000}
- m_GameObject: {fileID: 1890219402901316}
- m_Enabled: 1
- m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: bb172294dbbcc408286b156a2c4b553c, type: 3}
- m_Name:
- m_EditorClassIdentifier:
--- !u!114 &114734187185382186
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -3712,7 +3857,7 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- vectorObservationSize: 112
+ vectorObservationSize: 0
numStackedVectorObservations: 1
vectorActionSize: 07000000
vectorActionDescriptions: []
@@ -3721,6 +3866,64 @@ MonoBehaviour:
m_InferenceDevice: 0
m_UseHeuristic: 0
m_BehaviorName: Striker
+--- !u!114 &114736358897902410
+MonoBehaviour:
+ m_ObjectHideFlags: 1
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 100100000}
+ m_GameObject: {fileID: 1131626411948014}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ sensorName: BlueOffsetRayPerceptionSensor
+ detectableTags:
+ - ball
+ - blueGoal
+ - purpleGoal
+ - wall
+ - blueAgent
+ - purpleAgent
+ raysPerDirection: 3
+ maxRayDegrees: 90
+ sphereCastRadius: 0.5
+ rayLength: 20
+ observationStacks: 1
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 1
+ endVerticalOffset: 1
+--- !u!114 &114742734491650780
+MonoBehaviour:
+ m_ObjectHideFlags: 1
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 100100000}
+ m_GameObject: {fileID: 1890219402901316}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ sensorName: PurpleRayPerceptionSensor
+ detectableTags:
+ - ball
+ - purpleGoal
+ - blueGoal
+ - wall
+ - purpleAgent
+ - blueAgent
+ raysPerDirection: 3
+ maxRayDegrees: 90
+ sphereCastRadius: 0.5
+ rayLength: 20
+ observationStacks: 1
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+ startVerticalOffset: 0
+ endVerticalOffset: 0
--- !u!114 &114850431417842684
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -3741,17 +3944,6 @@ MonoBehaviour:
agentRole: 0
area: {fileID: 114559182131992928}
agentRb: {fileID: 0}
---- !u!114 &114965771318032104
-MonoBehaviour:
- m_ObjectHideFlags: 1
- m_PrefabParentObject: {fileID: 0}
- m_PrefabInternal: {fileID: 100100000}
- m_GameObject: {fileID: 1131626411948014}
- m_Enabled: 1
- m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: bb172294dbbcc408286b156a2c4b553c, type: 3}
- m_Name:
- m_EditorClassIdentifier:
--- !u!135 &135133947297127334
SphereCollider:
m_ObjectHideFlags: 1
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs b/UnitySDK/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs
index 2c0ec23162..7c579a92a0 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs
@@ -3,6 +3,13 @@
public class AgentSoccer : Agent
{
+ // Note that that the detectable tags are different for the blue and purple teams. The order is
+ // * ball
+ // * own goal
+ // * opposing goal
+ // * wall
+ // * own teammate
+ // * opposing player
public enum Team
{
Purple,
@@ -24,13 +31,6 @@ public enum AgentRole
public Rigidbody agentRb;
SoccerAcademy m_Academy;
Renderer m_AgentRenderer;
- RayPerception m_RayPer;
-
- float[] m_RayAngles = { 0f, 45f, 90f, 135f, 180f, 110f, 70f };
- string[] m_DetectableObjectsPurple = { "ball", "purpleGoal", "blueGoal",
- "wall", "purpleAgent", "blueAgent" };
- string[] m_DetectableObjectsBlue = { "ball", "blueGoal", "purpleGoal",
- "wall", "blueAgent", "purpleAgent" };
public void ChooseRandomTeam()
{
@@ -65,7 +65,6 @@ public override void InitializeAgent()
{
base.InitializeAgent();
m_AgentRenderer = GetComponentInChildren();
- m_RayPer = GetComponent();
m_Academy = FindObjectOfType();
agentRb = GetComponent();
agentRb.maxAngularVelocity = 500;
@@ -81,22 +80,6 @@ public override void InitializeAgent()
playerState.playerIndex = m_PlayerIndex;
}
- public override void CollectObservations()
- {
- var rayDistance = 20f;
- string[] detectableObjects;
- if (team == Team.Purple)
- {
- detectableObjects = m_DetectableObjectsPurple;
- }
- else
- {
- detectableObjects = m_DetectableObjectsBlue;
- }
- AddVectorObs(m_RayPer.Perceive(rayDistance, m_RayAngles, detectableObjects, 0f, 0f));
- AddVectorObs(m_RayPer.Perceive(rayDistance, m_RayAngles, detectableObjects, 1f, 0f));
- }
-
public void MoveAgent(float[] act)
{
var dirToGo = Vector3.zero;
@@ -156,7 +139,7 @@ public void MoveAgent(float[] act)
ForceMode.VelocityChange);
}
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
// Existential penalty for strikers.
if (agentRole == AgentRole.Striker)
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs b/UnitySDK/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs
index 591d7a7377..8060e88542 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs
@@ -7,7 +7,7 @@ public override void CollectObservations()
{
}
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
}
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/HitWall.cs b/UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/HitWall.cs
index bf9f03696e..d4914a83bc 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/HitWall.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/HitWall.cs
@@ -5,9 +5,9 @@ public class HitWall : MonoBehaviour
public GameObject areaObject;
public int lastAgentHit;
- private TennisArea m_Area;
- private TennisAgent m_AgentA;
- private TennisAgent m_AgentB;
+ TennisArea m_Area;
+ TennisAgent m_AgentA;
+ TennisAgent m_AgentB;
// Use this for initialization
void Start()
@@ -17,7 +17,7 @@ void Start()
m_AgentB = m_Area.agentB.GetComponent();
}
- private void OnTriggerExit(Collider other)
+ void OnTriggerExit(Collider other)
{
if (other.name == "over")
{
@@ -33,7 +33,7 @@ private void OnTriggerExit(Collider other)
}
}
- private void OnCollisionEnter(Collision collision)
+ void OnCollisionEnter(Collision collision)
{
if (collision.gameObject.CompareTag("iWall"))
{
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs b/UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
index 206a5e83fe..bcff20486d 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
@@ -12,17 +12,17 @@ public class TennisAgent : Agent
public float angle;
public float scale;
- private Text m_TextComponent;
- private Rigidbody m_AgentRb;
- private Rigidbody m_BallRb;
- private float m_InvertMult;
- private ResetParameters m_ResetParams;
+ Text m_TextComponent;
+ Rigidbody m_AgentRb;
+ Rigidbody m_BallRb;
+ float m_InvertMult;
+ ResetParameters m_ResetParams;
// Looks for the scoreboard based on the name of the gameObjects.
// Do not modify the names of the Score GameObjects
- private const string k_CanvasName = "Canvas";
- private const string k_ScoreBoardAName = "ScoreA";
- private const string k_ScoreBoardBName = "ScoreB";
+ const string k_CanvasName = "Canvas";
+ const string k_ScoreBoardAName = "ScoreA";
+ const string k_ScoreBoardBName = "ScoreB";
public override void InitializeAgent()
{
@@ -57,7 +57,7 @@ public override void CollectObservations()
AddVectorObs(m_BallRb.velocity.y);
}
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
var moveX = Mathf.Clamp(vectorAction[0], -1f, 1f) * m_InvertMult;
var moveY = Mathf.Clamp(vectorAction[1], -1f, 1f);
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/TennisArea.cs b/UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/TennisArea.cs
index 2cff0f6317..7526513a1d 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/TennisArea.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/TennisArea.cs
@@ -5,7 +5,7 @@ public class TennisArea : MonoBehaviour
public GameObject ball;
public GameObject agentA;
public GameObject agentB;
- private Rigidbody m_BallRb;
+ Rigidbody m_BallRb;
// Use this for initialization
void Start()
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Tennis/TFModels/Tennis.nn b/UnitySDK/Assets/ML-Agents/Examples/Tennis/TFModels/Tennis.nn
index 53807f719a..59cd524359 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/Tennis/TFModels/Tennis.nn and b/UnitySDK/Assets/ML-Agents/Examples/Tennis/TFModels/Tennis.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs b/UnitySDK/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
index 231d7eb402..02cd5d4439 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
@@ -27,11 +27,11 @@ public class WalkerAgent : Agent
bool m_IsNewDecisionStep;
int m_CurrentDecisionStep;
- private Rigidbody m_HipsRb;
- private Rigidbody m_ChestRb;
- private Rigidbody m_SpineRb;
+ Rigidbody m_HipsRb;
+ Rigidbody m_ChestRb;
+ Rigidbody m_SpineRb;
- private ResetParameters m_ResetParams;
+ ResetParameters m_ResetParams;
public override void InitializeAgent()
{
@@ -103,7 +103,7 @@ public override void CollectObservations()
}
}
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
m_DirToTarget = target.position - m_JdController.bodyPartsDict[hips].rb.position;
diff --git a/UnitySDK/Assets/ML-Agents/Examples/Walker/TFModels/Walker.nn b/UnitySDK/Assets/ML-Agents/Examples/Walker/TFModels/Walker.nn
index 2d115a1d71..9419f7b9ce 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/Walker/TFModels/Walker.nn and b/UnitySDK/Assets/ML-Agents/Examples/Walker/TFModels/Walker.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/WallJump/Prefabs/WallJumpArea.prefab b/UnitySDK/Assets/ML-Agents/Examples/WallJump/Prefabs/WallJumpArea.prefab
index 3521ab9154..3d4d3d8790 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/WallJump/Prefabs/WallJumpArea.prefab
+++ b/UnitySDK/Assets/ML-Agents/Examples/WallJump/Prefabs/WallJumpArea.prefab
@@ -39,7 +39,8 @@ GameObject:
- component: {fileID: 54678503543725326}
- component: {fileID: 114898893333200490}
- component: {fileID: 114925928594762506}
- - component: {fileID: 114092229367912210}
+ - component: {fileID: 114458838850320084}
+ - component: {fileID: 114227939525648256}
m_Layer: 0
m_Name: Agent
m_TagString: agent
@@ -1052,7 +1053,7 @@ BoxCollider:
serializedVersion: 2
m_Size: {x: 1, y: 1, z: 1}
m_Center: {x: 0, y: 0, z: 0}
---- !u!114 &114092229367912210
+--- !u!114 &114227939525648256
MonoBehaviour:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
@@ -1060,9 +1061,50 @@ MonoBehaviour:
m_GameObject: {fileID: 1195095783991828}
m_Enabled: 1
m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: bb172294dbbcc408286b156a2c4b553c, type: 3}
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
m_Name:
m_EditorClassIdentifier:
+ sensorName: OffsetRayPerceptionSensor
+ detectableTags:
+ - wall
+ - goal
+ - block
+ raysPerDirection: 3
+ maxRayDegrees: 90
+ startVerticalOffset: 2.5
+ endVerticalOffset: 5
+ sphereCastRadius: 0.5
+ rayLength: 20
+ observationStacks: 6
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
+--- !u!114 &114458838850320084
+MonoBehaviour:
+ m_ObjectHideFlags: 1
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 100100000}
+ m_GameObject: {fileID: 1195095783991828}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 11500000, guid: 6bb6b867a41448888c1cd4f99643ad71, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ sensorName: RayPerceptionSensor
+ detectableTags:
+ - wall
+ - goal
+ - block
+ raysPerDirection: 3
+ maxRayDegrees: 90
+ startVerticalOffset: 0
+ endVerticalOffset: 0
+ sphereCastRadius: 0.5
+ rayLength: 20
+ observationStacks: 6
+ rayHitColor: {r: 1, g: 0, b: 0, a: 1}
+ rayMissColor: {r: 1, g: 1, b: 1, a: 1}
+ useWorldPositions: 1
--- !u!114 &114898893333200490
MonoBehaviour:
m_ObjectHideFlags: 1
@@ -1075,12 +1117,12 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- vectorObservationSize: 74
+ vectorObservationSize: 4
numStackedVectorObservations: 6
vectorActionSize: 03000000030000000300000002000000
vectorActionDescriptions: []
vectorActionSpaceType: 0
- m_Model: {fileID: 11400000, guid: fb2ce36eb40b6480e94ea0b5d7573e47, type: 3}
+ m_Model: {fileID: 11400000, guid: 0468bf44b1efd4992b6bf22cadb50d89, type: 3}
m_InferenceDevice: 0
m_UseHeuristic: 0
m_BehaviorName: SmallWallJump
diff --git a/UnitySDK/Assets/ML-Agents/Examples/WallJump/Scenes/WallJump.unity b/UnitySDK/Assets/ML-Agents/Examples/WallJump/Scenes/WallJump.unity
index 59461a669b..04a09e7383 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/WallJump/Scenes/WallJump.unity
+++ b/UnitySDK/Assets/ML-Agents/Examples/WallJump/Scenes/WallJump.unity
@@ -38,7 +38,7 @@ RenderSettings:
m_ReflectionIntensity: 1
m_CustomReflection: {fileID: 0}
m_Sun: {fileID: 0}
- m_IndirectSpecularColor: {r: 0.44971442, g: 0.499779, b: 0.5756377, a: 1}
+ m_IndirectSpecularColor: {r: 0.44971484, g: 0.49977958, b: 0.5756385, a: 1}
--- !u!157 &3
LightmapSettings:
m_ObjectHideFlags: 0
diff --git a/UnitySDK/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs b/UnitySDK/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs
index 131e6d780f..6b385c142f 100644
--- a/UnitySDK/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs
+++ b/UnitySDK/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs
@@ -30,7 +30,6 @@ public class WallJumpAgent : Agent
Material m_GroundMaterial;
Renderer m_GroundRenderer;
WallJumpAcademy m_Academy;
- RayPerception m_RayPer;
public float jumpingTime;
public float jumpTime;
@@ -42,14 +41,10 @@ public class WallJumpAgent : Agent
Vector3 m_JumpTargetPos;
Vector3 m_JumpStartingPos;
- string[] m_DetectableObjects;
-
public override void InitializeAgent()
{
m_Academy = FindObjectOfType();
- m_RayPer = GetComponent();
m_Configuration = Random.Range(0, 5);
- m_DetectableObjects = new[] { "wall", "goal", "block" };
m_AgentRb = GetComponent();
m_ShortBlockRb = shortBlock.GetComponent();
@@ -139,12 +134,6 @@ void MoveTowards(
public override void CollectObservations()
{
- var rayDistance = 20f;
- float[] rayAngles = { 0f, 45f, 90f, 135f, 180f, 110f, 70f };
- AddVectorObs(m_RayPer.Perceive(
- rayDistance, rayAngles, m_DetectableObjects, 0f, 0f));
- AddVectorObs(m_RayPer.Perceive(
- rayDistance, rayAngles, m_DetectableObjects, 2.5f, 2.5f));
var agentPos = m_AgentRb.position - ground.transform.position;
AddVectorObs(agentPos / 20f);
@@ -233,7 +222,7 @@ public void MoveAgent(float[] act)
jumpingTime -= Time.fixedDeltaTime;
}
- public override void AgentAction(float[] vectorAction, string textAction)
+ public override void AgentAction(float[] vectorAction)
{
MoveAgent(vectorAction);
if ((!Physics.Raycast(m_AgentRb.position, Vector3.down, 20))
@@ -299,7 +288,7 @@ public override void AgentReset()
m_AgentRb.velocity = default(Vector3);
}
- private void FixedUpdate()
+ void FixedUpdate()
{
if (m_Configuration != -1)
{
diff --git a/UnitySDK/Assets/ML-Agents/Examples/WallJump/TFModels/BigWallJump.nn b/UnitySDK/Assets/ML-Agents/Examples/WallJump/TFModels/BigWallJump.nn
index 7113fb3f04..b2f3307f0d 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/WallJump/TFModels/BigWallJump.nn and b/UnitySDK/Assets/ML-Agents/Examples/WallJump/TFModels/BigWallJump.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Examples/WallJump/TFModels/SmallWallJump.nn b/UnitySDK/Assets/ML-Agents/Examples/WallJump/TFModels/SmallWallJump.nn
index 30c34e1294..ffaad860c9 100644
Binary files a/UnitySDK/Assets/ML-Agents/Examples/WallJump/TFModels/SmallWallJump.nn and b/UnitySDK/Assets/ML-Agents/Examples/WallJump/TFModels/SmallWallJump.nn differ
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda.md b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda.md
deleted file mode 100644
index 26c5a638cb..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda.md
+++ /dev/null
@@ -1,257 +0,0 @@
-
-
-# Barracuda
-
-**Barracuda** is a lightweight and **cross-platform** Neural Net **inference library for Unity**. Barracuda can execute both on GPU and CPU. Currently Barracuda is in the early development stage, so adventures are expected.
-
-## Using Barracuda
-Typically the following steps are needed to use Barracuda in application:
-1. load model,
-2. create inference engine (the worker),
-3. execute model and
-4. fetch results.
-
-But first you have to convert your TensorFlow (or ONNX) model to Barracuda format with python scripts. Example usage:
-```bash
-python onnx_to_barracuda.py Models/mnist/model.onnx Destination/mnist.bytes
-```
-See _Converting models to Barracuda_ paragraph below for more information.
-
-### Load Model into Barracuda
-Once you have your TensorFlow (or ONNX) model converted, you can load resulting Barracuda file via `ModelLoader`:
-```C#
-var model = ModelLoader.LoadFromStreamingAssets(modelName + ".nn");
-```
-Another option is to use editor model importer. Just add public `NNModel` field to your C# script and assing ``.nn`` model file via editor UI:
-```C#
-public NNModel modelSource;
-<..>
-var model = ModelLoader.Load(modelSource);
-```
-
-### Create inference engine (Worker)
-Inference engine in Barracuda is called Worker. Worker is responsible for converting model into executable tasks and scheduling them on GPU or CPU.
-```C#
-var worker = BarracudaWorkerFactory.CreateWorker(BarracudaWorkerFactory.Type.ComputePrecompiled, model)
-```
-
-### Execute the model
-Inputs can be provided both as sole `Tensor` object (assuming Model has only one input) or as a dictionary of name and `Tensor` pairs.
-
-```C#
-var inputs = new Dictionary();
-inputs[name1] = new Tensor(...);
-inputs[name2] = new Tensor(...);
-worker.Execute(inputs);
-```
-Execution is asynchronous for GPU backends. Currently implementation is synchronous for CPU backends, however it is good to assume that execution will be async for all backends in the future.
-
-### Fetch outputs
-If model has only single output, then simple `worker.Peek()` can be used, otherwise output names should be provided.
-```C#
-var O = worker.Peek(outputName);
-```
-_Note:_ ``Peek()`` does not take ownership of the tensor. If you expect to keep tensor for longer time use ``Fetch()``
-
-### Cleanup
-As a Barracuda client you are responsible to `Dispose` _worker_, _inputs_ and _outputs_ you fetched. This is necessary to properly free GPU resources.
-```C#
-O.Dispose();
-worker.Dispose();
-```
-
-## Working with data
-
-### Tensor
-Barracuda stores data in `batch`,`height`,`width`,`channels` also known as _NHWC_ or _channels-last_ format. You can interact with `Tensor` data via multi-dimensional array operators:
-```C#
-var tensor = new Tensor(batchCount, height, width, channelCount);
-tensor[n, y, x, c] = 1.0f; // as N batches of 3 dimensional data: N x {X, Y, C}
-tensor[n, c] = 2.0f; // as N batches of 1 dimensional data: N x {C}
-tensor[ i] = 3.0f; // as flat array
-```
-
-There are number of `Tensor` constructors that cover variety of scenarios. By default tensors are initialized with `0` upon construction, unless intialization `Array` is provided.
-```C#
-tensor = new Tensor(batchCount, height, width, channelCount); // batch of 3 dimensional data, 0 initialized: batchCount x {height, width, channelCount}
-tensor = new Tensor(batchCount, elementCount); // batch of 1 dimensional data, 0 initialized: batchCount x {elementCount}
-
-var stridedArray = new float[batchCount * elementCount] { ... };
-tensor = new Tensor(batchCount, elementCount, stridedArray); // batch of 1 dimensional data, initialized from strided array
-
-var jaggedArray = new float[batchCount][elementCount] { ... };
-tensor = new Tensor(batchCount, elementCount, jaggedArray); // batch of 1 dimensional data, initialized from jagged array
-
-Texture2D texture = ...;
-tensor = new Tensor(texture); // tensor initialized with texture data: 1 x { texture.width, texture.height, 3}
-```
-
-You can query shape of the `Tensor` object, but you can not change it. Shape of the `Tensor` is immutable. If you want to have different shape of `Tensor`, you have to construct the new instance of `Tensor` object.
-```C#
-var shape = tensor.shape;
-Debug.Log(shape + " or " + shape.batch + shape.height + shape.width + shape.channels);
-```
-
-### Texture as input
-You can directly pass `Texture2D`, `Texture2DArray`, `Texture3D` or `RenderTexture` to Barracuda without accessing individual pixels on CPU:
-```C#
-var channelCount = 3; // you can treat input pixels as 1 (grayscale), 3 (color) or 4 (color with alpha) channels
-var tensor = new Tensor(texture, channelCount);
-```
-You can batch multiple textures into the single `Tensor` object:
-```C#
-var textures = new [] { texture0, texture1, texture2, texture3 }; // these textures will form a batch
-var tensor = new Tensor(textures, channelCount);
-```
-Note that to form a batch all textures must have the same width and height dimensions.
-
-### Texture as output
-If you want to use Barracuda execution results further in the graphics pipeline, you can copy data from `Tensor` into `RenderTexture` without stalling CPU or GPU:
-```C#
- var tensor = worker.Peek();
- var texture = BarracudaTextureUtils.TensorToRenderTexture(tensor);
-```
-If you wish, you can reuse the same `RenderTexture` multiple times:
-```C#
- var texture = new RenderTexture(width, height, 0);
- // ...
- var tensor = worker.Peek();
- BarracudaTextureUtils.TensorToRenderTexture(tensor, texture);
-```
-
-## Introspecting Barracuda models
-Barracuda model has very simple memory representation. Once model is loaded you can query for inputs and outputs:
-```C#
-string[] inputNames = model.inputs; // query model inputs
-string[] outputNames = model.outputs; // query model outputs
-```
-Or you can directly iterate through the layers and investigate what model is going to do:
-```C#
-foreach (var layer in model.layers)
- Debug.Log(layer.name + " does " + layer.type);
-```
-
-## Verbose mode
-You can turn on verbose mode for different parts of Barracuda:
-```C#
-bool verbose = true;
-var model = ModelLoader.LoadFromStreamingAssets(modelName + ".bytes", verbose); // verbose loader
-var worker = BarracudaWorkerFactory.CreateWorker(BarracudaWorkerFactory.Type.ComputeFast, model, verbose); // verbose execution
-```
-
-## Converting TensorFlow and ONNX models to Barracuda format
-Barracuda comes with dedicated python scripts to convert pre-trained TensorFlow and ONNX models to Barracuda format.
-
-Convert from TensorFlow:
-```bash
-python tensorflow_to_barracuda.py Models/3DBall-tf-model.pb Destination/3DBall-bc.nn
-```
-
-Convert from ONNX:
-```bash
-python onnx_to_barracuda.py Models/mnist/model.onnx Destination/mnist-bc.nn
-```
-
-If network has multiple outputs, but you need only particular ones during the inference, there is an optional `-trim` flag to remove unused outputs and calculations.
-For example:
-```bash
-python tensorflow_to_barracuda.py Models/3DBall-tf-model.pb Destination/3DBall-bc.bytes -trim action$
-```
-Trim will first remove outputs that do not match regular expression from the graph. In this case only output that ends with `action` will be left.
-Next trim will strip all nodes that do not participate in the evaluation of the output.
-
-You could pass `--print-supported-ops` to get approximate list of supported operations/activations for specific converter.
-
-## Approximate list of supported layers/operations for TensorFlow converter
-```
-Activation
-Add
-AvgPool
-BatchNormalization
-BatchNormalizationRuntime
-BiasAdd
-Concat
-Conv2D
-Conv2DBackpropInput
-Dense
-DepthwiseConv2dNative
-Flatten
-FusedBatchNorm
-GlobalAveragePool
-GlobalAvgPool
-InstanceNormalization
-LRN
-MatMul
-Max
-MaxPool
-Maximum
-Mean
-Min
-Minimum
-Mul
-Multinomial
-Nop
-OneHot
-Pad
-Pow
-Prod
-RandomStandardNormal
-RandomUniform
-RealDiv
-Reshape
-ResizeBicubic
-ResizeBilinear
-ResizeNearestNeighbor
-StridedSlice
-Sub
-Sum
-
-```
-
-## Approximate list of supported activations for TensorFlow converter
-```
-Abs
-Acos
-Acosh
-Asin
-Asinh
-Atan
-Atanh
-Ceil
-Cos
-Cosh
-Elu
-Exp
-Floor
-LeakyRelu
-Linear
-Log
-LogSoftmax
-Neg
-Relu
-Relu6
-Selu
-Sigmoid
-Sin
-Sinh
-Softmax
-Softplus
-Softsign
-Sqrt
-Swish
-Tan
-Tanh
-```
-
-P.S. some of these operations are under limited support and not all configurations are properly supported
-
-P.P.S. Python 3.5 or 3.6 is recommended
-
-P.P.P.S. We plan to migrate Tensorflow and ONNX converters from Python to C# in the future.
-
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda.md.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda.md.meta
deleted file mode 100644
index 4a967c3801..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda.md.meta
+++ /dev/null
@@ -1,7 +0,0 @@
-fileFormatVersion: 2
-guid: 3cf2bcd7dcfe144bebf6cf271e7dfbe0
-TextScriptImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda.meta
deleted file mode 100644
index c142006d60..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda.meta
+++ /dev/null
@@ -1,8 +0,0 @@
-fileFormatVersion: 2
-guid: 4d59cec597ba94288831c0cade38b14e
-folderAsset: yes
-DefaultImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Barracuda.dll b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Barracuda.dll
deleted file mode 100644
index a9f15d0a01..0000000000
Binary files a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Barracuda.dll and /dev/null differ
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Barracuda.dll.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Barracuda.dll.meta
deleted file mode 100644
index 3e4f56da40..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Barracuda.dll.meta
+++ /dev/null
@@ -1,30 +0,0 @@
-fileFormatVersion: 2
-guid: de59cc66e5e394f93b2a692e50bce97f
-PluginImporter:
- externalObjects: {}
- serializedVersion: 2
- iconMap: {}
- executionOrder: {}
- isPreloaded: 0
- isOverridable: 0
- platformData:
- - first:
- Any:
- second:
- enabled: 1
- settings: {}
- - first:
- Editor: Editor
- second:
- enabled: 0
- settings:
- DefaultValueInitialized: true
- - first:
- Windows Store Apps: WindowsStoreApps
- second:
- enabled: 0
- settings:
- CPU: AnyCPU
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins.meta
deleted file mode 100644
index d253192dc0..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins.meta
+++ /dev/null
@@ -1,8 +0,0 @@
-fileFormatVersion: 2
-guid: a7bba248e968b476a875260a8127a595
-folderAsset: yes
-DefaultImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor.meta
deleted file mode 100644
index 5e8284ace0..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor.meta
+++ /dev/null
@@ -1,8 +0,0 @@
-fileFormatVersion: 2
-guid: 4b10c58689ee84c2abe895327686f532
-folderAsset: yes
-DefaultImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor.meta
deleted file mode 100644
index 4b0693ff39..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor.meta
+++ /dev/null
@@ -1,8 +0,0 @@
-fileFormatVersion: 2
-guid: e192a80b369ad4683a329432eeb5ec20
-folderAsset: yes
-DefaultImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/Barracuda-editor.asmdef b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/Barracuda-editor.asmdef
deleted file mode 100644
index b10599b462..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/Barracuda-editor.asmdef
+++ /dev/null
@@ -1,8 +0,0 @@
-{
- "name": "Barracuda-editor",
- "references": [],
- "includePlatforms": [
- "Editor"
- ],
- "excludePlatforms": []
-}
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/Barracuda-editor.asmdef.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/Barracuda-editor.asmdef.meta
deleted file mode 100644
index 7f0c301a87..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/Barracuda-editor.asmdef.meta
+++ /dev/null
@@ -1,7 +0,0 @@
-fileFormatVersion: 2
-guid: 9f1e7d835703842dda0e25142ed6c3c9
-AssemblyDefinitionImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/NNModelIcon.png b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/NNModelIcon.png
deleted file mode 100644
index 10434c2792..0000000000
Binary files a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/NNModelIcon.png and /dev/null differ
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/NNModelIcon.png.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/NNModelIcon.png.meta
deleted file mode 100644
index 9a88c6d197..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/NNModelIcon.png.meta
+++ /dev/null
@@ -1,106 +0,0 @@
-fileFormatVersion: 2
-guid: 8682ff569c4c7457a8a8e3a527aad537
-TextureImporter:
- fileIDToRecycleName: {}
- externalObjects: {}
- serializedVersion: 4
- mipmaps:
- mipMapMode: 0
- enableMipMap: 0
- sRGBTexture: 0
- linearTexture: 0
- fadeOut: 0
- borderMipMap: 0
- mipMapsPreserveCoverage: 0
- alphaTestReferenceValue: 0.5
- mipMapFadeDistanceStart: 1
- mipMapFadeDistanceEnd: 3
- bumpmap:
- convertToNormalMap: 0
- externalNormalMap: 0
- heightScale: 0.25
- normalMapFilter: 0
- isReadable: 0
- grayScaleToAlpha: 0
- generateCubemap: 6
- cubemapConvolution: 0
- seamlessCubemap: 0
- textureFormat: 1
- maxTextureSize: 2048
- textureSettings:
- serializedVersion: 2
- filterMode: -1
- aniso: 1
- mipBias: -1
- wrapU: 1
- wrapV: 1
- wrapW: -1
- nPOTScale: 0
- lightmap: 0
- compressionQuality: 50
- spriteMode: 0
- spriteExtrude: 1
- spriteMeshType: 1
- alignment: 0
- spritePivot: {x: 0.5, y: 0.5}
- spritePixelsToUnits: 100
- spriteBorder: {x: 0, y: 0, z: 0, w: 0}
- spriteGenerateFallbackPhysicsShape: 1
- alphaUsage: 1
- alphaIsTransparency: 1
- spriteTessellationDetail: -1
- textureType: 2
- textureShape: 1
- maxTextureSizeSet: 0
- compressionQualitySet: 0
- textureFormatSet: 0
- platformSettings:
- - buildTarget: DefaultTexturePlatform
- maxTextureSize: 2048
- resizeAlgorithm: 0
- textureFormat: -1
- textureCompression: 1
- compressionQuality: 50
- crunchedCompression: 0
- allowsAlphaSplitting: 0
- overridden: 0
- androidETC2FallbackOverride: 0
- - buildTarget: Standalone
- maxTextureSize: 2048
- resizeAlgorithm: 0
- textureFormat: -1
- textureCompression: 1
- compressionQuality: 50
- crunchedCompression: 0
- allowsAlphaSplitting: 0
- overridden: 0
- androidETC2FallbackOverride: 0
- - buildTarget: iPhone
- maxTextureSize: 2048
- resizeAlgorithm: 0
- textureFormat: -1
- textureCompression: 1
- compressionQuality: 50
- crunchedCompression: 0
- allowsAlphaSplitting: 0
- overridden: 0
- androidETC2FallbackOverride: 0
- - buildTarget: Android
- maxTextureSize: 2048
- resizeAlgorithm: 0
- textureFormat: -1
- textureCompression: 1
- compressionQuality: 50
- crunchedCompression: 0
- allowsAlphaSplitting: 0
- overridden: 0
- androidETC2FallbackOverride: 0
- spriteSheet:
- serializedVersion: 2
- sprites: []
- outline: []
- physicsShape: []
- spritePackingTag:
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/NNModelImporter.cs b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/NNModelImporter.cs
deleted file mode 100644
index e6317a9232..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/Editor/BarracudaEditor/NNModelImporter.cs
+++ /dev/null
@@ -1,42 +0,0 @@
-using System.IO;
-using UnityEditor;
-using UnityEngine;
-using UnityEditor.Experimental.AssetImporters;
-
-namespace Barracuda
-{
- ///
- /// Asset Importer of barracuda models.
- ///
- [ScriptedImporter(1, new[] {"nn"})]
- public class NNModelImporter : ScriptedImporter
- {
- private const string k_IconName = "NNModelIcon";
-
- private Texture2D m_IconTexture;
-
- public override void OnImportAsset(AssetImportContext ctx)
- {
- var model = File.ReadAllBytes(ctx.assetPath);
- var asset = ScriptableObject.CreateInstance();
- asset.Value = model;
-
- ctx.AddObjectToAsset("main obj", asset, LoadIconTexture());
- ctx.SetMainObject(asset);
- }
-
- private Texture2D LoadIconTexture()
- {
- if (m_IconTexture == null)
- {
- var allCandidates = AssetDatabase.FindAssets(k_IconName);
-
- if (allCandidates.Length > 0)
- {
- m_IconTexture = AssetDatabase.LoadAssetAtPath(AssetDatabase.GUIDToAssetPath(allCandidates[0]), typeof(Texture2D)) as Texture2D;
- }
- }
- return m_IconTexture;
- }
- }
-}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX.meta
deleted file mode 100644
index ecc28271e7..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX.meta
+++ /dev/null
@@ -1,8 +0,0 @@
-fileFormatVersion: 2
-guid: 5087a463bec2b4b76808e7307a94887f
-folderAsset: yes
-DefaultImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.asmdef b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.asmdef
deleted file mode 100644
index 9d6f291afe..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.asmdef
+++ /dev/null
@@ -1,11 +0,0 @@
-{
- "name": "MacBLAS",
- "references": [],
- "optionalUnityReferences": [],
- "includePlatforms": [
- "Editor",
- "macOSStandalone"
- ],
- "excludePlatforms": [],
- "allowUnsafeCode": true
-}
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.asmdef.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.asmdef.meta
deleted file mode 100644
index 4a3cefc87a..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.asmdef.meta
+++ /dev/null
@@ -1,7 +0,0 @@
-fileFormatVersion: 2
-guid: 53fc9961397934ed38a573ce1392c80c
-AssemblyDefinitionImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.cs b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.cs
deleted file mode 100644
index fdd22b3647..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.cs
+++ /dev/null
@@ -1,29 +0,0 @@
-#if UNITY_STANDALONE_OSX || UNITY_EDITOR_OSX
-using System.Runtime.InteropServices;
-using Barracuda;
-using UnityEngine;
-using UnityEngine.Scripting;
-
-
-[Preserve]
-public class MacBLAS : BLASPlugin
-{
- [DllImport("macblas")]
- static extern unsafe void macsgemm(float* ap, int an, int am,
- float* bp, int bn, int bm,
- float* cp, int cn, int cm,
- int bs, bool transposeA, bool transposeB);
-
- public bool IsCurrentPlatformSupported()
- {
- return Application.platform == RuntimePlatform.OSXEditor ||
- Application.platform == RuntimePlatform.OSXPlayer;
- }
-
- public unsafe void SGEMM(float* ap, int an, int am, float* bp, int bn, int bm, float* cp, int cn, int cm, int bs,
- bool transposeA = false, bool transposeB = false)
- {
- macsgemm(ap, an, am, bp, bn, bm, cp, cn, cm, bs, transposeA, transposeB);
- }
-}
-#endif // UNITY_OSX
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle.meta
deleted file mode 100644
index c73e210085..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle.meta
+++ /dev/null
@@ -1,40 +0,0 @@
-fileFormatVersion: 2
-guid: 6633afded85ec4f00a4cc653053461bb
-folderAsset: yes
-PluginImporter:
- externalObjects: {}
- serializedVersion: 2
- iconMap: {}
- executionOrder: {}
- isPreloaded: 0
- isOverridable: 0
- platformData:
- - first:
- '': OSXIntel
- second:
- enabled: 1
- settings: {}
- - first:
- '': OSXIntel64
- second:
- enabled: 1
- settings: {}
- - first:
- Any:
- second:
- enabled: 0
- settings: {}
- - first:
- Editor: Editor
- second:
- enabled: 1
- settings:
- DefaultValueInitialized: true
- - first:
- Standalone: OSXUniversal
- second:
- enabled: 1
- settings: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents.meta
deleted file mode 100644
index a0a3fd804a..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents.meta
+++ /dev/null
@@ -1,8 +0,0 @@
-fileFormatVersion: 2
-guid: 5de42c62131964fc999e1dc3d292cc31
-folderAsset: yes
-DefaultImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/Info.plist b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/Info.plist
deleted file mode 100644
index 22d6943a39..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/Info.plist
+++ /dev/null
@@ -1,40 +0,0 @@
-
-
-
-
- BuildMachineOSBuild
- 14F27
- CFBundleDevelopmentRegion
- en
- CFBundleExecutable
- macblas
- CFBundleIdentifier
- com.unity3d.macblas
- CFBundleInfoDictionaryVersion
- 6.0
- CFBundleName
- macblas
- CFBundlePackageType
- BNDL
- CFBundleShortVersionString
- 0.1.4
- CFBundleVersion
- 1
- DTCompiler
- com.apple.compilers.llvm.clang.1_0
- DTPlatformBuild
- 6A1052d
- DTPlatformVersion
- GM
- DTSDKBuild
- 14A382
- DTSDKName
- macosx10.10
- DTXcode
- 0610
- DTXcodeBuild
- 6A1052d
- NSHumanReadableCopyright
- Copyright © 2018 Unity Technologies. All rights reserved.
-
-
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/Info.plist.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/Info.plist.meta
deleted file mode 100644
index 2a9aa9e42d..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/Info.plist.meta
+++ /dev/null
@@ -1,7 +0,0 @@
-fileFormatVersion: 2
-guid: 844f003f25d444aafad9fb1fcea17bbc
-DefaultImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS.meta
deleted file mode 100644
index dc277cfa8d..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS.meta
+++ /dev/null
@@ -1,8 +0,0 @@
-fileFormatVersion: 2
-guid: 0620b207d80004fe595413acf79f2f66
-folderAsset: yes
-DefaultImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS/macblas b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS/macblas
deleted file mode 100755
index e3f52632bb..0000000000
Binary files a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS/macblas and /dev/null differ
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS/macblas.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS/macblas.meta
deleted file mode 100644
index 7077e86696..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/MacOS/macblas.meta
+++ /dev/null
@@ -1,7 +0,0 @@
-fileFormatVersion: 2
-guid: e9ef2c9e25cad478aa1220d6cf68a2ed
-DefaultImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature.meta
deleted file mode 100644
index 2a52881cc5..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature.meta
+++ /dev/null
@@ -1,8 +0,0 @@
-fileFormatVersion: 2
-guid: 93038b433855548879a151644d2354c1
-folderAsset: yes
-DefaultImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature/CodeResources b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature/CodeResources
deleted file mode 100644
index 0710b40083..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature/CodeResources
+++ /dev/null
@@ -1,105 +0,0 @@
-
-
-
-
- files
-
- files2
-
- rules
-
- ^Resources/
-
- ^Resources/.*\.lproj/
-
- optional
-
- weight
- 1000
-
- ^Resources/.*\.lproj/locversion.plist$
-
- omit
-
- weight
- 1100
-
- ^version.plist$
-
-
- rules2
-
- .*\.dSYM($|/)
-
- weight
- 11
-
- ^(.*/)?\.DS_Store$
-
- omit
-
- weight
- 2000
-
- ^(Frameworks|SharedFrameworks|PlugIns|Plug-ins|XPCServices|Helpers|MacOS|Library/(Automator|Spotlight|LoginItems))/
-
- nested
-
- weight
- 10
-
- ^.*
-
- ^Info\.plist$
-
- omit
-
- weight
- 20
-
- ^PkgInfo$
-
- omit
-
- weight
- 20
-
- ^Resources/
-
- weight
- 20
-
- ^Resources/.*\.lproj/
-
- optional
-
- weight
- 1000
-
- ^Resources/.*\.lproj/locversion.plist$
-
- omit
-
- weight
- 1100
-
- ^[^/]+$
-
- nested
-
- weight
- 10
-
- ^embedded\.provisionprofile$
-
- weight
- 20
-
- ^version\.plist$
-
- weight
- 20
-
-
-
-
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature/CodeResources.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature/CodeResources.meta
deleted file mode 100644
index 87c151ef59..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/macblas.bundle/Contents/_CodeSignature/CodeResources.meta
+++ /dev/null
@@ -1,7 +0,0 @@
-fileFormatVersion: 2
-guid: 523ab7e7760c743a9977ecfedabe1691
-DefaultImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS.meta
deleted file mode 100644
index 0d588e91b6..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS.meta
+++ /dev/null
@@ -1,8 +0,0 @@
-fileFormatVersion: 2
-guid: 256085e1b062345239f3d7d88741f96c
-folderAsset: yes
-DefaultImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.asmdef b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.asmdef
deleted file mode 100644
index ba5816655e..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.asmdef
+++ /dev/null
@@ -1,11 +0,0 @@
-{
- "name": "iOSBLAS",
- "references": [],
- "optionalUnityReferences": [],
- "includePlatforms": [
- "Editor",
- "iOS"
- ],
- "excludePlatforms": [],
- "allowUnsafeCode": true
-}
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.asmdef.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.asmdef.meta
deleted file mode 100644
index 5b93d7691a..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.asmdef.meta
+++ /dev/null
@@ -1,7 +0,0 @@
-fileFormatVersion: 2
-guid: 005937e819cd540429ad05eabcfb642f
-AssemblyDefinitionImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.cs b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.cs
deleted file mode 100644
index f6f66f2995..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.cs
+++ /dev/null
@@ -1,27 +0,0 @@
-#if UNITY_IOS
-using System.Runtime.InteropServices;
-using Barracuda;
-using UnityEngine;
-using UnityEngine.Scripting;
-
-[Preserve]
-public class iOSBLAS : BLASPlugin
-{
- [DllImport("__Internal")]
- static extern unsafe void iossgemm(float* Ap, int AN, int AM,
- float* Bp, int BN, int BM,
- float* Cp, int CN, int CM,
- int bs, bool transposeA, bool transposeB);
-
- public bool IsCurrentPlatformSupported()
- {
- return Application.platform == RuntimePlatform.IPhonePlayer;
- }
-
- public unsafe void SGEMM(float* Ap, int AN, int AM, float* Bp, int BN, int BM, float* Cp, int CN, int CM, int bs,
- bool transposeA = false, bool transposeB = false)
- {
- iossgemm(Ap, AN, AM, Bp, BN, BM, Cp, CN, CM, bs, transposeA, transposeB);
- }
-}
-#endif // UNITY_IOS
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.mm b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.mm
deleted file mode 100644
index 15cbe6c76d..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.mm
+++ /dev/null
@@ -1,15 +0,0 @@
-#import
-
-extern "C"
-{
-void iossgemm(float* Ap, int AN, int AM,
- float* Bp, int BN, int BM,
- float* Cp, int CN, int CM,
- int bs, bool transposeA, bool transposeB)
- {
- cblas_sgemm(CblasRowMajor, transposeA ? CblasTrans : CblasNoTrans,
- transposeB ? CblasTrans : CblasNoTrans,
- AN, BM, BN, 1.0f, Ap, AM, Bp, BM, 1.0f, Cp, CM);
- }
-
-}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.mm.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.mm.meta
deleted file mode 100644
index 2fa3f6de9b..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.mm.meta
+++ /dev/null
@@ -1,102 +0,0 @@
-fileFormatVersion: 2
-guid: 100b08f95d9f349118f287b0170140d4
-PluginImporter:
- externalObjects: {}
- serializedVersion: 2
- iconMap: {}
- executionOrder: {}
- isPreloaded: 0
- isOverridable: 0
- platformData:
- - first:
- '': Any
- second:
- enabled: 0
- settings:
- Exclude Android: 1
- Exclude Editor: 1
- Exclude Linux: 1
- Exclude Linux64: 1
- Exclude LinuxUniversal: 1
- Exclude OSXUniversal: 1
- Exclude WebGL: 1
- Exclude Win: 1
- Exclude Win64: 1
- Exclude iOS: 0
- - first:
- Android: Android
- second:
- enabled: 0
- settings:
- CPU: ARMv7
- - first:
- Any:
- second:
- enabled: 0
- settings: {}
- - first:
- Editor: Editor
- second:
- enabled: 0
- settings:
- CPU: AnyCPU
- DefaultValueInitialized: true
- OS: AnyOS
- - first:
- Facebook: Win
- second:
- enabled: 0
- settings:
- CPU: AnyCPU
- - first:
- Facebook: Win64
- second:
- enabled: 0
- settings:
- CPU: AnyCPU
- - first:
- Standalone: Linux
- second:
- enabled: 0
- settings:
- CPU: x86
- - first:
- Standalone: Linux64
- second:
- enabled: 0
- settings:
- CPU: x86_64
- - first:
- Standalone: OSXUniversal
- second:
- enabled: 0
- settings:
- CPU: AnyCPU
- - first:
- Standalone: Win
- second:
- enabled: 0
- settings:
- CPU: AnyCPU
- - first:
- Standalone: Win64
- second:
- enabled: 0
- settings:
- CPU: AnyCPU
- - first:
- iPhone: iOS
- second:
- enabled: 1
- settings:
- AddToEmbeddedBinaries: false
- CompileFlags:
- FrameworkDependencies: Accelerate;
- - first:
- tvOS: tvOS
- second:
- enabled: 1
- settings: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources.meta
deleted file mode 100644
index da72593ca5..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources.meta
+++ /dev/null
@@ -1,8 +0,0 @@
-fileFormatVersion: 2
-guid: 264a957219ea041c58af860601fe1881
-folderAsset: yes
-DefaultImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Activation.compute b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Activation.compute
deleted file mode 100644
index 35fc553fad..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Activation.compute
+++ /dev/null
@@ -1,843 +0,0 @@
-#pragma kernel Relu_Flat
-#pragma kernel Relu_Loop
-#pragma kernel Relu6_Flat
-#pragma kernel Relu6_Loop
-#pragma kernel Tanh_Flat
-#pragma kernel Tanh_Loop
-#pragma kernel Swish_Flat
-#pragma kernel Swish_Loop
-#pragma kernel Sigmoid_Flat
-#pragma kernel Sigmoid_Loop
-#pragma kernel Elu_Flat
-#pragma kernel Elu_Loop
-#pragma kernel LeakyRelu_Flat
-#pragma kernel LeakyRelu_Loop
-#pragma kernel Exp_Flat
-#pragma kernel Exp_Loop
-#pragma kernel Log_Flat
-#pragma kernel Log_Loop
-#pragma kernel Pow_Flat
-#pragma kernel Pow_Loop
-
-/*
-Relu_Flat (NEW) vs Relu_Nyxc+Relu_CNyx+Relu
-Compute Precompiled
-
-VGG@1
-<< O.GetLength()) return;\
-\
- float v = X.Get(i);\
- v = op_name (v);\
- O.Set(i, v);\
-}
-
-#define LOOP_ACTIVATION(name, op_name) \
-void name##_Loop (uint3 dispatchThreadID : SV_DispatchThreadID)\
-{\
- DISPATCH_ARGS(O.length, 1, 1)\
- TENSOR_ARGS2(X, O);\
-\
- uint i = dispatchThreadID.x;\
- uint len = O.GetLength();\
-\
- while (i < len) {\
- float v = X.Get(i); \
- v = op_name (v); \
- O.Set(i, v); \
- i += _LoopStride; \
- }\
-}
-
-#define ACTIVATION(name, op_name) \
-NUMTHREADS((512,1,1), (128,1,1), (64,1,1))\
-FLAT_ACTIVATION(name, op_name)\
-NUMTHREADS((512,1,1), (128,1,1), (64,1,1))\
-LOOP_ACTIVATION(name, op_name)
-
-float relu(float v)
-{
- return 0.5f * (v + abs(v));
-}
-
-float relu6(float v)
-{
- return min(max(0, v), 6);
-}
-
-float swish(float v)
-{
- return v / (1.f + exp(-v));
-}
-
-float sigmoid(float v)
-{
- return 1.f / (1.f + exp(-v));
-}
-
-float elu(float v)
-{
- if (v <= 0)
- v = _Alpha * (exp(v) - 1);
- return v;
-}
-
-float lrelu(float v)
-{
- return max(v, _Alpha * v);
-}
-
-float signed_pow(float f)
-{
- float e = _Alpha;
-
- // handle negative f
- float v = pow(abs(f), e);
- float s = (e % 2 == 1) ?
- sign(f): // exponent is odd => sign(f) * pow(abs(f), e)
- 1; // exponent is even => pow(abs(f), e)
- return v * s;
-}
-
-ACTIVATION(Relu, relu)
-ACTIVATION(Relu6, relu6)
-ACTIVATION(Tanh, tanh)
-ACTIVATION(Sigmoid, sigmoid)
-ACTIVATION(Swish, swish)
-ACTIVATION(Elu, elu)
-ACTIVATION(LeakyRelu, lrelu)
-ACTIVATION(Exp, exp)
-ACTIVATION(Log, log)
-ACTIVATION(Pow, signed_pow)
-
-// -------------------
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void Relu(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = relu(v);
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void Relu6(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = relu6(v);
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void Tanh(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = tanh(v);
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
- void Sigmoid(uint3 dispatchThreadID : SV_DispatchThreadID)
- {
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = sigmoid(v);
- O.Set(n, y, x, c, v);
- }
- }
-
- NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void Swish(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = swish(v);
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void Elu(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = elu(v);
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void LeakyRelu(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = lrelu(v);
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void Exp(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = exp(v);
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void Log(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = log(v);
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void Pow(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = signed_pow(v);
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
-void Relu_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint nyx = dispatchThreadID.y;
-
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (c >= X.channels) return;
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = relu(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
-void Relu_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
- TENSOR_ARGS2(X, O);
-
- uint nyxc = dispatchThreadID.x;
-
- uint c = nyxc % X.channels;
- uint nyx = nyxc / X.channels;
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = relu(v);
- O.Set(n, y, x, c, v);
-}
-
-
-NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
-void Relu6_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint nyx = dispatchThreadID.y;
-
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (c >= X.channels) return;
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = relu6(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
-void Relu6_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
- TENSOR_ARGS2(X, O);
-
- uint nyxc = dispatchThreadID.x;
-
- uint c = nyxc % X.channels;
- uint nyx = nyxc / X.channels;
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = relu6(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
-void Tanh_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint nyx = dispatchThreadID.y;
-
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (c >= X.channels) return;
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = tanh(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
-void Tanh_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
- TENSOR_ARGS2(X, O);
-
- uint nyxc = dispatchThreadID.x;
-
- uint c = nyxc % X.channels;
- uint nyx = nyxc / X.channels;
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = tanh(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
-void Sigmoid_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint nyx = dispatchThreadID.y;
-
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (c >= X.channels) return;
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = sigmoid(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
-void Sigmoid_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
- TENSOR_ARGS2(X, O);
-
- uint nyxc = dispatchThreadID.x;
-
- uint c = nyxc % X.channels;
- uint nyx = nyxc / X.channels;
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = sigmoid(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
-void Swish_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint nyx = dispatchThreadID.y;
-
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (c >= X.channels) return;
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = swish(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
-void Swish_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
- TENSOR_ARGS2(X, O);
-
- uint nyxc = dispatchThreadID.x;
-
- uint c = nyxc % X.channels;
- uint nyx = nyxc / X.channels;
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = swish(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
-void Elu_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint nyx = dispatchThreadID.y;
-
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (c >= X.channels) return;
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = elu(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
-void Elu_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
- TENSOR_ARGS2(X, O);
-
- uint nyxc = dispatchThreadID.x;
-
- uint c = nyxc % X.channels;
- uint nyx = nyxc / X.channels;
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = elu(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
-void LeakyRelu_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint nyx = dispatchThreadID.y;
-
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (c >= X.channels) return;
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = lrelu(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
-void LeakyRelu_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
- TENSOR_ARGS2(X, O);
-
- uint nyxc = dispatchThreadID.x;
-
- uint c = nyxc % X.channels;
- uint nyx = nyxc / X.channels;
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = lrelu(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
-void Exp_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint nyx = dispatchThreadID.y;
-
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (c >= X.channels) return;
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = exp(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
-void Exp_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
- TENSOR_ARGS2(X, O);
-
- uint nyxc = dispatchThreadID.x;
-
- uint c = nyxc % X.channels;
- uint nyx = nyxc / X.channels;
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = exp(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
-void Log_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint nyx = dispatchThreadID.y;
-
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (c >= X.channels) return;
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = log(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
-void Log_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
- TENSOR_ARGS2(X, O);
-
- uint nyxc = dispatchThreadID.x;
-
- uint c = nyxc % X.channels;
- uint nyx = nyxc / X.channels;
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = log(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
-void Pow_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint nyx = dispatchThreadID.y;
-
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (c >= X.channels) return;
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = signed_pow(v);
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((512,1,1), (128,1,1), (64,1,1))
-void Pow_Nyxc(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.batch * O.height * O.width * O.channels, 1, 1)
- TENSOR_ARGS2(X, O);
-
- uint nyxc = dispatchThreadID.x;
-
- uint c = nyxc % X.channels;
- uint nyx = nyxc / X.channels;
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (n >= X.batch) return;
-
- float v = X.Get(n, y, x, c);
- v = signed_pow(v);
- O.Set(n, y, x, c, v);
-}
-
-
-NUMTHREADS((64,4,1), (64,2,1), (64,1,1))
-void Softmax(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
- TENSOR_ARGS2(X, O);
-
- uint x = dispatchThreadID.x;
- uint y = dispatchThreadID.y;
-
- if (x >= O.GetFlatWidth()) return;
- if (y >= O.GetFlatHeight()) return;
-
- float maxV = -FLT_MAX;
- for (uint i = 0; i < X.GetFlatWidth(); ++i)
- {
- float v = X.Get(y, i);
- if (v > maxV)
- maxV = v;
- }
-
- float acc = 0.0f;
- for (i = 0; i < X.GetFlatWidth(); ++i)
- {
- float v = X.Get(y, i);
- acc += exp(v - maxV);
- }
-
- float v = X.Get(y, x);
- v = exp(v - maxV) / acc;
- O.Set(y, x, v);
-}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Activation.compute.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Activation.compute.meta
deleted file mode 100644
index 1c31e43523..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Activation.compute.meta
+++ /dev/null
@@ -1,9 +0,0 @@
-fileFormatVersion: 2
-guid: fdc94044b2f234c0fa80ada3771a2ae7
-timeCreated: 1495527718
-licenseType: Pro
-ComputeShaderImporter:
- currentAPIMask: 196608
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/BarracudaReferenceImpl.compute b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/BarracudaReferenceImpl.compute
deleted file mode 100644
index 76eb7f8237..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/BarracudaReferenceImpl.compute
+++ /dev/null
@@ -1,1012 +0,0 @@
-#pragma kernel Dense
-#pragma kernel Conv2D
-#pragma kernel DepthwiseConv2D
-#pragma kernel Conv2DTrans
-#pragma kernel Upsample2D
-#pragma kernel Unstride2D
-#pragma kernel MaxPool2D
-#pragma kernel AvgPool2D
-#pragma kernel GlobalMaxPool2D
-#pragma kernel GlobalAvgPool2D
-#pragma kernel ScaleBias
-#pragma kernel InstanceNorm
-#pragma kernel Dropout
-#pragma kernel Relu
-#pragma kernel Swish
-#pragma kernel Softmax
-#pragma kernel Tanh
-#pragma kernel Sigmoid
-#pragma kernel Relu6
-#pragma kernel Elu
-#pragma kernel LeakyRelu
-#pragma kernel Exp
-#pragma kernel Log
-#pragma kernel Pow
-#pragma kernel Copy
-#pragma kernel BroadcastAdd
-#pragma kernel BroadcastSub
-#pragma kernel BroadcastMul
-#pragma kernel BroadcastDiv
-#pragma kernel BroadcastPow
-#pragma kernel BroadcastMin
-#pragma kernel BroadcastMax
-#pragma kernel ReduceMin
-#pragma kernel ReduceMax
-#pragma kernel ReduceSum
-#pragma kernel ReduceMean
-#pragma kernel ReduceProd
-#pragma kernel TextureToTensor
-#pragma kernel TensorToTexture
-
-#include "Tensor.cginc"
-#include "Random.cginc"
-
-TENSOR_DECL(X)
-TENSOR_DECL(W)
-TENSOR_DECL(K)
-TENSOR_DECL(B)
-TENSOR_DECL_RW(O)
-
-uint4 _Pad;
-uint4 _Pool;
-uint4 _Stride;
-float _Alpha;
-float _Beta;
-float _Seed;
-
-[numthreads(8,8,1)]
-void Dense(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
- TENSOR_ARGS4(X, W, B, O);
-
- uint x = dispatchThreadID.x;
- uint y = dispatchThreadID.y;
-
- if (x >= O.GetFlatWidth()) return;
- if (y >= O.GetFlatHeight()) return;
-
- float acc = B.Get(x);
- for (uint i = 0; i < X.GetFlatWidth(); ++i)
- acc += X.Get(y, i) * W.Get(i, x);
-
- O.Set(y, x, acc);
-}
-
-[numthreads(4,4,4)]
-void Relu(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = 0.5f * (v + abs(v));
-
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void Swish(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = v / (1 + exp(-v));
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void Tanh(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = tanh(v);
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void Sigmoid(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = 1 / (1 + exp(-v));
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void Relu6(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = min(max(0, v), 6);
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void Elu(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- if (v <= 0)
- v = _Alpha * (exp(v) - 1);
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void LeakyRelu(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = max(v, _Alpha * v);
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void Exp(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = exp(v);
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void Log(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = log(v);
- O.Set(n, y, x, c, v);
- }
-}
-
-float signed_pow(float f, float e)
-{
- // handle negative f
- float v = pow(abs(f), e);
- float s = (e % 2 == 1) ?
- sign(f): // exponent is odd => sign(f) * pow(abs(f), e)
- 1; // exponent is even => pow(abs(f), e)
- return v * s;
-}
-
-[numthreads(4,4,4)]
-void Pow(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = signed_pow(v, _Alpha);
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void BroadcastAdd(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS3(X, B, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v =
- X.BroadcastGet(n, y, x, c) +
- B.BroadcastGet(n, y, x, c);
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void BroadcastSub(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS3(X, B, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v =
- X.BroadcastGet(n, y, x, c) -
- B.BroadcastGet(n, y, x, c);
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void BroadcastMul(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS3(X, B, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float v =
- X.BroadcastGet(n, y, x, c) *
- B.BroadcastGet(n, y, x, c);
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void BroadcastDiv(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS3(X, B, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v =
- X.BroadcastGet(n, y, x, c) /
- B.BroadcastGet(n, y, x, c);
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void BroadcastPow(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS3(X, B, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = signed_pow(
- X.BroadcastGet(n, y, x, c),
- B.BroadcastGet(n, y, x, c));
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void BroadcastMin(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS3(X, B, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = min(
- X.BroadcastGet(n, y, x, c),
- B.BroadcastGet(n, y, x, c));
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void BroadcastMax(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS3(X, B, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = max(
- X.BroadcastGet(n, y, x, c),
- B.BroadcastGet(n, y, x, c));
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,1)]
-void ReduceMin(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.width, O.height, 1);
- TENSOR_ARGS3(X, B, O);
-
- uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float minV = FLT_MAX;
- for (uint c = 0; c < X.channels; ++c)
- {
- float v = X.Get(n, y, x, c);
- minV = min(v, minV);
- }
- O.Set(n, y, x, 0, minV);
- }
-}
-
-[numthreads(4,4,1)]
-void ReduceMax(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.width, O.height, 1);
- TENSOR_ARGS3(X, B, O);
-
- uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float maxV = -FLT_MAX;
- for (uint c = 0; c < X.channels; ++c)
- {
- float v = X.Get(n, y, x, c);
- maxV = max(v, maxV);
- }
- O.Set(n, y, x, 0, maxV);
- }
-}
-
-[numthreads(4,4,1)]
-void ReduceSum(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.width, O.height, 1);
- TENSOR_ARGS3(X, B, O);
-
- uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = 0;
- for (uint c = 0; c < X.channels; ++c)
- v += X.Get(n, y, x, c);
- O.Set(n, y, x, 0, v);
- }
-}
-
-[numthreads(4,4,1)]
-void ReduceMean(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.width, O.height, 1);
- TENSOR_ARGS3(X, B, O);
-
- uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = 0;
- for (uint c = 0; c < X.channels; ++c)
- v += X.Get(n, y, x, c);
-
- v /= X.channels;
- O.Set(n, y, x, 0, v);
- }
-}
-
-[numthreads(4,4,1)]
-void ReduceProd(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.width, O.height, 1);
- TENSOR_ARGS3(X, B, O);
-
- uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = 1;
- for (uint c = 0; c < X.channels; ++c)
- v *= X.Get(n, y, x, c);
- O.Set(n, y, x, 0, v);
- }
-}
-
-[numthreads(4,4,4)]
-void Copy(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- // NOTE: dispatched over X (not O)
- DISPATCH_ARGS(X.channels, X.width, X.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= X.channels) return; if (x >= X.width) return; if (y >= X.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- O.Set(n + _Pad[0], y + _Pad[1], x + _Pad[2], c + _Pad[3], v);
- }
-}
-
-[numthreads(4,4,4)]
-void Dropout(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float4 seed = float4(n / O.batch, y / O.height, x / O.width, c / O.channels);
- seed = frac(seed + _Seed);
-
- float v = X.Get(n, y, x, c);
- v *= Bernoulli(seed, 1 - _Alpha) / (1 - _Alpha);
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void ScaleBias(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS4(X, W, B, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- float scale = W.Get(0, 0, 0, c);
- float bias = B.Get(0, 0, 0, c);
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = v * scale + bias;
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(16,4,1)]
-void Softmax(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
- TENSOR_ARGS2(X, O);
-
- uint x = dispatchThreadID.x;
- uint y = dispatchThreadID.y;
-
- if (x >= O.GetFlatWidth()) return;
- if (y >= O.GetFlatHeight()) return;
-
- float maxV = -FLT_MAX;
- for (uint i = 0; i < X.GetFlatWidth(); ++i)
- {
- float v = X.Get(y, i);
- if (v > maxV)
- maxV = v;
- }
-
- float acc = 0.0f;
- for (i = 0; i < X.GetFlatWidth(); ++i)
- {
- float v = X.Get(y, i);
- acc += exp(v - maxV);
- }
-
- float v = X.Get(y, x);
- v = exp(v - maxV) / acc;
- O.Set(y, x, v);
-}
-
-[numthreads(4,4,4)]
-void Upsample2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- // NOTE: dispatched over X (not O)
- DISPATCH_ARGS(X.channels, X.width, X.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= X.channels) return;
- if (x >= X.width) return;
- if (y >= X.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
-
- for (uint dy = 0; dy < _Pool.y; ++dy)
- for (uint dx = 0; dx < _Pool.x; ++dx)
- {
- uint oy = y * _Pool.y + dy;
- uint ox = x * _Pool.x + dx;
- O.Set(n, oy, ox, c, v);
- }
- }
-}
-
-[numthreads(4,4,4)]
-void MaxPool2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float maxV = -FLT_MAX;
- for (uint dy = 0; dy < _Pool.y; ++dy)
- for (uint dx = 0; dx < _Pool.x; ++dx)
- {
- uint2 pos = uint2(x, y) * _Stride.xy + uint2(dx, dy);
- float v = X.SafeGet(n, pos, c, _Pad.xy);
- maxV = max(v, maxV);
- }
-
- O.Set(n, y, x, c, maxV);
- }
-}
-
-[numthreads(4,4,4)]
-void AvgPool2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- uint2 leftCorner = _Pad.xy;
- uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
- for (uint n = 0; n < X.batch; ++n)
- {
- float acc = 0;
- float counter = 0;
- for (uint dy = 0; dy < _Pool.y; ++dy)
- for (uint dx = 0; dx < _Pool.x; ++dx)
- {
- uint2 pos = uint2(x, y) * _Stride.xy + uint2(dx, dy);
-
- bool mask = all(pos >= leftCorner) && all(pos < rightCorner);
- acc += (mask)? X.Get(n, pos.y - leftCorner.y, pos.x - leftCorner.x, c): 0;
- counter += (mask)? 1: 0;
- }
-
- acc /= counter;
- O.Set(n, y, x, c, acc);
- }
-}
-
-[numthreads(32,1,1)]
-void GlobalMaxPool2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, 1, 1);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- if (c >= O.channels) return;
- //ASSERT(X.batch == O.batch)
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float maxV = -FLT_MAX;
- for (uint y = 0; y < X.height; ++y)
- for (uint x = 0; x < X.width; ++x)
- {
- float v = X.Get(n, y, x, c);
- maxV = max(v, maxV);
- }
-
- O.Set(n, 0, 0, c, maxV);
- }
-}
-
-[numthreads(32,1,1)]
-void GlobalAvgPool2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, 1, 1);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- if (c >= O.channels) return;
- //ASSERT(X.batch == O.batch)
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = 0;
- for (uint y = 0; y < X.height; ++y)
- for (uint x = 0; x < X.width; ++x)
- v += X.Get(n, y, x, c);
-
- v /= (X.height * X.width);
- O.Set(n, 0, 0, c, v);
- }
-}
-
-[numthreads(32,1,1)]
-void InstanceNorm(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, 1, 1);
- TENSOR_ARGS4(X, W, B, O);
-
- uint c = dispatchThreadID.x;
- if (c >= O.channels) return;
- //ASSERT(X.shape == O.shape)
-
- float gamma = W.Get(0, 0, 0, c);
- float beta = B.Get(0, 0, 0, c);
-
- for (uint n = 0; n < O.batch; ++n)
- {
- uint x, y;
- // calc mean
- float acc = 0;
- for (y = 0; y < O.height; ++y)
- for (x = 0; x < O.width; ++x)
- acc += X.Get(n, y, x, c);
- float mean = acc / (O.width * O.height);
-
- // calc variance
- acc = 0;
- for (y = 0; y < O.height; ++y)
- for (x = 0; x < O.width; ++x)
- {
- float delta = X.Get(n, y, x, c) - mean;
- acc += delta * delta;
- }
- float var = acc / (O.width * O.height);
-
- // normalization factor
- float invNormFactor = 1 / sqrt(var + FLT_EPSILON);
-
- float scale = gamma * invNormFactor;
- float bias = beta - gamma * mean * invNormFactor;
-
- // apply normalization
- for (y = 0; y < O.height; ++y)
- for (x = 0; x < O.width; ++x)
- {
- float v = X.Get(n, y, x, c);
- v = v * scale + bias;
- O.Set(n, y, x, c, v);
- }
- }
-}
-
-[numthreads(4,4,4)]
-void Conv2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_ARGS4(X, K, B, O);
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (k >= K.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc = B.Get(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos = uint2(x, y) * _Stride.xy + uint2(dx, dy);
- for (uint c = 0; c < X.channels; ++c)
- {
- float v = X.SafeGet(n, pos, c, _Pad.xy);
- acc += v * K.Get(dy, dx, c, k);
- }
- }
- }
-
- O.Set(n, y, x, k, acc);
- }
-}
-
-NUMTHREADS((16,4,4), (8,4,4), (4,4,4))
-void DepthwiseConv2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_ARGS4(X, K, B, O);
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (k >= K.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc = B.Get(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos = uint2(x, y) * _Stride.xy + uint2(dx, dy);
- float v = X.SafeGet(n, pos, k, _Pad.xy);
- acc += v * K.Get(dy, dx, 0, k);
- }
-
- O.Set(n, y, x, k, acc);
- }
-}
-
-[numthreads(4,4,4)]
-void Unstride2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- int xx = (int)x - (int)_Pad.x;
- int yy = (int)y - (int)_Pad.y;
-
- int my = yy % _Stride.y;
- int mx = xx % _Stride.x;
-
- int oy = yy / _Stride.y;
- int ox = xx / _Stride.x;
-
- bool mask = ox >= 0 && oy >= 0 && ox < (int)X.width && oy < (int)X.height &&
- my == 0 && mx == 0;
-
- float v = mask ? X.Get(n, (uint)oy, (uint)ox, c) : 0;
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(4,4,4)]
-void Conv2DTrans(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_ARGS4(X, K, B, O);
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (k >= K.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- uint2 strideMask = _Stride.xy - 1;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc = B.Get(k);
- for (uint dy = y & strideMask.y; dy < K.GetKernelHeight(); dy += _Stride.y)
- {
- for (uint dx = x & strideMask.x; dx < K.GetKernelWidth(); dx += _Stride.x)
- {
- for (uint c = 0; c < X.channels; ++c)
- {
- uint xx = x + dx;
- uint yy = y + dy;
-
- uint oy = (yy - _Pad.y) / _Stride.y;
- uint ox = (xx - _Pad.x) / _Stride.x;
-
- bool mask = xx >= _Pad.x && yy >= _Pad.y && ox < X.width && oy < X.height;
-
- float v = (mask)? X.Get(n, oy, ox, c): 0;
- acc += v * K.Get(K.GetKernelHeight() - 1 - dy, K.GetKernelWidth() - 1 - dx, c, k);
- }
- }
- }
-
- O.Set(n, y, x, k, acc);
- }
-}
-
-
-Texture2D Xtex2D;
-Texture3D Xtex3D;
-Texture2DArray Xtex2DArray;
-SamplerState samplerXtex2D { Filter = MIN_MAG_LINEAR_MIP_POINT; AddressU = Clamp; AddressV = Clamp; };
-SamplerState samplerXtex3D { Filter = MIN_MAG_LINEAR_MIP_POINT; AddressU = Clamp; AddressV = Clamp; AddressW = Clamp; };
-SamplerState samplerXtex2DArray { Filter = MIN_MAG_LINEAR_MIP_POINT; AddressU = Clamp; AddressV = Clamp; };
-
-RWTexture2D Otex2D;
-RWTexture3D Otex3D;
-RWTexture2DArray Otex2DArray;
-
-bool _FlipY;
-
-// TODO: call TextureToTensor(v, dispatchThreadID) from Tex2DToTensor() { v = Xtex2D.SampleLevel }
-[numthreads(8,8,1)]
-void TextureToTensor(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- TENSOR_ARG_RW(O);
-
- uint b = _Pad.x;
- uint x = dispatchThreadID.x + _Pad.y;
- uint y = dispatchThreadID.y + _Pad.z;
- uint c = dispatchThreadID.z + _Pad.w;
-
- if (y >= O.height || x >= O.width)
- return;
-
- // calculate texture coordinates:
- // offset by 0.5 to get texel centers
- // divide by texture resolution (_Pool)
- float3 uvw = (float3)dispatchThreadID + float3(0.5f, 0.5f, 0);
- uvw /= (float3)_Pool.xyz;
- if (_FlipY)
- uvw.y = 1 - uvw.y;
-
- float4 v = Xtex2D.SampleLevel(samplerXtex2D, uvw.xy, 0);
- //texArray.SampleLevel(smpArray, loc, 0);
-
- if (_Stride.w == 1)
- {
- // TODO: interpret color as
- O.Set(b, y, x, c+0, (v.r + v.g + v.b) / 3.0f);
- }
- else if (_Stride.w == 3)
- {
- O.Set(b, y, x, c+0, v.r);
- O.Set(b, y, x, c+1, v.g);
- O.Set(b, y, x, c+2, v.b);
- }
- else if (_Stride.w == 4)
- {
- O.Set(b, y, x, c+0, v.r);
- O.Set(b, y, x, c+1, v.g);
- O.Set(b, y, x, c+2, v.b);
- O.Set(b, y, x, c+3, v.a);
- }
-}
-
-[numthreads(8,8,1)]
-void TensorToTexture(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- TENSOR_ARG(X);
-
- uint b = _Pad.x;
- uint x = dispatchThreadID.x + _Pad.y;
- uint y = dispatchThreadID.y + _Pad.z;
- uint c = dispatchThreadID.z + _Pad.w;
-
- if (y >= X.height || x >= X.width)
- return;
-
- if (_FlipY)
- y = X.height - 1 - y;
-
- float4 v = 0;
-
- if (X.channels - c == 1)
- {
- // broadcast to all channels
- v = _Alpha * X.Get(b, y, x, c) + _Beta;
- }
- else if (X.channels - c == 3)
- {
- v.r = _Alpha * X.Get(b, y, x, c+0) + _Beta;
- v.g = _Alpha * X.Get(b, y, x, c+1) + _Beta;
- v.b = _Alpha * X.Get(b, y, x, c+2) + _Beta;
- v.a = 1;
- }
- else if (X.channels - c >= 4)
- {
- v.r = _Alpha * X.Get(b, y, x, c+0) + _Beta;
- v.g = _Alpha * X.Get(b, y, x, c+1) + _Beta;
- v.b = _Alpha * X.Get(b, y, x, c+2) + _Beta;
- v.a = _Alpha * X.Get(b, y, x, c+3) + _Beta;
- }
-
- Otex2D[dispatchThreadID.xy] = v;
-}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/BarracudaReferenceImpl.compute.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/BarracudaReferenceImpl.compute.meta
deleted file mode 100644
index e8147972a1..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/BarracudaReferenceImpl.compute.meta
+++ /dev/null
@@ -1,9 +0,0 @@
-fileFormatVersion: 2
-guid: b4b1b304aae6c404cb0cdab46b8fa084
-timeCreated: 1495527718
-licenseType: Pro
-ComputeShaderImporter:
- currentAPIMask: 196608
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Broadcast.compute b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Broadcast.compute
deleted file mode 100644
index 3e31a66b95..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Broadcast.compute
+++ /dev/null
@@ -1,149 +0,0 @@
-#pragma kernel BroadcastAdd
-#pragma kernel BroadcastSub
-#pragma kernel BroadcastMul
-#pragma kernel BroadcastDiv
-#pragma kernel BroadcastPow
-#pragma kernel BroadcastMin
-#pragma kernel BroadcastMax
-
-#include "Tensor.cginc"
-
-TENSOR_DECL(X)
-TENSOR_DECL(B)
-TENSOR_DECL_RW(O)
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void BroadcastAdd(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS3(X, B, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v =
- X.BroadcastGet(n, y, x, c) +
- B.BroadcastGet(n, y, x, c);
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void BroadcastSub(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS3(X, B, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v =
- X.BroadcastGet(n, y, x, c) -
- B.BroadcastGet(n, y, x, c);
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void BroadcastMul(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS3(X, B, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float v =
- X.BroadcastGet(n, y, x, c) *
- B.BroadcastGet(n, y, x, c);
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void BroadcastDiv(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS3(X, B, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v =
- X.BroadcastGet(n, y, x, c) /
- B.BroadcastGet(n, y, x, c);
- O.Set(n, y, x, c, v);
- }
-}
-
-float signed_pow(float f, float e)
-{
- // handle negative f
- float v = pow(abs(f), e);
- float s = (e % 2 == 1) ?
- sign(f): // exponent is odd => sign(f) * pow(abs(f), e)
- 1; // exponent is even => pow(abs(f), e)
- return v * s;
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void BroadcastPow(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS3(X, B, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = signed_pow(
- X.BroadcastGet(n, y, x, c),
- B.BroadcastGet(n, y, x, c));
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void BroadcastMin(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS3(X, B, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = min(
- X.BroadcastGet(n, y, x, c),
- B.BroadcastGet(n, y, x, c));
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void BroadcastMax(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS3(X, B, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= O.channels) return; if (x >= O.width) return; if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = max(
- X.BroadcastGet(n, y, x, c),
- B.BroadcastGet(n, y, x, c));
- O.Set(n, y, x, c, v);
- }
-}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Broadcast.compute.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Broadcast.compute.meta
deleted file mode 100644
index 70f38084a0..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Broadcast.compute.meta
+++ /dev/null
@@ -1,8 +0,0 @@
-fileFormatVersion: 2
-guid: 72dd00e416ab94bd79e7264a1fadef9d
-ComputeShaderImporter:
- externalObjects: {}
- currentAPIMask: 65536
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Conv.compute b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Conv.compute
deleted file mode 100644
index 89ba4ed07a..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Conv.compute
+++ /dev/null
@@ -1,638 +0,0 @@
-#pragma kernel Conv2D
-#pragma kernel Conv2D_RegisterBlock4x2
-//#pragma kernel Conv2D_L1Cached64_RegisterBlock4x4
-//#pragma kernel Conv2D_L1Cached32_RegisterBlock4x4
-#pragma kernel Conv2DKernelKxK_T16x16_R4x4 BLOCK_SIZE=4 SUFFIX=KernelKxK_T16x16_R
-#pragma kernel Conv2DKernelKxK_StrictC16K64_T16x16_R4x4 BLOCK_SIZE=4 STRICT_CHANNELS=1 SUFFIX=KernelKxK_StrictC16K64_T16x16_R
-#pragma kernel Conv2DKernel1x1_StrictC16K64_T16x16_R4x4 BLOCK_SIZE=4 KERNEL_1x1=1 STRICT_CHANNELS=1 SUFFIX=Kernel1x1_StrictC16K64_T16x16_R
-
-#pragma kernel DepthwiseConv2D
-
-#pragma kernel Conv2DTrans
-#pragma kernel Conv2DTrans_L1Cached64_RegisterBlock2x2
-
-#include "Tensor.cginc"
-
-TENSOR_DECL(X)
-TENSOR_DECL(K)
-TENSOR_DECL(B)
-TENSOR_DECL(WBK)
-TENSOR_DECL_RW(O)
-
-uint4 _Pad;
-uint4 _Stride;
-
-#define DEBUG_CHECK_BOUNDS 0
-
-// Conv2DBlock64x64_4x4 + index optimizations
-// T
-// -1|0 -1|0
-// 16: 142|142ms 144|155ms
-
-float ffma(float a, float b, float c) { return dot(float2(a,c), float2(b,1)); }
-#define FUNC_NAME(KERNEL, SUFFIX, SIZE) KERNEL##SUFFIX##SIZE##x##SIZE
-#define CACHE_NAME(KERNEL, SUFFIX, SIZE, TENSOR) KERNEL##SUFFIX##SIZE##x##SIZE##_Cache_##TENSOR
-
-#define KERNEL_NAME Conv2D
-
-#if BLOCK_SIZE == 4
-#define TRANSPOSED_X 0
-#define BUF_OFFSET 0
-#define CACHE_DEPTH 16
-groupshared float CACHE_NAME(KERNEL_NAME, SUFFIX, BLOCK_SIZE, X)[CACHE_DEPTH*16*BLOCK_SIZE+(1-TRANSPOSED_X)*CACHE_DEPTH];
-groupshared float CACHE_NAME(KERNEL_NAME, SUFFIX, BLOCK_SIZE, W)[CACHE_DEPTH*16*BLOCK_SIZE];
-[numthreads(16,16,1)]
-void FUNC_NAME(KERNEL_NAME, SUFFIX, BLOCK_SIZE)(uint3 dispatchThreadID : SV_DispatchThreadID, uint3 groupThreadID : SV_GroupThreadID, uint threadIndex : SV_GroupIndex)
-{
- DISPATCH_ARGS(K.kernelCount, O.width * O.height * O.batch, 1);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- // [W*H, Ky*Kx*In] * [Ky*Kx*In, Out] => [W*H, Out]
-
- #define X_ CACHE_NAME(KERNEL_NAME, SUFFIX, BLOCK_SIZE, X)
- #define W_ CACHE_NAME(KERNEL_NAME, SUFFIX, BLOCK_SIZE, W)
-
- int x = (int)dispatchThreadID.x * BLOCK_SIZE; // output_channels
- int y = (int)dispatchThreadID.y * BLOCK_SIZE; // batch*width*height
- int tx = (int)groupThreadID.x;
- int ty = (int)groupThreadID.y;
- int bx = ((int)dispatchThreadID.x - (int)groupThreadID.x) * BLOCK_SIZE;
- int by = ((int)dispatchThreadID.y - (int)groupThreadID.y) * BLOCK_SIZE;
- int ti = (int)threadIndex;
- uint w = O.width;
- uint h = O.height;
- int channels = X.channels;
- int widthX = X.width;
- int heightX = X.height;
- int strideX = X.channels;
- int strideK = K.channels;
- int strideO = O.channels;
- int offsetX = BUF_OFFSET;
- int offsetK = BUF_OFFSET;
- int offsetO = BUF_OFFSET;
-
- float4 dstA[4];
- dstA[0].x = B.Get(x+0); dstA[0].y = B.Get(x+1); dstA[0].z = B.Get(x+2); dstA[0].w = B.Get(x+3);
- dstA[1].x = B.Get(x+0); dstA[1].y = B.Get(x+1); dstA[1].z = B.Get(x+2); dstA[1].w = B.Get(x+3);
- dstA[2].x = B.Get(x+0); dstA[2].y = B.Get(x+1); dstA[2].z = B.Get(x+2); dstA[2].w = B.Get(x+3);
- dstA[3].x = B.Get(x+0); dstA[3].y = B.Get(x+1); dstA[3].z = B.Get(x+2); dstA[3].w = B.Get(x+3);
-
- int readK = strideK * (ti>>6) + bx + (ti&63) + offsetK;
- #if STRICT_CHANNELS == 1
- #else
- bool maskK = (bx + (ti&63)) < strideK;
- #endif
-
-#if TRANSPOSED_X == 1
- uint centroidId = by + (ti&63);
- #if KERNEL_1x1 == 1
- int readX = strideX * (ti>>6) + centroidId;
- #else
- int batch = centroidId / w / h;
- int topY = (centroidId / w % h) * _Stride.y - _Pad.y;
- int leftX = (centroidId % w) * _Stride.x - _Pad.x;
- int cornerId = batch * heightX * widthX + topY * widthX + leftX;
- int readX = strideX * (ti>>6) + cornerId;
- bool mask;
- #endif
-#else
- uint4 centroidId = uint4(
- (by + (ti>>4) + 0),
- (by + (ti>>4) + 16),
- (by + (ti>>4) + 32),
- (by + (ti>>4) + 48));
- #if KERNEL_1x1 == 1
- int4 readX = strideX * centroidId + (ti&15);
- #else
- int4 batch = centroidId / w / h;
- int4 topY = (centroidId / w % h) * _Stride.y - _Pad.y;
- int4 leftX = (centroidId % w) * _Stride.x - _Pad.x;
- int4 cornerId = batch * heightX * widthX + topY * widthX + leftX;
- int4 readX = strideX * cornerId + (ti&15);
- bool4 mask;
- #endif
-#endif
-
-#if KERNEL_1x1 == 1
- {
- {
-#else
- for (int dy = 0; dy < (int)K.GetKernelHeight(); dy++)
- {
- for (int dx = 0; dx < (int)K.GetKernelWidth(); dx++)
- {
- int kernelOffsetX = (dy * widthX + dx) * strideX;
- mask =
- topY + dy >= 0 &&
- topY + dy < heightX &&
- leftX + dx >= 0 &&
- leftX + dx < widthX;
-#endif // KERNEL_1x1
- for (int i = 0; i < channels; i += CACHE_DEPTH)
- {
- #if STRICT_CHANNELS == 1
- #else
- if (i + CACHE_DEPTH > channels)
- {
- int channelRemainder = channels - i;
- [unroll] for (int j = 0; j < 4; ++j)
- {
- bool maskChannelsK = ti < 64 * (channelRemainder - j * 4);
- bool maskChannelsX =
- #if TRANSPOSED_X == 1
- maskChannelsK;
- #else
- (ti&15) < channelRemainder;
- #endif
-
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2) + 256*j] =
- (maskK & maskChannelsK) ? K.data[readK] : 0;
- readK += strideK * max(0, min(channelRemainder - j * 4, 4));
-
- #if TRANSPOSED_X == 1
- X_[ti + 256*j] =
- #if KERNEL_1x1 == 1
- maskChannelsX ? X.data[readX + strideX * (i + j * 4) + offsetX]: 0;
- #else
- (mask && maskChannelsX) ? X.data[readX + strideX * (i + j * 4) + kernelOffsetX + offsetX]: 0;
- #endif
- #else
- X_[(ti>>4) + 65*(ti&15) + 16*j] =
- #if KERNEL_1x1 == 1
- maskChannelsX ? X.data[readX[j] + i + offsetX]: 0;
- #else
- (mask[j] && maskChannelsX) ? X.data[readX[j] + i + kernelOffsetX + offsetX]: 0;
- #endif
- #endif
- }
- }
- else
- #endif
- [unroll] for (int j = 0; j < 4; ++j)
- {
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2) + 256*j] =
- #if STRICT_CHANNELS == 1
- K.data[readK];
- #else
- maskK ? K.data[readK]: 0;
- #endif
- readK += strideK * 4;
-
- #if TRANSPOSED_X == 1
- X_[ti + 256*j] =
- #if KERNEL_1x1 == 1
- X.data[readX + strideX * (i + j * 4) + offsetX];
- #else
- mask ? X.data[readX + strideX * (i + j * 4) + kernelOffsetX + offsetX]: 0;
- #endif
- #else
- X_[(ti>>4) + 65*(ti&15) + 16*j] =
- #if KERNEL_1x1 == 1
- X.data[readX[j] + i + offsetX];
- #else
- mask[j] ? X.data[readX[j] + i + kernelOffsetX + offsetX]: 0;
- #endif
- #endif
-
- #if DEBUG_CHECK_BOUNDS && KERNEL_1x1 == 0
- if (mask[j] && readX[j] + i + kernelOffsetX < 0)
- X_[(ti>>4) + 65*(ti&15) + 16*j] = -1;
- if (mask[j] && readX[j] + i + kernelOffsetX >= X.GetLength())
- X_[(ti>>4) + 65*(ti&15) + 16*j] = -1;
- #endif
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- int4 idX = int4(0,1,2,3);
- int4 idW = int4(0,16,32,48);
- int incX = 64 + (1-TRANSPOSED_X);
- int incW = 64;
-
- for (int di = 0; di < CACHE_DEPTH; di++)
- {
- float4 srcX = float4(
- X_[idX.x + ty*4],
- X_[idX.y + ty*4],
- X_[idX.z + ty*4],
- X_[idX.w + ty*4]);
- float4 srcW = float4(
- W_[idW.x + tx],
- W_[idW.y + tx],
- W_[idW.z + tx],
- W_[idW.w + tx]
- );
- idX += incX;
- idW += incW;
-
- dstA[0].x = ffma(srcX.x, srcW.x, dstA[0].x);
- dstA[0].y = ffma(srcX.x, srcW.y, dstA[0].y);
- dstA[0].z = ffma(srcX.x, srcW.z, dstA[0].z);
- dstA[0].w = ffma(srcX.x, srcW.w, dstA[0].w);
-
- dstA[1].x = ffma(srcX.y, srcW.x, dstA[1].x);
- dstA[1].y = ffma(srcX.y, srcW.y, dstA[1].y);
- dstA[1].z = ffma(srcX.y, srcW.z, dstA[1].z);
- dstA[1].w = ffma(srcX.y, srcW.w, dstA[1].w);
-
- dstA[2].x = ffma(srcX.z, srcW.x, dstA[2].x);
- dstA[2].y = ffma(srcX.z, srcW.y, dstA[2].y);
- dstA[2].z = ffma(srcX.z, srcW.z, dstA[2].z);
- dstA[2].w = ffma(srcX.z, srcW.w, dstA[2].w);
-
- dstA[3].x = ffma(srcX.w, srcW.x, dstA[3].x);
- dstA[3].y = ffma(srcX.w, srcW.y, dstA[3].y);
- dstA[3].z = ffma(srcX.w, srcW.z, dstA[3].z);
- dstA[3].w = ffma(srcX.w, srcW.w, dstA[3].w);
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- [unroll] for (int sy = 0; sy < 4 && y+sy < (int)w * (int)h * (int)O.batch; ++sy)
- [unroll] for (int sx = 0; sx < 4 && x+sx < strideO; ++sx)
- O.data[strideO * (y+sy) + x+sx + offsetO] = dstA[sy][sx];
-
- #undef X_
- #undef W_
-}
-#else
-#endif
-#undef TRANSPOSED_X
-#undef CACHE_DEPTH
-#undef BUF_OFFSET
-#undef KERNEL_NAME
-
-NUMTHREADS((16,4,4), (8,4,4), (4,4,4))
-void Conv2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (k >= K.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- uint2 leftCorner = _Pad.xy;
- uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc = B.Get(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos = uint2(x, y) * _Stride.xy + uint2(dx, dy);
- // @TODO: investigate
- // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
- if (any(pos < leftCorner)) continue;
- if (any(pos >= rightCorner)) continue;
-
- for (uint c = 0; c < X.channels; ++c)
- acc = fastfma(X.Get(n, pos.y - leftCorner.y, pos.x - leftCorner.x, c), K.Get(dy, dx, c, k), acc);
- }
- }
-
- O.Set(n, y, x, k, acc);
- }
-}
-
-
-#define SIZE_W 4
-#define SIZE_H 2
-NUMTHREADS((64, 2, 2), (32, 2, 2), (16, 2, 2))
-void Conv2D_RegisterBlock4x2(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (k >= K.channels) return;
- if (x*SIZE_W >= O.width) return;
- if (y*SIZE_H >= O.height) return;
-
- uint2 leftCorner = _Pad.xy;
- uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc[SIZE_H*SIZE_W];
- [unroll]
- for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
- acc[q] = B.Get(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos[SIZE_H*SIZE_W];
- [unroll]
- for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
- pos[q] = uint2(x*SIZE_W+(q%SIZE_W), y*SIZE_H+(q/SIZE_W)) * _Stride.xy + uint2(dx, dy);
-
- for (uint c = 0; c < X.channels; ++c)
- [unroll]
- for (q = 0; q < SIZE_H*SIZE_W; ++q)
- if (all(pos[q] >= leftCorner) && all(pos[q] < rightCorner))
- acc[q] = fastfma(X.Get(n, pos[q] - leftCorner, c), K.Get(dy, dx, c, k), acc[q]);
- }
- }
-
- [unroll]
- for (q = 0; q < SIZE_H*SIZE_W; ++q)
- O.Set(n, y*SIZE_H+(q/SIZE_W), x*SIZE_W+(q%SIZE_W), k, acc[q]);
- }
-}
-#undef SIZE_W
-#undef SIZE_H
-
-#define CONV2D_L1CACHED(L1CACHESIZE, SIZE, FMA) \
-groupshared float Conv2D_L1Cached##L1CACHESIZE##_Reg_Loop_safe_X[SIZE*SIZE][L1CACHESIZE];\
-[numthreads(L1CACHESIZE, 1, 1)]\
-void Conv2D_L1Cached##L1CACHESIZE##_RegisterBlock##SIZE##x##SIZE(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)\
-{\
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);\
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);\
-\
- uint k = L1CACHESIZE * groupID.x + groupThreadID.x;\
- uint x = groupID.y;\
- uint y = groupID.z;\
-\
- if (x*SIZE >= O.width) return;\
- if (y*SIZE >= O.height) return;\
-\
- for (uint n = 0; n < O.batch; ++n)\
- {\
- float acc[SIZE*SIZE];\
- [unroll]\
- for (uint q = 0; q < SIZE*SIZE; ++q)\
- acc[q] = B.SafeGet(k);\
-\
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)\
- {\
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)\
- {\
- uint2 pos[SIZE*SIZE];\
- [unroll]\
- for (uint q = 0; q < SIZE*SIZE; ++q)\
- pos[q] = uint2(x*SIZE+(q%SIZE), y*SIZE+(q/SIZE)) * _Stride.xy + uint2(dx, dy);\
-\
- for (uint c = 0; c < X.channels; c += L1CACHESIZE)\
- {\
- uint dc = groupThreadID.x;\
- [unroll]\
- for (q = 0; q < SIZE*SIZE; ++q)\
- Conv2D_L1Cached##L1CACHESIZE##_Reg_Loop_safe_X[q][dc] = X.SafeGet(n, pos[q], c + dc, _Pad.xy);\
- GroupMemoryBarrierWithGroupSync();\
-\
- if (k < K.channels)\
- {\
- uint kIndex = K.Index(dy, dx, c, k);\
- for (dc = 0; dc < L1CACHESIZE; ++dc)\
- {\
- [unroll]\
- for (q = 0; q < SIZE*SIZE; ++q)\
- acc[q] = FMA(Conv2D_L1Cached##L1CACHESIZE##_Reg_Loop_safe_X[q][dc], K.data[kIndex], acc[q]);\
- kIndex += K.channels;\
- }\
- }\
- GroupMemoryBarrierWithGroupSync();\
- }\
- }\
- }\
-\
- uint remainderW = (O.width - x*SIZE);\
- uint remainderH = (O.height - y*SIZE);\
-\
- if (k < K.channels)\
- [unroll]\
- for (q = 0; q < SIZE*SIZE; ++q)\
- if (q/SIZE < remainderH && q%SIZE < remainderW)\
- O.Set(n, y*SIZE+(q/SIZE), x*SIZE+(q%SIZE), k, acc[q]);\
- }\
-\
-}
-
-CONV2D_L1CACHED(64,4, fastfma)
-CONV2D_L1CACHED(32,4, fastfma)
-
-
-// IDEA: iterate over channels in the inner loop - needs channels first layout
-NUMTHREADS((16,4,4), (8,4,4), (4,4,4))
-void DepthwiseConv2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (k >= K.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- uint2 leftCorner = _Pad.xy;
- uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
-
- uint2 leftKernelCorner = uint2(x, y) * _Stride.xy;
- uint2 rightKernelCorner = leftKernelCorner + uint2(K.GetKernelWidth(), K.GetKernelHeight());
-
- if (any(leftKernelCorner < leftCorner) || any(rightKernelCorner >= rightCorner))
- {
- // path with edge-cases checks
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc = B.Get(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos = leftKernelCorner + uint2(dx, dy);
- if (any(pos < leftCorner)) continue;
- if (any(pos >= rightCorner)) continue;
-
- acc = fastfma(
- X.Get(n, pos.y - leftCorner.y, pos.x - leftCorner.x, k),
- K.Get(dy, dx, 0, k),
- acc);
- }
-
- O.Set(n, y, x, k, acc);
- }
- }
- else
- {
- // kernel is guaranteed to be within X,
- // no need to check against edge-cases
- leftKernelCorner -= leftCorner;
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc = B.Get(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos = leftKernelCorner + uint2(dx, dy);
-
- acc = fastfma(
- X.Get(n, pos, k),
- K.Get(dy, dx, 0, k),
- acc);
- }
-
- O.Set(n, y, x, k, acc);
- }
- }
-}
-
-
-// Significantly faster than Conv2DTrans
-[numthreads(16,2,2)]
-void Conv2DTrans(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- // NOTE: dispatched over X (not O)
- DISPATCH_ARGS(K.kernelCount, X.width, X.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (k >= K.channels) return;
- if (x >= X.width) return;
- if (y >= X.height) return;
-
- uint2 pad = _Pad.xy / _Stride.xy;
- uint2 leftCorner = pad;
- uint2 rightCorner = uint2(X.width, X.height) + pad;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- for (uint sy = 0; sy < _Stride.y; ++sy)
- {
- for (uint sx = 0; sx < _Stride.x; ++sx)
- {
- float acc = B.Get(k);
- for (uint dy = sy; dy < K.GetKernelHeight(); dy += _Stride.y)
- {
- for (uint dx = sx; dx < K.GetKernelWidth(); dx += _Stride.x)
- {
- uint2 pos = uint2(x, y) + uint2(sx + dx, sy + dy) / _Stride.xy;
-
- if (any(pos < leftCorner)) continue;
- if (any(pos >= rightCorner)) continue;
-
- for (uint c = 0; c < X.channels; ++c)
- {
- acc = fastfma( X.Get(n, pos - leftCorner, c),
- K.Get( K.GetKernelHeight() - 1 - dy,
- K.GetKernelWidth() - 1 - dx, c, k),
- acc);
- }
- }
- }
-
- uint oy = y * _Stride.y + sy;
- uint ox = x * _Stride.x + sx;
- if (oy < O.height && ox < O.width)
- O.Set(n, oy, ox, k, acc);
- }
- }
- }
-}
-
-#undef L1CACHESIZE
-#define L1CACHESIZE 64
-#undef SIZE
-#define SIZE 2
-groupshared float Conv2DTrans_L1Cached64_Reg_Loop_safe_X[SIZE*SIZE][L1CACHESIZE];
-[numthreads(L1CACHESIZE, 1, 1)]
-void Conv2DTrans_L1Cached64_RegisterBlock2x2(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- // NOTE: dispatched over X (not O)
- DISPATCH_ARGS(K.kernelCount, X.width / SIZE, X.height / SIZE);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- #define X_ Conv2DTrans_L1Cached64_Reg_Loop_safe_X
-
- uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
- uint x = groupID.y;
- uint y = groupID.z;
-
- // need all threads to load channels, thus will do late check against kernel count
- if (x*SIZE >= X.width) return;
- if (y*SIZE >= X.height) return;
-
- uint2 pad = _Pad.xy / _Stride.xy;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- for (uint sy = 0; sy < _Stride.y; ++sy)
- {
- for (uint sx = 0; sx < _Stride.x; ++sx)
- {
- float acc[SIZE*SIZE];
- [unroll]
- for (uint q = 0; q < SIZE*SIZE; ++q)
- acc[q] = B.SafeGet(k);
-
- for (uint dy = sy; dy < K.GetKernelHeight(); dy += _Stride.y)
- {
- for (uint dx = sx; dx < K.GetKernelWidth(); dx += _Stride.x)
- {
- uint2 pos[SIZE*SIZE];
- [unroll]
- for (uint q = 0; q < SIZE*SIZE; ++q)
- pos[q] = uint2(x*SIZE+(q%SIZE), y*SIZE+(q/SIZE)) + uint2(dx+sx, dy+sy) / _Stride.xy;
-
- for (uint c = 0; c < X.channels; c += L1CACHESIZE)
- {
- // Cache X
- uint dc = groupThreadID.x;
- [unroll]
- for (q = 0; q < SIZE*SIZE; ++q)
- X_[q][dc] = X.SafeGet(n, pos[q], c + dc, pad);
- GroupMemoryBarrierWithGroupSync();
-
- // X * K
- if (k < K.channels) // need all threads to load channels, thus late check against kernel count
- {
- //uint kIndex = K.Index(dy, dx, c, k);
- for (dc = 0; dc < L1CACHESIZE; ++dc)
- {
- [unroll]
- for (q = 0; q < SIZE*SIZE; ++q)
- acc[q] = fastfma( X_[q][dc],
- K.Get( K.GetKernelHeight() - 1 - dy,
- K.GetKernelWidth() - 1 - dx, c + dc, k),
- acc[q]);
- //kIndex += K.channels;
- }
- }
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
-
- if (k < K.channels) // need all threads to load channels, thus late check against kernel count
- [unroll]
- for (q = 0; q < SIZE*SIZE; ++q)
- {
- uint ox = (x*SIZE+(q%SIZE)) * _Stride.x + sx;
- uint oy = (y*SIZE+(q/SIZE)) * _Stride.y + sy;
- if (ox < O.width && oy < O.height)
- O.Set(n, oy, ox, k, acc[q]);
- }
- }
- }
- }
-
- #undef X_
-}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Conv.compute.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Conv.compute.meta
deleted file mode 100644
index bc66c8b6ef..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Conv.compute.meta
+++ /dev/null
@@ -1,9 +0,0 @@
-fileFormatVersion: 2
-guid: 7f508b82f984146e8bf0ad8520c316c7
-timeCreated: 1507457340
-licenseType: Pro
-ComputeShaderImporter:
- currentAPIMask: 196608
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/ConvOld.compute b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/ConvOld.compute
deleted file mode 100644
index 81b5e4b68a..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/ConvOld.compute
+++ /dev/null
@@ -1,418 +0,0 @@
-//#pragma kernel Conv2D_Kmod16_Nmod8_KNY
-//#pragma kernel Conv2D_Cache_KCmod32_KNyx
-//#pragma kernel Conv2D_Cache_KCmod32_KNyxDiv2
-// NOTE: DISABLED 64 version because as it is slower than 32 version on AMD GPU
-//#pragma kernel Conv2D_Cache_KCmod64_KNyx
-
-#include "Tensor.cginc"
-
-TENSOR_DECL(X)
-TENSOR_DECL(K)
-TENSOR_DECL(B)
-TENSOR_DECL(WBK)
-TENSOR_DECL_RW(O)
-
-uint4 _Pad;
-uint4 _Stride;
-
-NUMTHREADS((16,8,1), (16,8,1), (16,4,1))
-void Conv2D_Kmod16_Nmod8_KNY(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(K.channels, O.batch, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- uint k = dispatchThreadID.x;
- uint n = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- for (uint x = 0; x < O.width; ++x)
- {
- float v = B.Get(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint oy = y * _Stride.y + dy;
- uint ox = x * _Stride.x + dx;
- // @TODO: investigate
- // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
- if (oy < _Pad.y) continue;
- if (oy - _Pad.w >= X.height) continue;
- if (ox < _Pad.x) continue;
- if (ox - _Pad.z >= X.width) continue;
-
- for (uint c = 0; c < X.channels; ++c)
- {
- v += X.Get(n, oy-_Pad.y, ox-_Pad.x, c) * K.Get(dy, dx, c, k);
- }
- }
- }
- O.Set(n, y, x, k, v);
- }
-}
-
-#undef CTILE
-#define CTILE NUMTHREAD(16, 8, 8)
-groupshared float Conv_Xcache[4][CTILE][CTILE];
-groupshared float Conv_Kcache[4][CTILE][CTILE];
-[numthreads(CTILE, CTILE, 1)]
-void Conv2D_Cache_KCmod32_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(K.kernelCount / 2, O.batch * O.height * O.width / 2, 1);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- #define X_ Conv_Xcache
- #define K_ Conv_Kcache
-
- uint gx = groupThreadID.x;
- uint gy = groupThreadID.y;
-
- uint k = CTILE * groupID.x + groupThreadID.x;
- uint nyx = CTILE * groupID.y + groupThreadID.y;
-
- uint width = O.width;
- uint height = O.height;
-
- uint x = nyx % width;
- uint ny = nyx / width;
- uint y = ny % height;
- uint n = ny / height;
-
- float b0 = B.Get(k*2+0);
- float b1 = B.Get(k*2+1);
- float4 v = float4(b0, b1,
- b0, b1);
-
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- bool mask = true;
- uint oy = y * _Stride.y + dy;
- uint ox = x * _Stride.x + dx;
- // @TODO: investigate
- // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
- if (oy < _Pad.y) mask = false;
- if (oy - _Pad.w >= X.height) mask = false;
- if (ox < _Pad.x) mask = false;
- if (ox - _Pad.z >= X.width) mask = false;
-
- for (uint m = 0; m < X.channels/(CTILE*2); ++m)
- {
- float x0 = 0;
- float x1 = 0;
- float x2 = 0;
- float x3 = 0;
-
- if (mask)
- {
- x0 = X.Get(n*2+0, oy-_Pad.y, ox-_Pad.x, (m*CTILE + gx)*2+0);
- x1 = X.Get(n*2+0, oy-_Pad.y, ox-_Pad.x, (m*CTILE + gx)*2+1);
- x2 = X.Get(n*2+1, oy-_Pad.y, ox-_Pad.x, (m*CTILE + gx)*2+0);
- x3 = X.Get(n*2+1, oy-_Pad.y, ox-_Pad.x, (m*CTILE + gx)*2+1);
- }
-
- float k0 = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+0);
- float k1 = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+1);
- float k2 = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+0);
- float k3 = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+1);
-
- //X_[gy][gx] = float4(x0, x1,
- // x2, x3);
- //K_[gy][gx] = float4(k0, k1,
- // k2, k3);
- X_[0][gy][gx] = x0;
- X_[1][gy][gx] = x1;
- X_[2][gy][gx] = x2;
- X_[3][gy][gx] = x3;
-
- K_[0][gy][gx] = k0;
- K_[1][gy][gx] = k1;
- K_[2][gy][gx] = k2;
- K_[3][gy][gx] = k3;
-
- GroupMemoryBarrierWithGroupSync();
-
- [unroll]
- for (uint i = 0; i < CTILE; ++i)
- {
- float4 x = //X_[gy][i];
- float4( X_[0][gy][i],
- X_[1][gy][i],
- X_[2][gy][i],
- X_[3][gy][i]);
- float4 k = //K_[i][gx];
- float4( K_[0][i][gx],
- K_[1][i][gx],
- K_[2][i][gx],
- K_[3][i][gx]);
-
- v.x = mad(k.x, x.x, v.x);
- v.x = mad(k.z, x.y, v.x);
-
- v.y = mad(k.y, x.x, v.y);
- v.y = mad(k.w, x.y, v.y);
-
- v.z = mad(k.x, x.z, v.z);
- v.z = mad(k.z, x.w, v.z);
-
- v.w = mad(k.y, x.z, v.w);
- v.w = mad(k.w, x.w, v.w);
-
- //v.x += k.x*x.x + k.z*x.y;
- //v.y += k.y*x.x + k.w*x.y;
- //v.z += k.x*x.z + k.z*x.w;
- //v.w += k.y*x.z + k.w*x.w;
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- O.Set(n*2+0, y, x, k*2+0, v.x);
- O.Set(n*2+0, y, x, k*2+1, v.y);
- O.Set(n*2+1, y, x, k*2+0, v.z);
- O.Set(n*2+1, y, x, k*2+1, v.w);
-
- #undef X_
- #undef K_
-}
-
-#undef CTILE
-//#define CTILE NUMTHREAD(16, 8, 8)
-#define CTILE 16
-groupshared float Conv_Xcache2[4][CTILE][CTILE];
-groupshared float Conv_Kcache2[4][CTILE][CTILE];
-[numthreads(CTILE, CTILE, 1)]
-void Conv2D_Cache_KCmod32_KNyxDiv2(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(K.kernelCount / 2, O.batch * O.height * O.width / 2, 1);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- #define X_ Conv_Xcache2
- #define K_ Conv_Kcache2
-
- uint gx = groupThreadID.x;
- uint gy = groupThreadID.y;
-
- uint k = CTILE * groupID.x + groupThreadID.x;
- uint nyx = CTILE * groupID.y + groupThreadID.y;
-
- uint width = O.width / 2;
- uint height = O.height;
-
- uint x = nyx % width;
- uint ny = nyx / width;
- uint y = ny % height;
- uint n = ny / height;
-
- float b0 = B.Get(k*2+0);
- float b1 = B.Get(k*2+1);
- float4 v = float4(b0, b1,
- b0, b1);
-
- bool mask = n < O.batch;
-
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- // @TODO: investigate
- // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
- bool maskY = mask;
- uint oy = y * _Stride.y + dy;
- if (oy < _Pad.y) maskY = false;
- if (oy - _Pad.w >= X.height) maskY = false;
-
- bool maskL = maskY;
- uint oxL = (x*2+0) * _Stride.x + dx;
- if (oxL < _Pad.x) maskL = false;
- if (oxL - _Pad.z >= X.width) maskL = false;
-
- bool maskR = maskY;
- uint oxR = (x*2+1) * _Stride.x + dx;
- if (oxR < _Pad.x) maskR = false;
- if (oxR - _Pad.z >= X.width) maskR = false;
-
- for (uint m = 0; m < X.channels/(CTILE*2); ++m)
- {
- if (maskL)
- {
- X_[0][gy][gx] = X.Get(n, oy-_Pad.y, oxL-_Pad.x, (m*CTILE + gx)*2+0);
- X_[1][gy][gx] = X.Get(n, oy-_Pad.y, oxL-_Pad.x, (m*CTILE + gx)*2+1);
- }
- else
- {
- X_[0][gy][gx] = X_[1][gy][gx] = 0;
- }
-
- if (maskR)
- {
- X_[2][gy][gx] = X.Get(n, oy-_Pad.y, oxR-_Pad.x, (m*CTILE + gx)*2+0);
- X_[3][gy][gx] = X.Get(n, oy-_Pad.y, oxR-_Pad.x, (m*CTILE + gx)*2+1);
- }
- else
- {
- X_[2][gy][gx] = X_[3][gy][gx] = 0;
- }
-
-
- K_[0][gy][gx] = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+0);
- K_[1][gy][gx] = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+1);
- K_[2][gy][gx] = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+0);
- K_[3][gy][gx] = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+1);
-
- GroupMemoryBarrierWithGroupSync();
-
- [unroll]
- for (uint i = 0; i < CTILE; ++i)
- {
- float4 x =
- float4( X_[0][gy][i],
- X_[1][gy][i],
- X_[2][gy][i],
- X_[3][gy][i]);
- float4 k =
- float4( K_[0][i][gx],
- K_[1][i][gx],
- K_[2][i][gx],
- K_[3][i][gx]);
-
- v.x = mad(k.x, x.x, v.x);
- v.x = mad(k.z, x.y, v.x);
-
- v.y = mad(k.y, x.x, v.y);
- v.y = mad(k.w, x.y, v.y);
-
- v.z = mad(k.x, x.z, v.z);
- v.z = mad(k.z, x.w, v.z);
-
- v.w = mad(k.y, x.z, v.w);
- v.w = mad(k.w, x.w, v.w);
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- O.Set(n, y, x*2+0, k*2+0, v.x);
- O.Set(n, y, x*2+0, k*2+1, v.y);
- if (mask && x*2+1 < O.width)
- {
- O.Set(n, y, x*2+1, k*2+0, v.z);
- O.Set(n, y, x*2+1, k*2+1, v.w);
- }
-
- #undef X_
- #undef K_
-}
-
-
-#undef CTILE
-//#define CTILE NUMTHREAD(16, 8, 8)
-#define CTILE 16
-#define RTILE 4
-groupshared float Conv_XcacheR[RTILE*RTILE][CTILE*CTILE];
-groupshared float Conv_KcacheR[RTILE*RTILE][CTILE*CTILE];
-[numthreads(CTILE, CTILE, 1)]
-void Conv2D_Cache_KCmod64_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(K.kernelCount / 4, O.batch * O.height * O.width / 4, 1);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- #define X_ Conv_XcacheR
- #define K_ Conv_KcacheR
-
- uint gx = groupThreadID.x;
- uint gy = groupThreadID.y;
-
- uint k = CTILE * groupID.x + groupThreadID.x;
- uint nyx = CTILE * groupID.y + groupThreadID.y;
-
- uint x = nyx % O.width;
- uint ny = nyx / O.width;
- uint y = ny % O.height;
- uint n = ny / O.height;
-
- float v[RTILE][RTILE];
- for (uint xxxx = 0; xxxx < RTILE; ++xxxx)
- {
- float b = B.Get(k*RTILE+xxxx);
- for (uint yyyy = 0; yyyy < RTILE; ++yyyy)
- v[yyyy][xxxx] = b;
- }
-
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- bool mask = true;
- uint oy = y * _Stride.y + dy;
- uint ox = x * _Stride.x + dx;
- // @TODO: investigate
- // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
- if (oy < _Pad.y) mask = false;
- if (oy - _Pad.w >= X.height) mask = false;
- if (ox < _Pad.x) mask = false;
- if (ox - _Pad.z >= X.width) mask = false;
-
- for (uint m = 0; m < X.channels/(CTILE*RTILE); ++m)
- {
- for (uint yy = 0; yy < RTILE; ++yy)
- for (uint xx = 0; xx < RTILE; ++xx)
- {
- if (mask)
- X_[yy*RTILE+xx][gy*CTILE+gx] = X.Get(n*RTILE+yy, oy - _Pad.y, ox - _Pad.x, (m*CTILE + gx)*RTILE+xx);
- else
- X_[yy*RTILE+xx][gy*CTILE+gx] = 0;
- K_[yy*RTILE+xx][gy*CTILE+gx] = K.Get(dy, dx, (m*CTILE + gy)*RTILE+yy, k*RTILE+xx);
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- for (uint ii = 0; ii < CTILE; ++ii)
- {
- float x[RTILE][RTILE];
- float k[RTILE][RTILE];
-
- [unroll]
- for (uint yy = 0; yy < RTILE; ++yy)
- {
- [unroll]
- for (uint xx = 0; xx < RTILE; ++xx)
- {
- x[yy][xx] = X_[yy*RTILE+xx][gy*CTILE+ii];
- k[yy][xx] = K_[yy*RTILE+xx][ii*CTILE+gx];
- }
- }
-
-
- [unroll]
- for (uint yyy = 0; yyy < RTILE; ++yyy)
- {
- [unroll]
- for (uint xxx = 0; xxx < RTILE; ++xxx)
- {
- [unroll]
- for (uint i = 0; i < RTILE; ++i)
- {
- v[yyy][xxx] = mad(x[yyy][i], k[i][xxx], v[yyy][xxx]);
- }
- }
- }
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- for (uint yy = 0; yy < RTILE; ++yy)
- for (uint xx = 0; xx < RTILE; ++xx)
- O.Set(n*RTILE+yy, y, x, k*RTILE+xx, v[yy][xx]);
-
- #undef X_
- #undef K_
-}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/ConvOld.compute.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/ConvOld.compute.meta
deleted file mode 100644
index dae45fc790..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/ConvOld.compute.meta
+++ /dev/null
@@ -1,8 +0,0 @@
-fileFormatVersion: 2
-guid: a89bb2d7cde74429c8475f7cd8bcdb01
-ComputeShaderImporter:
- externalObjects: {}
- currentAPIMask: 0
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Dense.compute b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Dense.compute
deleted file mode 100644
index bc5f328db9..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Dense.compute
+++ /dev/null
@@ -1,1672 +0,0 @@
-#pragma kernel Dense_L1Cached64
-#pragma kernel DenseTiled16x16
-#pragma kernel DenseTiled32x32
-#pragma kernel DenseTiled64x64
-
-//#pragma kernel Dense_T8x8_R8x8 DENSE=1 BLOCK_SIZE=8
-#pragma kernel Dense_T8x8_R4x4 DENSE=1 BLOCK_SIZE=4
-#pragma kernel Dense_T16x16_R4x4 DENSE=1 BLOCK_SIZE=4
-
-#include "Tensor.cginc"
-
-TENSOR_DECL(X)
-TENSOR_DECL(W)
-TENSOR_DECL(B)
-TENSOR_DECL(WBK)
-TENSOR_DECL_RW(O)
-
-#if DENSE
-float ffma(float a, float b, float c) { return dot(float2(a,c), float2(b,1)); } //return a*b+c;} //fastfma(a,b,c); }
-#define FUNC_NAME(KERNEL, SIZE) KERNEL##SIZE##x##SIZE
-#define CACHE_NAME(KERNEL, SIZE, TENSOR) KERNEL##SIZE##x##SIZE##_Cache_##TENSOR
-
-
-//CACHE_DEPTH
-// T >>X
-//16: 178ms 272ms 181ms
-// 8: 173ms 395ms 205ms
-// 4: 176ms 630ms 260ms
-// 2: 205ms 495ms 420ms
-// 1: 209ms 980ms --
-
-
-//@HARDCODED_DIMS + BUF_OFFSET + lds read index alu opt
-//CACHE_DEPTH
-// T >>X
-//16: 169ms 241ms 173ms
-// 8: 169ms 356ms 178ms
-// 4: 170ms 612ms 209ms
-// 2: 178ms 900ms 380ms
-// 1: 250ms 875ms --
-
-//@BLOCKED_W + HARDCODED_DIMS + BUF_OFFSET + lds read index alu opt
-//!INCLUDING ValidateData by mistake!
-//CACHE_DEPTH
-// T >>X
-//16: 144ms 241ms 155ms
-// 8: 158ms 357ms 164ms
-// 4: 151ms 630ms 202ms
-// 2: 180ms 815ms 350ms
-// 1: 258ms 883ms --
-// @TODO: try 32
-
-
-//============================================
-//@BLOCKED_W + BUF_OFFSET + lds read index alu opt
-//CACHE_DEPTH
-// T T >>X
-// hard_dims
-//32: 167ms
-//16: 122ms 141ms 140ms
-// 8: 136ms 147ms 154ms
-// 4: 130ms 141ms 189ms
-// 2: 159ms ***ms ***ms
-// 1: 220ms ***ms ***ms
-//
-//Vega
-//32: 172ms
-//16: 154ms
-// 8: 156ms
-// 4: 161ms
-// 2: 162ms
-// 1: 245ms
-//iOS(8layers)
-//32: 28ms
-
-
-//@BLOCKED_W + lds read index alu opt
-//16: 134ms 142ms 146ms
-
-
-//@BLOCKED_W + BUF_OFFSET + optimized read indices
-//CACHE_DEPTH
-//16: 123ms 131ms 135ms
-
-
-#define KERNEL_NAME Dense_T16x16_R
-#if BLOCK_SIZE == 4
-#define TRANSPOSED_X 0
-#define SHIFTED_X 1
-#define BLOCKED_W 1
-#define HARDCODED_DIMS 0
-#define BUF_OFFSET 0
-#define DOUBLE_BUFFER_LDS_READS 0
-#define CACHE_DEPTH 16
-groupshared float CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, X)[CACHE_DEPTH*16*BLOCK_SIZE+SHIFTED_X*CACHE_DEPTH];
-groupshared float CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, W)[CACHE_DEPTH*16*BLOCK_SIZE];
-[numthreads(16,16,1)]
-void FUNC_NAME(KERNEL_NAME, BLOCK_SIZE)(uint3 dispatchThreadID : SV_DispatchThreadID, uint3 groupThreadID : SV_GroupThreadID, uint threadIndex : SV_GroupIndex)
-{
- DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- int x = (int)dispatchThreadID.x * BLOCK_SIZE;
- int y = (int)dispatchThreadID.y * BLOCK_SIZE;
- int tx = (int)groupThreadID.x;
- int ty = (int)groupThreadID.y;
- int bx = ((int)dispatchThreadID.x - (int)groupThreadID.x) * BLOCK_SIZE;
- int by = ((int)dispatchThreadID.y - (int)groupThreadID.y) * BLOCK_SIZE;
- int ti = (int)threadIndex;
- int n = (int)X.GetFlatWidth();
- int strideX = (int)X.GetFlatWidth();
- int strideW = (int)W.GetFlatWidth();
- int strideO = (int)O.GetFlatWidth();
- int offsetX = BUF_OFFSET;
- int offsetW = BUF_OFFSET;
- int offsetO = BUF_OFFSET;
-#if HARDCODED_DIMS == 1
- n = 1024;
- strideX = 1024;
- strideW = 1024;
- strideO = 1024;
-#endif
-
- #define X_ CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, X)
- #define W_ CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, W)
-
- //if (x >= (int)O.GetFlatWidth()) return;
- //if (y >= (int)O.GetFlatHeight()) return;
-
- float4 dstA_0, dstA_1, dstA_2, dstA_3;
-
- dstA_0.x = B.Get(x+0);
- dstA_1.x = B.Get(x+0);
- dstA_2.x = B.Get(x+0);
- dstA_3.x = B.Get(x+0);
- dstA_0.y = B.Get(x+1);
- dstA_1.y = B.Get(x+1);
- dstA_2.y = B.Get(x+1);
- dstA_3.y = B.Get(x+1);
- dstA_0.z = B.Get(x+2);
- dstA_1.z = B.Get(x+2);
- dstA_2.z = B.Get(x+2);
- dstA_3.z = B.Get(x+2);
- dstA_0.w = B.Get(x+3);
- dstA_1.w = B.Get(x+3);
- dstA_2.w = B.Get(x+3);
- dstA_3.w = B.Get(x+3);
-
- int j;
- int readW = strideW * (ti>>6) + bx + (ti&63) + offsetW;
- #if TRANSPOSED_X == 1
- int readX = strideX * (ti>>6) + by + (ti&63) + offsetX;
- #elif SHIFTED_X == 1
- int4 readX = int4(
- strideX * (by + (ti>>4) + 0) + (ti&15) + offsetX,
- strideX * (by + (ti>>4) +16) + (ti&15) + offsetX,
- strideX * (by + (ti>>4) +32) + (ti&15) + offsetX,
- strideX * (by + (ti>>4) +48) + (ti&15) + offsetX);
- #endif
-
- for (int i = 0; i < n; i += CACHE_DEPTH)
- {
-
- #if CACHE_DEPTH == 32
- #if BLOCKED_W == 1
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2) ] = W.data[strideW * (i + (ti>>6) + 0) + bx + (ti&63) + offsetW];
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2)+256] = W.data[strideW * (i + (ti>>6) + 4) + bx + (ti&63) + offsetW];
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2)+512] = W.data[strideW * (i + (ti>>6) + 8) + bx + (ti&63) + offsetW];
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2)+768] = W.data[strideW * (i + (ti>>6) +12) + bx + (ti&63) + offsetW];
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2)+1024]= W.data[strideW * (i + (ti>>6) +16) + bx + (ti&63) + offsetW];
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2)+1280]= W.data[strideW * (i + (ti>>6) +20) + bx + (ti&63) + offsetW];
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2)+1536]= W.data[strideW * (i + (ti>>6) +24) + bx + (ti&63) + offsetW];
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2)+1792]= W.data[strideW * (i + (ti>>6) +28) + bx + (ti&63) + offsetW];
- #else
- #endif
-
- #if TRANSPOSED_X == 1
- X_[ti ] = X.data[strideX * (i + (ti>>6) + 0) + by + (ti&63) + offsetX];
- X_[ti+256] = X.data[strideX * (i + (ti>>6) + 4) + by + (ti&63) + offsetX];
- X_[ti+512] = X.data[strideX * (i + (ti>>6) + 8) + by + (ti&63) + offsetX];
- X_[ti+768] = X.data[strideX * (i + (ti>>6) +12) + by + (ti&63) + offsetX];
- X_[ti+1024]= X.data[strideX * (i + (ti>>6) +16) + by + (ti&63) + offsetX];
- X_[ti+1280]= X.data[strideX * (i + (ti>>6) +20) + by + (ti&63) + offsetX];
- X_[ti+1536]= X.data[strideX * (i + (ti>>6) +24) + by + (ti&63) + offsetX];
- X_[ti+1792]= X.data[strideX * (i + (ti>>6) +28) + by + (ti&63) + offsetX];
- #elif SHIFTED_X == 1
- // 16x64 => 64x16
- X_[(ti>>5) + 65*(ti&31) + 0] = X.data[strideX * (by + (ti>>5) + 0) + i + (ti&31) + offsetX];
- X_[(ti>>5) + 65*(ti&31) + 8] = X.data[strideX * (by + (ti>>5) + 8) + i + (ti&31) + offsetX];
- X_[(ti>>5) + 65*(ti&31) +16] = X.data[strideX * (by + (ti>>5) +16) + i + (ti&31) + offsetX];
- X_[(ti>>5) + 65*(ti&31) +24] = X.data[strideX * (by + (ti>>5) +24) + i + (ti&31) + offsetX];
- X_[(ti>>5) + 65*(ti&31) +32] = X.data[strideX * (by + (ti>>5) +32) + i + (ti&31) + offsetX];
- X_[(ti>>5) + 65*(ti&31) +40] = X.data[strideX * (by + (ti>>5) +40) + i + (ti&31) + offsetX];
- X_[(ti>>5) + 65*(ti&31) +48] = X.data[strideX * (by + (ti>>5) +48) + i + (ti&31) + offsetX];
- X_[(ti>>5) + 65*(ti&31) +56] = X.data[strideX * (by + (ti>>5) +56) + i + (ti&31) + offsetX];
- #else
- // 16x64 => 64x16
- #endif
-
-
- #elif CACHE_DEPTH == 16
- #if BLOCKED_W == 1
- #if HARDCODED_DIMS
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2) ] = W.data[strideW * (i + (ti>>6) + 0) + bx + (ti&63) + offsetW];
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2)+256] = W.data[strideW * (i + (ti>>6) + 4) + bx + (ti&63) + offsetW];
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2)+512] = W.data[strideW * (i + (ti>>6) + 8) + bx + (ti&63) + offsetW];
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2)+768] = W.data[strideW * (i + (ti>>6) +12) + bx + (ti&63) + offsetW];
- #else
- [unroll] for (j = 0; j < 4; ++j, readW += strideW * 4)
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2) + 256*j] = W.data[readW];
- #endif
- #else
- W_[ti ] = W.data[strideW * (i + (ti>>6) + 0) + bx + (ti&63) + offsetW];
- W_[ti+256] = W.data[strideW * (i + (ti>>6) + 4) + bx + (ti&63) + offsetW];
- W_[ti+512] = W.data[strideW * (i + (ti>>6) + 8) + bx + (ti&63) + offsetW];
- W_[ti+768] = W.data[strideW * (i + (ti>>6) +12) + bx + (ti&63) + offsetW];
- #endif
-
- #if TRANSPOSED_X == 1
- #if HARDCODED_DIMS
- X_[ti ] = X.data[strideX * (i + (ti>>6) + 0) + by + (ti&63) + offsetX];
- X_[ti+256] = X.data[strideX * (i + (ti>>6) + 4) + by + (ti&63) + offsetX];
- X_[ti+512] = X.data[strideX * (i + (ti>>6) + 8) + by + (ti&63) + offsetX];
- X_[ti+768] = X.data[strideX * (i + (ti>>6) +12) + by + (ti&63) + offsetX];
- #else
- [unroll] for (j = 0; j < 4; ++j, readX += strideX * 4)
- X_[ti + 256*j] = X.data[readX];
- #endif
-
- #elif SHIFTED_X == 1
- // 16x64 => 64x16
- #if HARDCODED_DIMS
- X_[(ti>>4) + 65*(ti&15) + 0] = X.data[strideX * (by + (ti>>4) + 0) + i + (ti&15) + offsetX];
- X_[(ti>>4) + 65*(ti&15) +16] = X.data[strideX * (by + (ti>>4) +16) + i + (ti&15) + offsetX];
- X_[(ti>>4) + 65*(ti&15) +32] = X.data[strideX * (by + (ti>>4) +32) + i + (ti&15) + offsetX];
- X_[(ti>>4) + 65*(ti&15) +48] = X.data[strideX * (by + (ti>>4) +48) + i + (ti&15) + offsetX];
- #else
- [unroll] for (j = 0; j < 4; ++j)
- X_[(ti>>4) + 65*(ti&15) + 16*j] = X.data[readX[j]];
- readX += CACHE_DEPTH;
- #endif
- #else
- // 16x64 => 64x16
- X_[ti ] = X.data[strideX * (by + (ti&63)) + i + (ti>>6) + 0 + offsetX];
- X_[ti+256] = X.data[strideX * (by + (ti&63)) + i + (ti>>6) + 4 + offsetX];
- X_[ti+512] = X.data[strideX * (by + (ti&63)) + i + (ti>>6) + 8 + offsetX];
- X_[ti+768] = X.data[strideX * (by + (ti&63)) + i + (ti>>6) +12 + offsetX];
- #endif
-
- #elif CACHE_DEPTH == 8
- #if BLOCKED_W == 1
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2) ] = W.data[strideW * (i + (ti>>6) + 0) + bx + (ti&63) + offsetW];
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2)+256] = W.data[strideW * (i + (ti>>6) + 4) + bx + (ti&63) + offsetW];
- #else
- W_[ti ] = W.data[strideW * (i + (ti>>6) + 0) + bx + (ti&63) + offsetW];
- W_[ti+256] = W.data[strideW * (i + (ti>>6) + 4) + bx + (ti&63) + offsetW];
- #endif
-
- #if TRANSPOSED_X == 1
- X_[ti ] = X.data[strideX * (i + (ti>>6) + 0) + by + (ti&63) + offsetX];
- X_[ti+256] = X.data[strideX * (i + (ti>>6) + 4) + by + (ti&63) + offsetX];
- #elif SHIFTED_X == 1
- // 8x64 => 64x8
- X_[(ti>>3) + 65*(ti&7) + 0] = X.data[strideX * (by + (ti>>3) + 0) + i + (ti&7) + offsetX];
- X_[(ti>>3) + 65*(ti&7) +32] = X.data[strideX * (by + (ti>>3) +32) + i + (ti&7) + offsetX];
- #else
- // 8x64 => 64x8
- X_[ti ] = X.data[strideX * (by + (ti&63)) + i + (ti>>6) + 0 + offsetX];
- X_[ti+256] = X.data[strideX * (by + (ti&63)) + i + (ti>>6) + 4 + offsetX];
- #endif
-
- #elif CACHE_DEPTH == 4
- #if BLOCKED_W == 1
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2) ] = W.data[strideW * (i + (ti>>6) + 0) + bx + (ti&63) + offsetW];
- #else
- W_[ti ] = W.data[strideW * (i + (ti>>6) + 0) + bx + (ti&63) + offsetW];
- #endif
- #if TRANSPOSED_X == 1
- X_[ti ] = X.data[strideX * (i + (ti>>6) + 0) + by + (ti&63) + offsetX];
- #elif SHIFTED_X == 1
- // 4x64 => 64x4
- X_[(ti>>2) + 65*(ti&3) + 0] = X.data[strideX * (by + (ti>>2) + 0) + i + (ti&3) + offsetX];
- #else
- // 4x64 => 64x4
- X_[ti ] = X.data[strideX * (by + (ti&63)) + i + (ti>>6) + 0 + offsetX];
- #endif
-
- #elif CACHE_DEPTH == 2
- if (ti < 128)
- {
- #if BLOCKED_W == 1
- W_[((ti>>6)<<6) + ((ti&3)<<4) + ((ti&63)>>2) ] = W.data[strideW * (i + (ti>>6) + 0) + bx + (ti&63) + offsetW];
- #else
- W_[ti ] = W.data[strideW * (i + (ti>>6) + 0) + bx + (ti&63) + offsetW];
- #endif
- #if TRANSPOSED_X == 1
- X_[ti ] = X.data[strideX * (i + (ti>>6) + 0) + by + (ti&63) + offsetX];
- #elif SHIFTED_X == 1
- X_[(ti>>1) + 65*(ti&1) + 0] = X.data[strideX * (by + (ti>>1) + 0) + i + (ti&1) + offsetX];
- #else
- X_[ti ] = X.data[strideX * (by + (ti&63)) + i + (ti>>6) + 0 + offsetX];
- #endif
- }
-
- #elif CACHE_DEPTH == 1
- if (ti < 64)
- {
- #if BLOCKED_W == 1
- W_[((ti&3)<<4) + ((ti&63)>>2) ] = W.data[strideW * i + bx + ti + offsetW];
- #else
- W_[ti] = W.data[strideW * i + bx + ti + offsetW];
- #endif
- #if TRANSPOSED_X == 1
- X_[ti] = X.data[strideX * i + by + ti + offsetX];
- #else
- //X_[ti] = X.Get(by+ti, i);
- X_[ti] = X.data[strideX * (by + ti) + i + offsetX];
- #endif
- }
- #endif
-
- GroupMemoryBarrierWithGroupSync();
-
- int4 idX = int4(0,1,2,3);
- int4 idW = int4(0,1,2,3);
- #if BLOCKED_W == 1
- idW = int4(0,16,32,48);
- #endif
- int incX = 64 + (SHIFTED_X & ~TRANSPOSED_X);
- int incW = 64;
-#if 0 //DOUBLE_BUFFER_LDS_READS == 1
- float4 srcW_ = float4(
- #if BLOCKED_W == 1
- W_[idW.x + tx],
- W_[idW.y + tx],
- W_[idW.z + tx],
- W_[idW.w + tx]
- #else
- W_[idW.x + tx*4],
- W_[idW.y + tx*4],
- W_[idW.z + tx*4],
- W_[idW.w + tx*4]
- #endif
- );
- idW += incW;
-
- //int lastX = idX.x + (CACHE_DEPTH - 2) * incX.x;
- //while (idX.x < lastX.x)
- for (int di = 0; di < CACHE_DEPTH - 2; di+=2)
- {
- float4 srcX, srcW;
- srcX = float4(
- X_[idX.x + ty*4],
- X_[idX.y + ty*4],
- X_[idX.z + ty*4],
- X_[idX.w + ty*4]);
- srcW = float4(
- #if BLOCKED_W == 1
- W_[idW.x + tx],
- W_[idW.y + tx],
- W_[idW.z + tx],
- W_[idW.w + tx]
- #else
- W_[idW.x + tx*4],
- W_[idW.y + tx*4],
- W_[idW.z + tx*4],
- W_[idW.w + tx*4]
- #endif
- );
- idX += incX;
- idW += incW;
-
- dstA_0.x = ffma(srcX.x, srcW_.x, dstA_0.x);
- dstA_0.y = ffma(srcX.x, srcW_.y, dstA_0.y);
- dstA_0.z = ffma(srcX.x, srcW_.z, dstA_0.z);
- dstA_0.w = ffma(srcX.x, srcW_.w, dstA_0.w);
-
- dstA_1.x = ffma(srcX.y, srcW_.x, dstA_1.x);
- dstA_1.y = ffma(srcX.y, srcW_.y, dstA_1.y);
- dstA_1.z = ffma(srcX.y, srcW_.z, dstA_1.z);
- dstA_1.w = ffma(srcX.y, srcW_.w, dstA_1.w);
-
- dstA_2.x = ffma(srcX.z, srcW_.x, dstA_2.x);
- dstA_2.y = ffma(srcX.z, srcW_.y, dstA_2.y);
- dstA_2.z = ffma(srcX.z, srcW_.z, dstA_2.z);
- dstA_2.w = ffma(srcX.z, srcW_.w, dstA_2.w);
-
- dstA_3.x = ffma(srcX.w, srcW_.x, dstA_3.x);
- dstA_3.y = ffma(srcX.w, srcW_.y, dstA_3.y);
- dstA_3.z = ffma(srcX.w, srcW_.z, dstA_3.z);
- dstA_3.w = ffma(srcX.w, srcW_.w, dstA_3.w);
-
- srcX = float4(
- X_[idX.x + ty*4],
- X_[idX.y + ty*4],
- X_[idX.z + ty*4],
- X_[idX.w + ty*4]);
- srcW_ = float4(
- #if BLOCKED_W == 1
- W_[idW.x + tx],
- W_[idW.y + tx],
- W_[idW.z + tx],
- W_[idW.w + tx]
- #else
- W_[idW.x + tx*4],
- W_[idW.y + tx*4],
- W_[idW.z + tx*4],
- W_[idW.w + tx*4]
- #endif
- );
- idX += incX;
- idW += incW;
-
- dstA_0.x = ffma(srcX.x, srcW.x, dstA_0.x);
- dstA_0.y = ffma(srcX.x, srcW.y, dstA_0.y);
- dstA_0.z = ffma(srcX.x, srcW.z, dstA_0.z);
- dstA_0.w = ffma(srcX.x, srcW.w, dstA_0.w);
-
- dstA_1.x = ffma(srcX.y, srcW.x, dstA_1.x);
- dstA_1.y = ffma(srcX.y, srcW.y, dstA_1.y);
- dstA_1.z = ffma(srcX.y, srcW.z, dstA_1.z);
- dstA_1.w = ffma(srcX.y, srcW.w, dstA_1.w);
-
- dstA_2.x = ffma(srcX.z, srcW.x, dstA_2.x);
- dstA_2.y = ffma(srcX.z, srcW.y, dstA_2.y);
- dstA_2.z = ffma(srcX.z, srcW.z, dstA_2.z);
- dstA_2.w = ffma(srcX.z, srcW.w, dstA_2.w);
-
- dstA_3.x = ffma(srcX.w, srcW.x, dstA_3.x);
- dstA_3.y = ffma(srcX.w, srcW.y, dstA_3.y);
- dstA_3.z = ffma(srcX.w, srcW.z, dstA_3.z);
- dstA_3.w = ffma(srcX.w, srcW.w, dstA_3.w);
- }
-
- float4 srcX = float4(
- X_[idX.x + ty*4],
- X_[idX.y + ty*4],
- X_[idX.z + ty*4],
- X_[idX.w + ty*4]);
- float4 srcW = float4(
- #if BLOCKED_W == 1
- W_[idW.x + tx],
- W_[idW.y + tx],
- W_[idW.z + tx],
- W_[idW.w + tx]
- #else
- W_[idW.x + tx*4],
- W_[idW.y + tx*4],
- W_[idW.z + tx*4],
- W_[idW.w + tx*4]
- #endif
- );
-
- dstA_0.x = ffma(srcX.x, srcW_.x, dstA_0.x);
- dstA_0.y = ffma(srcX.x, srcW_.y, dstA_0.y);
- dstA_0.z = ffma(srcX.x, srcW_.z, dstA_0.z);
- dstA_0.w = ffma(srcX.x, srcW_.w, dstA_0.w);
-
- dstA_1.x = ffma(srcX.y, srcW_.x, dstA_1.x);
- dstA_1.y = ffma(srcX.y, srcW_.y, dstA_1.y);
- dstA_1.z = ffma(srcX.y, srcW_.z, dstA_1.z);
- dstA_1.w = ffma(srcX.y, srcW_.w, dstA_1.w);
-
- dstA_2.x = ffma(srcX.z, srcW_.x, dstA_2.x);
- dstA_2.y = ffma(srcX.z, srcW_.y, dstA_2.y);
- dstA_2.z = ffma(srcX.z, srcW_.z, dstA_2.z);
- dstA_2.w = ffma(srcX.z, srcW_.w, dstA_2.w);
-
- dstA_3.x = ffma(srcX.w, srcW_.x, dstA_3.x);
- dstA_3.y = ffma(srcX.w, srcW_.y, dstA_3.y);
- dstA_3.z = ffma(srcX.w, srcW_.z, dstA_3.z);
- dstA_3.w = ffma(srcX.w, srcW_.w, dstA_3.w);
-
- srcX = float4(
- X_[idX.x + ty*4],
- X_[idX.y + ty*4],
- X_[idX.z + ty*4],
- X_[idX.w + ty*4]);
- idX += incX;
-
- dstA_0.x = ffma(srcX.x, srcW.x, dstA_0.x);
- dstA_0.y = ffma(srcX.x, srcW.y, dstA_0.y);
- dstA_0.z = ffma(srcX.x, srcW.z, dstA_0.z);
- dstA_0.w = ffma(srcX.x, srcW.w, dstA_0.w);
-
- dstA_1.x = ffma(srcX.y, srcW.x, dstA_1.x);
- dstA_1.y = ffma(srcX.y, srcW.y, dstA_1.y);
- dstA_1.z = ffma(srcX.y, srcW.z, dstA_1.z);
- dstA_1.w = ffma(srcX.y, srcW.w, dstA_1.w);
-
- dstA_2.x = ffma(srcX.z, srcW.x, dstA_2.x);
- dstA_2.y = ffma(srcX.z, srcW.y, dstA_2.y);
- dstA_2.z = ffma(srcX.z, srcW.z, dstA_2.z);
- dstA_2.w = ffma(srcX.z, srcW.w, dstA_2.w);
-
- dstA_3.x = ffma(srcX.w, srcW.x, dstA_3.x);
- dstA_3.y = ffma(srcX.w, srcW.y, dstA_3.y);
- dstA_3.z = ffma(srcX.w, srcW.z, dstA_3.z);
- dstA_3.w = ffma(srcX.w, srcW.w, dstA_3.w);
-
-
- GroupMemoryBarrierWithGroupSync();
- }
-#else // DOUBLE_BUFFER_LDS_READS
-
-#define CACHE_UNROLL 1
- for (int di = 0; di < CACHE_DEPTH; di+=CACHE_UNROLL)
- {
- float4 srcX = float4(
- X_[idX.x + /*ti+0**/ ty*4],
- X_[idX.y + /*ti+0**/ ty*4],
- X_[idX.z + /*ti+0**/ ty*4],
- X_[idX.w + /*ti+0**/ ty*4]);
- //X_[di*_64 + ty*4 + 0],
- //X_[di*_64 + ty*4 + 1],
- //X_[di*_64 + ty*4 + 2],
- //X_[di*_64 + ty*4 + 3]);
- //X.Get(y+0, i+di),
- //X.Get(y+1, i+di),
- //X.Get(y+2, i+di),
- //X.Get(y+3, i+di));
- float4 srcW = float4(
- #if BLOCKED_W == 1
- W_[idW.x + tx],
- W_[idW.y + tx],
- W_[idW.z + tx],
- W_[idW.w + tx]
- #else
- W_[idW.x + tx*4],
- W_[idW.y + tx*4],
- W_[idW.z + tx*4],
- W_[idW.w + tx*4]
- #endif
- //W_[di*64 + tx*4 + 0],
- //W_[di*64 + tx*4 + 1],
- //W_[di*64 + tx*4 + 2],
- //W_[di*64 + tx*4 + 3]
- //W.Get(i+di, x+0),
- //W.Get(i+di, x+1),
- //W.Get(i+di, x+2),
- //W.Get(i+di, x+3)
- );
- idX += incX;
- idW += incW;
-
- dstA_0.x = ffma(srcX.x, srcW.x, dstA_0.x);
- dstA_0.y = ffma(srcX.x, srcW.y, dstA_0.y);
- dstA_0.z = ffma(srcX.x, srcW.z, dstA_0.z);
- dstA_0.w = ffma(srcX.x, srcW.w, dstA_0.w);
-
- dstA_1.x = ffma(srcX.y, srcW.x, dstA_1.x);
- dstA_1.y = ffma(srcX.y, srcW.y, dstA_1.y);
- dstA_1.z = ffma(srcX.y, srcW.z, dstA_1.z);
- dstA_1.w = ffma(srcX.y, srcW.w, dstA_1.w);
-
- dstA_2.x = ffma(srcX.z, srcW.x, dstA_2.x);
- dstA_2.y = ffma(srcX.z, srcW.y, dstA_2.y);
- dstA_2.z = ffma(srcX.z, srcW.z, dstA_2.z);
- dstA_2.w = ffma(srcX.z, srcW.w, dstA_2.w);
-
- dstA_3.x = ffma(srcX.w, srcW.x, dstA_3.x);
- dstA_3.y = ffma(srcX.w, srcW.y, dstA_3.y);
- dstA_3.z = ffma(srcX.w, srcW.z, dstA_3.z);
- dstA_3.w = ffma(srcX.w, srcW.w, dstA_3.w);
-
-#if CACHE_UNROLL>=2
-#endif
-#if CACHE_UNROLL>=3
-#endif
-#if CACHE_UNROLL>=4
-#endif
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
-#undef CACHE_UNROLL
-#endif //DOUBLE_BUFFER_LDS_READS
-
- O.data[strideO * (y+0) + x+0 + offsetO] = dstA_0.x;
- O.data[strideO * (y+0) + x+1 + offsetO] = dstA_0.y;
- O.data[strideO * (y+0) + x+2 + offsetO] = dstA_0.z;
- O.data[strideO * (y+0) + x+3 + offsetO] = dstA_0.w;
- O.data[strideO * (y+1) + x+0 + offsetO] = dstA_1.x;
- O.data[strideO * (y+1) + x+1 + offsetO] = dstA_1.y;
- O.data[strideO * (y+1) + x+2 + offsetO] = dstA_1.z;
- O.data[strideO * (y+1) + x+3 + offsetO] = dstA_1.w;
- O.data[strideO * (y+2) + x+0 + offsetO] = dstA_2.x;
- O.data[strideO * (y+2) + x+1 + offsetO] = dstA_2.y;
- O.data[strideO * (y+2) + x+2 + offsetO] = dstA_2.z;
- O.data[strideO * (y+2) + x+3 + offsetO] = dstA_2.w;
- O.data[strideO * (y+3) + x+0 + offsetO] = dstA_3.x;
- O.data[strideO * (y+3) + x+1 + offsetO] = dstA_3.y;
- O.data[strideO * (y+3) + x+2 + offsetO] = dstA_3.z;
- O.data[strideO * (y+3) + x+3 + offsetO] = dstA_3.w;
-
- #undef X_
- #undef W_
-}
-#undef TRANSPOSED_X
-#undef SHIFTED_X
-#undef BLOCKED_W
-#undef HARDCODED_DIMS
-#undef BUF_OFFSET
-#undef DOUBLE_BUFFER_LDS_READS
-#undef CACHE_DEPTH
-#else
-[numthreads(16,16,1)]
-void FUNC_NAME(KERNEL_NAME, BLOCK_SIZE)(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- int x = (int)dispatchThreadID.x * BLOCK_SIZE;
- int y = (int)dispatchThreadID.y * BLOCK_SIZE;
- int n = (int)X.GetFlatWidth();
-
- if (x >= (int)O.GetFlatWidth()) return;
- if (y >= (int)O.GetFlatHeight()) return;
-
- float dstA[BLOCK_SIZE][BLOCK_SIZE];
- float srcX[BLOCK_SIZE];
-
- int dy, dx;
- for (dx = 0; dx < BLOCK_SIZE; ++dx)
- for (dy = 0; dy < BLOCK_SIZE; ++dy)
- dstA[dy][dx] = B.data[x+dx+B.offset];//B.Get(x+dx);
-
- for (int i = 0; i < n; ++i)
- {
- for (dy = 0; dy < BLOCK_SIZE; ++dy)
- srcX[dy] = X.data[(y+dy)*X.channels+i];//X.Get(y+dy, i);
-
- for (dx = 0; dx < BLOCK_SIZE; ++dx)
- {
- float srcW = W.data[i*W.channels+x+dx];//W.Get(i, x+dx);
- for (dy = 0; dy < BLOCK_SIZE; ++dy)
- dstA[dy][dx] += srcX[dy] * srcW;
- }
- }
-
- for (dx = 0; dx < BLOCK_SIZE; ++dx)
- for (dy = 0; dy < BLOCK_SIZE; ++dy)
- O.Set(y+dy, x+dx, dstA[dy][dx]);
-}
-#endif
-#undef KERNEL_NAME
-
-
-//CACHE_DEPTH
-// T >>X
-//16: 183ms 207ms
-// 8: 158ms 202ms
-// 4: 162ms 334ms
-// 2: 159ms ***ms
-// 1: 173ms --
-
-#define KERNEL_NAME Dense_T8x8_R
-#if BLOCK_SIZE == 8
-#define UNROLL_INNER_LOOP 0
-#define TRANSPOSED_X 0
-#define HARDCODED_DIMS 0
-#define BUF_OFFSET 0
-#define CACHE_DEPTH 8
-groupshared float CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, X)[CACHE_DEPTH*8*BLOCK_SIZE+(1-TRANSPOSED_X)*CACHE_DEPTH];
-groupshared float CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, W)[CACHE_DEPTH*8*BLOCK_SIZE];
-[numthreads(8,8,1)]
-void FUNC_NAME(KERNEL_NAME, BLOCK_SIZE)(uint3 dispatchThreadID : SV_DispatchThreadID, uint3 groupThreadID : SV_GroupThreadID, uint threadIndex : SV_GroupIndex)
-{
- DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- int x = (int)dispatchThreadID.x * BLOCK_SIZE;
- int y = (int)dispatchThreadID.y * BLOCK_SIZE;
- int tx = (int)groupThreadID.x;
- int ty = (int)groupThreadID.y;
- int bx = ((int)dispatchThreadID.x - (int)groupThreadID.x) * BLOCK_SIZE;
- int by = ((int)dispatchThreadID.y - (int)groupThreadID.y) * BLOCK_SIZE;
- int ti = (int)threadIndex;
- int n = (int)X.GetFlatWidth();
- int strideX = (int)X.GetFlatWidth();
- int strideW = (int)W.GetFlatWidth();
- int strideO = (int)O.GetFlatWidth();
- int offsetX = BUF_OFFSET;
- int offsetW = BUF_OFFSET;
- int offsetO = BUF_OFFSET;
-#if HARDCODED_DIMS == 1
- n = 1024;
- strideX = 1024;
- strideW = 1024;
- strideO = 1024;
-#endif
-
- #define X_ CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, X)
- #define W_ CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, W)
-
-#if UNROLL_INNER_LOOP
- float4 dstA_0, dstA_1, dstA_2, dstA_3;
- float4 dstB_0, dstB_1, dstB_2, dstB_3;
- float4 dstC_0, dstC_1, dstC_2, dstC_3;
- float4 dstD_0, dstD_1, dstD_2, dstD_3;
-
- dstA_0.x = dstC_0.x = B.Get(x+0);
- dstA_1.x = dstC_1.x = B.Get(x+0);
- dstA_2.x = dstC_2.x = B.Get(x+0);
- dstA_3.x = dstC_3.x = B.Get(x+0);
- dstA_0.y = dstC_0.y = B.Get(x+1);
- dstA_1.y = dstC_1.y = B.Get(x+1);
- dstA_2.y = dstC_2.y = B.Get(x+1);
- dstA_3.y = dstC_3.y = B.Get(x+1);
- dstA_0.z = dstC_0.z = B.Get(x+2);
- dstA_1.z = dstC_1.z = B.Get(x+2);
- dstA_2.z = dstC_2.z = B.Get(x+2);
- dstA_3.z = dstC_3.z = B.Get(x+2);
- dstA_0.w = dstC_0.w = B.Get(x+3);
- dstA_1.w = dstC_1.w = B.Get(x+3);
- dstA_2.w = dstC_2.w = B.Get(x+3);
- dstA_3.w = dstC_3.w = B.Get(x+3);
-
- dstB_0.x = dstD_0.x = B.Get(x+4);
- dstB_1.x = dstD_1.x = B.Get(x+4);
- dstB_2.x = dstD_2.x = B.Get(x+4);
- dstB_3.x = dstD_3.x = B.Get(x+4);
- dstB_0.y = dstD_0.y = B.Get(x+5);
- dstB_1.y = dstD_1.y = B.Get(x+5);
- dstB_2.y = dstD_2.y = B.Get(x+5);
- dstB_3.y = dstD_3.y = B.Get(x+5);
- dstB_0.z = dstD_0.z = B.Get(x+6);
- dstB_1.z = dstD_1.z = B.Get(x+6);
- dstB_2.z = dstD_2.z = B.Get(x+6);
- dstB_3.z = dstD_3.z = B.Get(x+6);
- dstB_0.w = dstD_0.w = B.Get(x+7);
- dstB_1.w = dstD_1.w = B.Get(x+7);
- dstB_2.w = dstD_2.w = B.Get(x+7);
- dstB_3.w = dstD_3.w = B.Get(x+7);
-#else
- float4 dstA_0[4], dstA_1[4], dstA_2[4], dstA_3[4];
- dstA_0[0].x = dstA_0[2].x = B.Get(x+0);
- dstA_1[0].x = dstA_1[2].x = B.Get(x+0);
- dstA_2[0].x = dstA_2[2].x = B.Get(x+0);
- dstA_3[0].x = dstA_3[2].x = B.Get(x+0);
- dstA_0[0].y = dstA_0[2].y = B.Get(x+1);
- dstA_1[0].y = dstA_1[2].y = B.Get(x+1);
- dstA_2[0].y = dstA_2[2].y = B.Get(x+1);
- dstA_3[0].y = dstA_3[2].y = B.Get(x+1);
- dstA_0[0].z = dstA_0[2].z = B.Get(x+2);
- dstA_1[0].z = dstA_1[2].z = B.Get(x+2);
- dstA_2[0].z = dstA_2[2].z = B.Get(x+2);
- dstA_3[0].z = dstA_3[2].z = B.Get(x+2);
- dstA_0[0].w = dstA_0[2].w = B.Get(x+3);
- dstA_1[0].w = dstA_1[2].w = B.Get(x+3);
- dstA_2[0].w = dstA_2[2].w = B.Get(x+3);
- dstA_3[0].w = dstA_3[2].w = B.Get(x+3);
-
- dstA_0[1].x = dstA_0[3].x = B.Get(x+4);
- dstA_1[1].x = dstA_1[3].x = B.Get(x+4);
- dstA_2[1].x = dstA_2[3].x = B.Get(x+4);
- dstA_3[1].x = dstA_3[3].x = B.Get(x+4);
- dstA_0[1].y = dstA_0[3].y = B.Get(x+5);
- dstA_1[1].y = dstA_1[3].y = B.Get(x+5);
- dstA_2[1].y = dstA_2[3].y = B.Get(x+5);
- dstA_3[1].y = dstA_3[3].y = B.Get(x+5);
- dstA_0[1].z = dstA_0[3].z = B.Get(x+6);
- dstA_1[1].z = dstA_1[3].z = B.Get(x+6);
- dstA_2[1].z = dstA_2[3].z = B.Get(x+6);
- dstA_3[1].z = dstA_3[3].z = B.Get(x+6);
- dstA_0[1].w = dstA_0[3].w = B.Get(x+7);
- dstA_1[1].w = dstA_1[3].w = B.Get(x+7);
- dstA_2[1].w = dstA_2[3].w = B.Get(x+7);
- dstA_3[1].w = dstA_3[3].w = B.Get(x+7);
-
-#endif
-
- for (int i = 0; i < n; i += CACHE_DEPTH)
- {
- #if TRANSPOSED_X == 1
- [unroll]
- for (int j = 0; j < CACHE_DEPTH; ++j)
- {
- X_[ti + j*64] = X.data[strideX * (i + j) + by + ti + offsetX];
-
- // split 64 into 8 blocks and interleave them
- // 000000001111111122222222... => 012345678012345678...
- W_[((ti&7)<<3) + (ti>>3) + j*64] = W.data[strideW * (i + j) + bx + ti + offsetW];
- }
- #else
- int tiDiv = (uint)ti/CACHE_DEPTH;
- int tiMod = ti&(CACHE_DEPTH-1);
- int jStride = 64/CACHE_DEPTH;
-
- [unroll]
- for (int j = 0; j < CACHE_DEPTH; ++j)
- {
- // CACHE_DEPTHx64 => 64xCACHE_DEPTH
- X_[tiDiv + 65*tiMod + j*jStride] = X.data[strideX * (by + tiDiv + j*jStride) + i + tiMod];
-
- // split 64 into 8 blocks and interleave them
- // 000000001111111122222222... => 012345678012345678...
- W_[((ti&7)<<3) + (ti>>3) + j*64] = W.data[strideW * (i + j) + bx + ti + offsetW];
- }
- #endif
-
- GroupMemoryBarrierWithGroupSync();
-
-#if UNROLL_INNER_LOOP
- int4 idX0 = int4(0,1,2,3); int4 idX1 = int4(4,5,6,7);
- int4 idW0 = int4(0,8,16,24); int4 idW1 = int4(32,40,48,56);
-#else
- int4 idX[2], idW[2];
- idX[0] = int4(0,1,2,3); idX[1] = int4(4,5,6,7);
- idW[0] = int4(0,8,16,24); idW[1] = int4(32,40,48,56);
-#endif
- int incX = 64 + (TRANSPOSED_X?0:1);
- int incW = 64;
- for (int di = 0; di < CACHE_DEPTH; di++)
- {
-#if UNROLL_INNER_LOOP
- float4 srcX0 = float4(
- X_[idX0.x + ty*8],
- X_[idX0.y + ty*8],
- X_[idX0.z + ty*8],
- X_[idX0.w + ty*8]);
- float4 srcX1 = float4(
- X_[idX1.x + ty*8],
- X_[idX1.y + ty*8],
- X_[idX1.z + ty*8],
- X_[idX1.w + ty*8]);
- float4 srcW0 = float4(
- W_[idW0.x + tx],
- W_[idW0.y + tx],
- W_[idW0.z + tx],
- W_[idW0.w + tx]);
- float4 srcW1 = float4(
- W_[idW1.x + tx],
- W_[idW1.y + tx],
- W_[idW1.z + tx],
- W_[idW1.w + tx]);
- idX0 += incX; idX1 += incX;
- idW0 += incW; idW1 += incW;
-
- dstA_0.x = ffma(srcX0.x, srcW0.x, dstA_0.x);
- dstA_0.y = ffma(srcX0.x, srcW0.y, dstA_0.y);
- dstA_0.z = ffma(srcX0.x, srcW0.z, dstA_0.z);
- dstA_0.w = ffma(srcX0.x, srcW0.w, dstA_0.w);
- dstA_1.x = ffma(srcX0.y, srcW0.x, dstA_1.x);
- dstA_1.y = ffma(srcX0.y, srcW0.y, dstA_1.y);
- dstA_1.z = ffma(srcX0.y, srcW0.z, dstA_1.z);
- dstA_1.w = ffma(srcX0.y, srcW0.w, dstA_1.w);
- dstA_2.x = ffma(srcX0.z, srcW0.x, dstA_2.x);
- dstA_2.y = ffma(srcX0.z, srcW0.y, dstA_2.y);
- dstA_2.z = ffma(srcX0.z, srcW0.z, dstA_2.z);
- dstA_2.w = ffma(srcX0.z, srcW0.w, dstA_2.w);
- dstA_3.x = ffma(srcX0.w, srcW0.x, dstA_3.x);
- dstA_3.y = ffma(srcX0.w, srcW0.y, dstA_3.y);
- dstA_3.z = ffma(srcX0.w, srcW0.z, dstA_3.z);
- dstA_3.w = ffma(srcX0.w, srcW0.w, dstA_3.w);
-
- //
- dstB_0.x = ffma(srcX0.x, srcW1.x, dstB_0.x);
- dstB_0.y = ffma(srcX0.x, srcW1.y, dstB_0.y);
- dstB_0.z = ffma(srcX0.x, srcW1.z, dstB_0.z);
- dstB_0.w = ffma(srcX0.x, srcW1.w, dstB_0.w);
- dstB_1.x = ffma(srcX0.y, srcW1.x, dstB_1.x);
- dstB_1.y = ffma(srcX0.y, srcW1.y, dstB_1.y);
- dstB_1.z = ffma(srcX0.y, srcW1.z, dstB_1.z);
- dstB_1.w = ffma(srcX0.y, srcW1.w, dstB_1.w);
- dstB_2.x = ffma(srcX0.z, srcW1.x, dstB_2.x);
- dstB_2.y = ffma(srcX0.z, srcW1.y, dstB_2.y);
- dstB_2.z = ffma(srcX0.z, srcW1.z, dstB_2.z);
- dstB_2.w = ffma(srcX0.z, srcW1.w, dstB_2.w);
- dstB_3.x = ffma(srcX0.w, srcW1.x, dstB_3.x);
- dstB_3.y = ffma(srcX0.w, srcW1.y, dstB_3.y);
- dstB_3.z = ffma(srcX0.w, srcW1.z, dstB_3.z);
- dstB_3.w = ffma(srcX0.w, srcW1.w, dstB_3.w);
-
- //
- dstC_0.x = ffma(srcX1.x, srcW0.x, dstC_0.x);
- dstC_0.y = ffma(srcX1.x, srcW0.y, dstC_0.y);
- dstC_0.z = ffma(srcX1.x, srcW0.z, dstC_0.z);
- dstC_0.w = ffma(srcX1.x, srcW0.w, dstC_0.w);
- dstC_1.x = ffma(srcX1.y, srcW0.x, dstC_1.x);
- dstC_1.y = ffma(srcX1.y, srcW0.y, dstC_1.y);
- dstC_1.z = ffma(srcX1.y, srcW0.z, dstC_1.z);
- dstC_1.w = ffma(srcX1.y, srcW0.w, dstC_1.w);
- dstC_2.x = ffma(srcX1.z, srcW0.x, dstC_2.x);
- dstC_2.y = ffma(srcX1.z, srcW0.y, dstC_2.y);
- dstC_2.z = ffma(srcX1.z, srcW0.z, dstC_2.z);
- dstC_2.w = ffma(srcX1.z, srcW0.w, dstC_2.w);
- dstC_3.x = ffma(srcX1.w, srcW0.x, dstC_3.x);
- dstC_3.y = ffma(srcX1.w, srcW0.y, dstC_3.y);
- dstC_3.z = ffma(srcX1.w, srcW0.z, dstC_3.z);
- dstC_3.w = ffma(srcX1.w, srcW0.w, dstC_3.w);
-
- //
- dstD_0.x = ffma(srcX1.x, srcW1.x, dstD_0.x);
- dstD_0.y = ffma(srcX1.x, srcW1.y, dstD_0.y);
- dstD_0.z = ffma(srcX1.x, srcW1.z, dstD_0.z);
- dstD_0.w = ffma(srcX1.x, srcW1.w, dstD_0.w);
- dstD_1.x = ffma(srcX1.y, srcW1.x, dstD_1.x);
- dstD_1.y = ffma(srcX1.y, srcW1.y, dstD_1.y);
- dstD_1.z = ffma(srcX1.y, srcW1.z, dstD_1.z);
- dstD_1.w = ffma(srcX1.y, srcW1.w, dstD_1.w);
- dstD_2.x = ffma(srcX1.z, srcW1.x, dstD_2.x);
- dstD_2.y = ffma(srcX1.z, srcW1.y, dstD_2.y);
- dstD_2.z = ffma(srcX1.z, srcW1.z, dstD_2.z);
- dstD_2.w = ffma(srcX1.z, srcW1.w, dstD_2.w);
- dstD_3.x = ffma(srcX1.w, srcW1.x, dstD_3.x);
- dstD_3.y = ffma(srcX1.w, srcW1.y, dstD_3.y);
- dstD_3.z = ffma(srcX1.w, srcW1.z, dstD_3.z);
- dstD_3.w = ffma(srcX1.w, srcW1.w, dstD_3.w);
-
-#else
- float4 srcX[2], srcW[2];
- srcX[0] = float4(
- X_[idX[0].x + ty*8],
- X_[idX[0].y + ty*8],
- X_[idX[0].z + ty*8],
- X_[idX[0].w + ty*8]);
- srcX[1] = float4(
- X_[idX[1].x + ty*8],
- X_[idX[1].y + ty*8],
- X_[idX[1].z + ty*8],
- X_[idX[1].w + ty*8]);
- srcW[0] = float4(
- W_[idW[0].x + tx],
- W_[idW[0].y + tx],
- W_[idW[0].z + tx],
- W_[idW[0].w + tx]);
- srcW[1] = float4(
- W_[idW[1].x + tx],
- W_[idW[1].y + tx],
- W_[idW[1].z + tx],
- W_[idW[1].w + tx]);
- idX[0] += incX; idX[1] += incX;
- idW[0] += incW; idW[1] += incW;
-
-
- [loop]
- for (uint qw = 0; qw < 4; ++qw)
- {
- uint q = qw >> 1;
- uint w = qw & 1;
- dstA_0[qw].x = ffma(srcX[q].x, srcW[w].x, dstA_0[qw].x);
- dstA_0[qw].y = ffma(srcX[q].x, srcW[w].y, dstA_0[qw].y);
- dstA_0[qw].z = ffma(srcX[q].x, srcW[w].z, dstA_0[qw].z);
- dstA_0[qw].w = ffma(srcX[q].x, srcW[w].w, dstA_0[qw].w);
- dstA_1[qw].x = ffma(srcX[q].y, srcW[w].x, dstA_1[qw].x);
- dstA_1[qw].y = ffma(srcX[q].y, srcW[w].y, dstA_1[qw].y);
- dstA_1[qw].z = ffma(srcX[q].y, srcW[w].z, dstA_1[qw].z);
- dstA_1[qw].w = ffma(srcX[q].y, srcW[w].w, dstA_1[qw].w);
- dstA_2[qw].x = ffma(srcX[q].z, srcW[w].x, dstA_2[qw].x);
- dstA_2[qw].y = ffma(srcX[q].z, srcW[w].y, dstA_2[qw].y);
- dstA_2[qw].z = ffma(srcX[q].z, srcW[w].z, dstA_2[qw].z);
- dstA_2[qw].w = ffma(srcX[q].z, srcW[w].w, dstA_2[qw].w);
- dstA_3[qw].x = ffma(srcX[q].w, srcW[w].x, dstA_3[qw].x);
- dstA_3[qw].y = ffma(srcX[q].w, srcW[w].y, dstA_3[qw].y);
- dstA_3[qw].z = ffma(srcX[q].w, srcW[w].z, dstA_3[qw].z);
- dstA_3[qw].w = ffma(srcX[q].w, srcW[w].w, dstA_3[qw].w);
- }
-#endif
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
-#if UNROLL_INNER_LOOP
- O.data[strideO * (y+0) + x+0 + offsetO] = dstA_0.x;
- O.data[strideO * (y+0) + x+1 + offsetO] = dstA_0.y;
- O.data[strideO * (y+0) + x+2 + offsetO] = dstA_0.z;
- O.data[strideO * (y+0) + x+3 + offsetO] = dstA_0.w;
- O.data[strideO * (y+0) + x+4 + offsetO] = dstB_0.x;
- O.data[strideO * (y+0) + x+5 + offsetO] = dstB_0.y;
- O.data[strideO * (y+0) + x+6 + offsetO] = dstB_0.z;
- O.data[strideO * (y+0) + x+7 + offsetO] = dstB_0.w;
- O.data[strideO * (y+1) + x+0 + offsetO] = dstA_1.x;
- O.data[strideO * (y+1) + x+1 + offsetO] = dstA_1.y;
- O.data[strideO * (y+1) + x+2 + offsetO] = dstA_1.z;
- O.data[strideO * (y+1) + x+3 + offsetO] = dstA_1.w;
- O.data[strideO * (y+1) + x+4 + offsetO] = dstB_1.x;
- O.data[strideO * (y+1) + x+5 + offsetO] = dstB_1.y;
- O.data[strideO * (y+1) + x+6 + offsetO] = dstB_1.z;
- O.data[strideO * (y+1) + x+7 + offsetO] = dstB_1.w;
- O.data[strideO * (y+2) + x+0 + offsetO] = dstA_2.x;
- O.data[strideO * (y+2) + x+1 + offsetO] = dstA_2.y;
- O.data[strideO * (y+2) + x+2 + offsetO] = dstA_2.z;
- O.data[strideO * (y+2) + x+3 + offsetO] = dstA_2.w;
- O.data[strideO * (y+2) + x+4 + offsetO] = dstB_2.x;
- O.data[strideO * (y+2) + x+5 + offsetO] = dstB_2.y;
- O.data[strideO * (y+2) + x+6 + offsetO] = dstB_2.z;
- O.data[strideO * (y+2) + x+7 + offsetO] = dstB_2.w;
- O.data[strideO * (y+3) + x+0 + offsetO] = dstA_3.x;
- O.data[strideO * (y+3) + x+1 + offsetO] = dstA_3.y;
- O.data[strideO * (y+3) + x+2 + offsetO] = dstA_3.z;
- O.data[strideO * (y+3) + x+3 + offsetO] = dstA_3.w;
- O.data[strideO * (y+3) + x+4 + offsetO] = dstB_3.x;
- O.data[strideO * (y+3) + x+5 + offsetO] = dstB_3.y;
- O.data[strideO * (y+3) + x+6 + offsetO] = dstB_3.z;
- O.data[strideO * (y+3) + x+7 + offsetO] = dstB_3.w;
-
- O.data[strideO * (y+4) + x+0 + offsetO] = dstC_0.x;
- O.data[strideO * (y+4) + x+1 + offsetO] = dstC_0.y;
- O.data[strideO * (y+4) + x+2 + offsetO] = dstC_0.z;
- O.data[strideO * (y+4) + x+3 + offsetO] = dstC_0.w;
- O.data[strideO * (y+4) + x+4 + offsetO] = dstD_0.x;
- O.data[strideO * (y+4) + x+5 + offsetO] = dstD_0.y;
- O.data[strideO * (y+4) + x+6 + offsetO] = dstD_0.z;
- O.data[strideO * (y+4) + x+7 + offsetO] = dstD_0.w;
- O.data[strideO * (y+5) + x+0 + offsetO] = dstC_1.x;
- O.data[strideO * (y+5) + x+1 + offsetO] = dstC_1.y;
- O.data[strideO * (y+5) + x+2 + offsetO] = dstC_1.z;
- O.data[strideO * (y+5) + x+3 + offsetO] = dstC_1.w;
- O.data[strideO * (y+5) + x+4 + offsetO] = dstD_1.x;
- O.data[strideO * (y+5) + x+5 + offsetO] = dstD_1.y;
- O.data[strideO * (y+5) + x+6 + offsetO] = dstD_1.z;
- O.data[strideO * (y+5) + x+7 + offsetO] = dstD_1.w;
- O.data[strideO * (y+6) + x+0 + offsetO] = dstC_2.x;
- O.data[strideO * (y+6) + x+1 + offsetO] = dstC_2.y;
- O.data[strideO * (y+6) + x+2 + offsetO] = dstC_2.z;
- O.data[strideO * (y+6) + x+3 + offsetO] = dstC_2.w;
- O.data[strideO * (y+6) + x+4 + offsetO] = dstD_2.x;
- O.data[strideO * (y+6) + x+5 + offsetO] = dstD_2.y;
- O.data[strideO * (y+6) + x+6 + offsetO] = dstD_2.z;
- O.data[strideO * (y+6) + x+7 + offsetO] = dstD_2.w;
- O.data[strideO * (y+7) + x+0 + offsetO] = dstC_3.x;
- O.data[strideO * (y+7) + x+1 + offsetO] = dstC_3.y;
- O.data[strideO * (y+7) + x+2 + offsetO] = dstC_3.z;
- O.data[strideO * (y+7) + x+3 + offsetO] = dstC_3.w;
- O.data[strideO * (y+7) + x+4 + offsetO] = dstD_3.x;
- O.data[strideO * (y+7) + x+5 + offsetO] = dstD_3.y;
- O.data[strideO * (y+7) + x+6 + offsetO] = dstD_3.z;
- O.data[strideO * (y+7) + x+7 + offsetO] = dstD_3.w;
-#else
- O.data[strideO * (y+0) + x+0 + offsetO] = dstA_0[0].x;
- O.data[strideO * (y+0) + x+1 + offsetO] = dstA_0[0].y;
- O.data[strideO * (y+0) + x+2 + offsetO] = dstA_0[0].z;
- O.data[strideO * (y+0) + x+3 + offsetO] = dstA_0[0].w;
- O.data[strideO * (y+0) + x+4 + offsetO] = dstA_0[1].x;
- O.data[strideO * (y+0) + x+5 + offsetO] = dstA_0[1].y;
- O.data[strideO * (y+0) + x+6 + offsetO] = dstA_0[1].z;
- O.data[strideO * (y+0) + x+7 + offsetO] = dstA_0[1].w;
- O.data[strideO * (y+1) + x+0 + offsetO] = dstA_1[0].x;
- O.data[strideO * (y+1) + x+1 + offsetO] = dstA_1[0].y;
- O.data[strideO * (y+1) + x+2 + offsetO] = dstA_1[0].z;
- O.data[strideO * (y+1) + x+3 + offsetO] = dstA_1[0].w;
- O.data[strideO * (y+1) + x+4 + offsetO] = dstA_1[1].x;
- O.data[strideO * (y+1) + x+5 + offsetO] = dstA_1[1].y;
- O.data[strideO * (y+1) + x+6 + offsetO] = dstA_1[1].z;
- O.data[strideO * (y+1) + x+7 + offsetO] = dstA_1[1].w;
- O.data[strideO * (y+2) + x+0 + offsetO] = dstA_2[0].x;
- O.data[strideO * (y+2) + x+1 + offsetO] = dstA_2[0].y;
- O.data[strideO * (y+2) + x+2 + offsetO] = dstA_2[0].z;
- O.data[strideO * (y+2) + x+3 + offsetO] = dstA_2[0].w;
- O.data[strideO * (y+2) + x+4 + offsetO] = dstA_2[1].x;
- O.data[strideO * (y+2) + x+5 + offsetO] = dstA_2[1].y;
- O.data[strideO * (y+2) + x+6 + offsetO] = dstA_2[1].z;
- O.data[strideO * (y+2) + x+7 + offsetO] = dstA_2[1].w;
- O.data[strideO * (y+3) + x+0 + offsetO] = dstA_3[0].x;
- O.data[strideO * (y+3) + x+1 + offsetO] = dstA_3[0].y;
- O.data[strideO * (y+3) + x+2 + offsetO] = dstA_3[0].z;
- O.data[strideO * (y+3) + x+3 + offsetO] = dstA_3[0].w;
- O.data[strideO * (y+3) + x+4 + offsetO] = dstA_3[1].x;
- O.data[strideO * (y+3) + x+5 + offsetO] = dstA_3[1].y;
- O.data[strideO * (y+3) + x+6 + offsetO] = dstA_3[1].z;
- O.data[strideO * (y+3) + x+7 + offsetO] = dstA_3[1].w;
-
- O.data[strideO * (y+4) + x+0 + offsetO] = dstA_0[2].x;
- O.data[strideO * (y+4) + x+1 + offsetO] = dstA_0[2].y;
- O.data[strideO * (y+4) + x+2 + offsetO] = dstA_0[2].z;
- O.data[strideO * (y+4) + x+3 + offsetO] = dstA_0[2].w;
- O.data[strideO * (y+4) + x+4 + offsetO] = dstA_0[3].x;
- O.data[strideO * (y+4) + x+5 + offsetO] = dstA_0[3].y;
- O.data[strideO * (y+4) + x+6 + offsetO] = dstA_0[3].z;
- O.data[strideO * (y+4) + x+7 + offsetO] = dstA_0[3].w;
- O.data[strideO * (y+5) + x+0 + offsetO] = dstA_1[2].x;
- O.data[strideO * (y+5) + x+1 + offsetO] = dstA_1[2].y;
- O.data[strideO * (y+5) + x+2 + offsetO] = dstA_1[2].z;
- O.data[strideO * (y+5) + x+3 + offsetO] = dstA_1[2].w;
- O.data[strideO * (y+5) + x+4 + offsetO] = dstA_1[3].x;
- O.data[strideO * (y+5) + x+5 + offsetO] = dstA_1[3].y;
- O.data[strideO * (y+5) + x+6 + offsetO] = dstA_1[3].z;
- O.data[strideO * (y+5) + x+7 + offsetO] = dstA_1[3].w;
- O.data[strideO * (y+6) + x+0 + offsetO] = dstA_2[2].x;
- O.data[strideO * (y+6) + x+1 + offsetO] = dstA_2[2].y;
- O.data[strideO * (y+6) + x+2 + offsetO] = dstA_2[2].z;
- O.data[strideO * (y+6) + x+3 + offsetO] = dstA_2[2].w;
- O.data[strideO * (y+6) + x+4 + offsetO] = dstA_2[3].x;
- O.data[strideO * (y+6) + x+5 + offsetO] = dstA_2[3].y;
- O.data[strideO * (y+6) + x+6 + offsetO] = dstA_2[3].z;
- O.data[strideO * (y+6) + x+7 + offsetO] = dstA_2[3].w;
- O.data[strideO * (y+7) + x+0 + offsetO] = dstA_3[2].x;
- O.data[strideO * (y+7) + x+1 + offsetO] = dstA_3[2].y;
- O.data[strideO * (y+7) + x+2 + offsetO] = dstA_3[2].z;
- O.data[strideO * (y+7) + x+3 + offsetO] = dstA_3[2].w;
- O.data[strideO * (y+7) + x+4 + offsetO] = dstA_3[3].x;
- O.data[strideO * (y+7) + x+5 + offsetO] = dstA_3[3].y;
- O.data[strideO * (y+7) + x+6 + offsetO] = dstA_3[3].z;
- O.data[strideO * (y+7) + x+7 + offsetO] = dstA_3[3].w;
-#endif
-
- #undef X_
- #undef W_
-}
-#undef TRANSPOSED_X
-#undef BLOCKED_W
-#undef HARDCODED_DIMS
-#undef BUF_OFFSET
-#undef CACHE_DEPTH
-#elif BLOCK_SIZE == 4
-#define TRANSPOSED_X 0
-#define SHIFTED_X 0
-#define CACHE_DEPTH 4
-groupshared float CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, X)[CACHE_DEPTH*8*BLOCK_SIZE+SHIFTED_X*CACHE_DEPTH];
-groupshared float CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, W)[CACHE_DEPTH*8*BLOCK_SIZE];
-[numthreads(8,8,1)]
-void FUNC_NAME(KERNEL_NAME, BLOCK_SIZE)(uint3 dispatchThreadID : SV_DispatchThreadID, uint3 groupThreadID : SV_GroupThreadID, uint threadIndex : SV_GroupIndex)
-{
- DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- int x = (int)dispatchThreadID.x * BLOCK_SIZE;
- int y = (int)dispatchThreadID.y * BLOCK_SIZE;
- int tx = (int)groupThreadID.x;
- int ty = (int)groupThreadID.y;
- int bx = ((int)dispatchThreadID.x - (int)groupThreadID.x) * BLOCK_SIZE;
- int by = ((int)dispatchThreadID.y - (int)groupThreadID.y) * BLOCK_SIZE;
- int ti = (int)threadIndex;
- int n = (int)X.GetFlatWidth();
- int strideX = (int)X.GetFlatWidth();
- int strideW = (int)W.GetFlatWidth();
-
- #define X_ CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, X)
- #define W_ CACHE_NAME(KERNEL_NAME, BLOCK_SIZE, W)
-
- //if (x >= (int)O.GetFlatWidth()) return;
- //if (y >= (int)O.GetFlatHeight()) return;
-
- float4 dstA_0, dstA_1, dstA_2, dstA_3;
-
- dstA_0.x = B.Get(x+0);
- dstA_1.x = B.Get(x+0);
- dstA_2.x = B.Get(x+0);
- dstA_3.x = B.Get(x+0);
- dstA_0.y = B.Get(x+1);
- dstA_1.y = B.Get(x+1);
- dstA_2.y = B.Get(x+1);
- dstA_3.y = B.Get(x+1);
- dstA_0.z = B.Get(x+2);
- dstA_1.z = B.Get(x+2);
- dstA_2.z = B.Get(x+2);
- dstA_3.z = B.Get(x+2);
- dstA_0.w = B.Get(x+3);
- dstA_1.w = B.Get(x+3);
- dstA_2.w = B.Get(x+3);
- dstA_3.w = B.Get(x+3);
-
- for (int i = 0; i < n; i += CACHE_DEPTH)
- {
- #if CACHE_DEPTH == 16
- W_[ti ] = W.data[strideW * (i + (ti>>5) + 0) + bx + (ti&31)];
- W_[ti+ 64] = W.data[strideW * (i + (ti>>5) + 2) + bx + (ti&31)];
- W_[ti+128] = W.data[strideW * (i + (ti>>5) + 4) + bx + (ti&31)];
- W_[ti+192] = W.data[strideW * (i + (ti>>5) + 6) + bx + (ti&31)];
- W_[ti+256] = W.data[strideW * (i + (ti>>5) + 8) + bx + (ti&31)];
- W_[ti+320] = W.data[strideW * (i + (ti>>5) +10) + bx + (ti&31)];
- W_[ti+384] = W.data[strideW * (i + (ti>>5) +12) + bx + (ti&31)];
- W_[ti+448] = W.data[strideW * (i + (ti>>5) +14) + bx + (ti&31)];
- #if TRANSPOSED_X == 1
- X_[ti ] = X.data[strideX * (i + (ti>>5) + 0) + by + (ti&31)];
- X_[ti+ 64] = X.data[strideX * (i + (ti>>5) + 2) + by + (ti&31)];
- X_[ti+128] = X.data[strideX * (i + (ti>>5) + 4) + by + (ti&31)];
- X_[ti+192] = X.data[strideX * (i + (ti>>5) + 6) + by + (ti&31)];
- X_[ti+256] = X.data[strideX * (i + (ti>>5) + 8) + by + (ti&31)];
- X_[ti+320] = X.data[strideX * (i + (ti>>5) +10) + by + (ti&31)];
- X_[ti+384] = X.data[strideX * (i + (ti>>5) +12) + by + (ti&31)];
- X_[ti+448] = X.data[strideX * (i + (ti>>5) +14) + by + (ti&31)];
- #elif SHIFTED_X == 1
- /*
- g=ti/16
- j=ti&15
-
- g0 j0123456789ABCDEF
- g1 j0123456789ABCDEF
- g2 j0123456789ABCDEF
- g3 j0123456789ABCDEF
- g0.j0 g1.j0 g2.j0 g3.j0 g0.j1 g1.j1 g2.j1 g3.j1
-
- 16x32 => 32x16
- */
- X_[(ti>>4) + 33*(ti&15) + 0] = X.data[strideX * (by + (ti>>4) + 0) + i + (ti&15) ];
- X_[(ti>>4) + 33*(ti&15) + 4] = X.data[strideX * (by + (ti>>4) + 4) + i + (ti&15) ];
- X_[(ti>>4) + 33*(ti&15) + 8] = X.data[strideX * (by + (ti>>4) + 8) + i + (ti&15) ];
- X_[(ti>>4) + 33*(ti&15) +12] = X.data[strideX * (by + (ti>>4) +12) + i + (ti&15) ];
- X_[(ti>>4) + 33*(ti&15) +16] = X.data[strideX * (by + (ti>>4) +16) + i + (ti&15) ];
- X_[(ti>>4) + 33*(ti&15) +20] = X.data[strideX * (by + (ti>>4) +20) + i + (ti&15) ];
- X_[(ti>>4) + 33*(ti&15) +24] = X.data[strideX * (by + (ti>>4) +24) + i + (ti&15) ];
- X_[(ti>>4) + 33*(ti&15) +28] = X.data[strideX * (by + (ti>>4) +28) + i + (ti&15) ];
- #else
- //X_[ti] = X.Get(by + (ti/16), i + (ti&15));
- X_[ti ] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) + 0];
- X_[ti+ 64] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) + 2];
- X_[ti+128] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) + 4];
- X_[ti+192] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) + 6];
- X_[ti+256] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) + 8];
- X_[ti+320] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) +10];
- X_[ti+384] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) +12];
- X_[ti+448] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) +14];
- #endif
-
- #elif CACHE_DEPTH == 8
- W_[ti ] = W.data[strideW * (i + (ti>>5) + 0) + bx + (ti&31)];
- W_[ti+ 64] = W.data[strideW * (i + (ti>>5) + 2) + bx + (ti&31)];
- W_[ti+128] = W.data[strideW * (i + (ti>>5) + 4) + bx + (ti&31)];
- W_[ti+192] = W.data[strideW * (i + (ti>>5) + 6) + bx + (ti&31)];
- #if TRANSPOSED_X == 1
- X_[ti ] = X.data[strideX * (i + (ti>>5) + 0) + by + (ti&31)];
- X_[ti+ 64] = X.data[strideX * (i + (ti>>5) + 2) + by + (ti&31)];
- X_[ti+128] = X.data[strideX * (i + (ti>>5) + 4) + by + (ti&31)];
- X_[ti+192] = X.data[strideX * (i + (ti>>5) + 6) + by + (ti&31)];
- #elif SHIFTED_X == 1
- // 8x32 => 32x8
- X_[(ti>>3) + 33*(ti&7) + 0] = X.data[strideX * (by + (ti>>3) + 0) + i + (ti&7) ];
- X_[(ti>>3) + 33*(ti&7) + 8] = X.data[strideX * (by + (ti>>3) + 8) + i + (ti&7) ];
- X_[(ti>>3) + 33*(ti&7) +16] = X.data[strideX * (by + (ti>>3) +16) + i + (ti&7) ];
- X_[(ti>>3) + 33*(ti&7) +24] = X.data[strideX * (by + (ti>>3) +24) + i + (ti&7) ];
- #else
- // 8x32 => 32x8
- X_[ti ] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) + 0];
- X_[ti+ 64] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) + 2];
- X_[ti+128] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) + 4];
- X_[ti+192] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) + 6];
- #endif
-
- #elif CACHE_DEPTH == 4
- W_[ti ] = W.data[strideW * (i + (ti>>5) + 0) + bx + (ti&31)];
- W_[ti+ 64] = W.data[strideW * (i + (ti>>5) + 2) + bx + (ti&31)];
- #if TRANSPOSED_X == 1
- X_[ti ] = X.data[strideX * (i + (ti>>5) + 0) + by + (ti&31)];
- X_[ti+ 64] = X.data[strideX * (i + (ti>>5) + 2) + by + (ti&31)];
- #elif SHIFTED_X == 1
- // 4x32 => 32x4
- X_[(ti>>2) + 33*(ti&3) + 0] = X.data[strideX * (by + (ti>>2) + 0) + i + (ti&3) ];
- X_[(ti>>2) + 33*(ti&3) +16] = X.data[strideX * (by + (ti>>2) + 16) + i + (ti&3) ];
- #else
- // 4x32 => 32x4
- X_[ti ] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) + 0];
- X_[ti+ 64] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) + 2];
- #endif
-
- #elif CACHE_DEPTH == 2
- W_[ti ] = W.data[strideW * (i + (ti>>5) + 0) + bx + (ti&31)];
- #if TRANSPOSED_X == 1
- X_[ti ] = X.data[strideX * (i + (ti>>5) + 0) + by + (ti&31)];
- #elif SHIFTED_X == 1
- // 2x32 => 32x2
- X_[(ti>>1) + 33*(ti&1) + 0] = X.data[strideX * (by + (ti>>1) + 0) + i + (ti&1) ];
- #else
- X_[ti ] = X.data[strideX * (by + (ti&31)) + i + (ti>>5) + 0];
- #endif
-
- #elif CACHE_DEPTH == 1
- if (ti < 32)
- {
- W_[ti] = W.data[strideW * i + bx + ti];
- #if TRANSPOSED_X == 1
- X_[ti] = X.data[strideX * i + by + ti];
- #else
- //X_[ti] = X.Get(by+ti, i);
- X_[ti] = X.data[strideX * (by + ti) + i];
- #endif
- }
- #endif
-
- GroupMemoryBarrierWithGroupSync();
-
- for (int di = 0; di < CACHE_DEPTH; di++)
- {
- int _32 = 32 + SHIFTED_X;
- float4 srcX = float4(
- X_[di*_32 + ty*4 + 0],
- X_[di*_32 + ty*4 + 1],
- X_[di*_32 + ty*4 + 2],
- X_[di*_32 + ty*4 + 3]);
- float4 srcW = float4(
- W_[di*32 + tx*4 + 0],
- W_[di*32 + tx*4 + 1],
- W_[di*32 + tx*4 + 2],
- W_[di*32 + tx*4 + 3]);
-
- dstA_0.x = ffma(srcX.x, srcW.x, dstA_0.x);
- dstA_0.y = ffma(srcX.x, srcW.y, dstA_0.y);
- dstA_0.z = ffma(srcX.x, srcW.z, dstA_0.z);
- dstA_0.w = ffma(srcX.x, srcW.w, dstA_0.w);
-
- dstA_1.x = ffma(srcX.y, srcW.x, dstA_1.x);
- dstA_1.y = ffma(srcX.y, srcW.y, dstA_1.y);
- dstA_1.z = ffma(srcX.y, srcW.z, dstA_1.z);
- dstA_1.w = ffma(srcX.y, srcW.w, dstA_1.w);
-
- dstA_2.x = ffma(srcX.z, srcW.x, dstA_2.x);
- dstA_2.y = ffma(srcX.z, srcW.y, dstA_2.y);
- dstA_2.z = ffma(srcX.z, srcW.z, dstA_2.z);
- dstA_2.w = ffma(srcX.z, srcW.w, dstA_2.w);
-
- dstA_3.x = ffma(srcX.w, srcW.x, dstA_3.x);
- dstA_3.y = ffma(srcX.w, srcW.y, dstA_3.y);
- dstA_3.z = ffma(srcX.w, srcW.z, dstA_3.z);
- dstA_3.w = ffma(srcX.w, srcW.w, dstA_3.w);
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
-
- O.Set(y+0, x+0, dstA_0.x);
- O.Set(y+0, x+1, dstA_0.y);
- O.Set(y+0, x+2, dstA_0.z);
- O.Set(y+0, x+3, dstA_0.w);
- O.Set(y+1, x+0, dstA_1.x);
- O.Set(y+1, x+1, dstA_1.y);
- O.Set(y+1, x+2, dstA_1.z);
- O.Set(y+1, x+3, dstA_1.w);
- O.Set(y+2, x+0, dstA_2.x);
- O.Set(y+2, x+1, dstA_2.y);
- O.Set(y+2, x+2, dstA_2.z);
- O.Set(y+2, x+3, dstA_2.w);
- O.Set(y+3, x+0, dstA_3.x);
- O.Set(y+3, x+1, dstA_3.y);
- O.Set(y+3, x+2, dstA_3.z);
- O.Set(y+3, x+3, dstA_3.w);
- /*for (dx = 0; dx < BLOCK_SIZE; ++dx)
- for (dy = 0; dy < BLOCK_SIZE; ++dy)
- O.Set(y+dy, x+dx, dstA[dy][dx]);
- */
- #undef X_
- #undef W_
-}
-#undef TRANSPOSED_X
-#undef SHIFTED_X
-#undef CACHE_DEPTH
-#else
-[numthreads(8,8,1)]
-void FUNC_NAME(KERNEL_NAME, BLOCK_SIZE)(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- int x = (int)dispatchThreadID.x * BLOCK_SIZE;
- int y = (int)dispatchThreadID.y * BLOCK_SIZE;
- int n = (int)X.GetFlatWidth();
-
- if (x >= (int)O.GetFlatWidth()) return;
- if (y >= (int)O.GetFlatHeight()) return;
-
- float dstA[BLOCK_SIZE][BLOCK_SIZE];
- float srcX[BLOCK_SIZE];
-
- int dy, dx;
- for (dx = 0; dx < BLOCK_SIZE; ++dx)
- for (dy = 0; dy < BLOCK_SIZE; ++dy)
- dstA[dy][dx] = B.data[x+dx+B.offset];//B.Get(x+dx);
-
- for (int i = 0; i < n; ++i)
- {
- for (dy = 0; dy < BLOCK_SIZE; ++dy)
- srcX[dy] = X.data[(y+dy)*X.channels+i];//X.Get(y+dy, i);
-
- for (dx = 0; dx < BLOCK_SIZE; ++dx)
- {
- float srcW = W.data[i*W.channels+x+dx];//W.Get(i, x+dx);
- for (dy = 0; dy < BLOCK_SIZE; ++dy)
- dstA[dy][dx] += srcX[dy] * srcW;
- }
- }
-
- for (dx = 0; dx < BLOCK_SIZE; ++dx)
- for (dy = 0; dy < BLOCK_SIZE; ++dy)
- O.Set(y+dy, x+dx, dstA[dy][dx]);
-}
-#endif
-#undef KERNEL_NAME
-
-#endif // DENSE
-
-// NOTE: usually this path is used for <16 batches
-#undef CACHESIZE
-#define CACHESIZE 64
-groupshared float Dense_L1Cached64_X[CACHESIZE];
-[numthreads(CACHESIZE, 1, 1)]
-void Dense_L1Cached64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- #define X_ Dense_L1Cached64_X
-
- uint x = CACHESIZE * groupID.x + groupThreadID.x;
- uint y = groupID.y;
-
- uint wIndex = W.Index(0, x);
-
- float acc = B.Get(x);
- // loop over X columns (flatWidth) and W rows (height) in CACHESIZE steps
- for (uint i = 0; i < X.GetFlatWidth(); i += CACHESIZE)
- {
- // Cache X
- // coalescent reads
- X_[groupThreadID.x] = X.SafeGet(y, i + groupThreadID.x);
- GroupMemoryBarrierWithGroupSync();
-
- // X * W
- if (i + CACHESIZE <= X.GetFlatWidth())
- {
- [unroll]
- for (uint di = 0; di < CACHESIZE; ++di)
- {
- acc = fastfma(X_[di], W.data[wIndex], acc);
- wIndex += W.GetFlatWidth();
- }
- }
- else
- {
- // handle remainder of the line < CACHESIZE
- for (uint di = 0; i + di < X.GetFlatWidth(); ++di)
- {
- acc = fastfma(X_[di], W.data[wIndex], acc);
- wIndex += W.GetFlatWidth();
- }
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
-
- // needed all threads to load matrix line, x might be out of the bounds for writing
- if (x < O.GetFlatWidth())
- O.Set(y, x, acc);
-
- #undef X_
-}
-
-
-#undef TILE_WIDTH
-#define TILE_WIDTH NUMTHREAD(16,8,8)
-groupshared float DenseTiled_Xcache[TILE_WIDTH][TILE_WIDTH];
-groupshared float DenseTiled_Wcache[TILE_WIDTH][TILE_WIDTH];
-[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
-void DenseTiled16x16(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(O.flatWidth, O.flatHeight, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- #define X_ DenseTiled_Xcache
- #define W_ DenseTiled_Wcache
-
- uint tx = groupThreadID.x;
- uint ty = groupThreadID.y;
- uint x = groupID.x*TILE_WIDTH + tx;
- uint y = groupID.y*TILE_WIDTH + ty;
-
- bool mask = (x < O.GetFlatWidth() && y < O.GetFlatHeight());
-
- float v = B.Get(x);
- for (uint m = 0; m < X.GetFlatWidth()/TILE_WIDTH; ++m)
- {
- if (mask)
- {
- X_[ty][tx] = X.Get(y, m*TILE_WIDTH + tx);
- W_[ty][tx] = W.Get(m*TILE_WIDTH + ty, x);
- }
- else
- {
- X_[ty][tx] = 0;
- W_[ty][tx] = 0;
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- [unroll]
- for (uint i = 0; i < TILE_WIDTH; ++i)
- {
- v = fastfma(X_[ty][i], W_[i][tx], v);
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
-
- if (mask)
- O.Set(y, x, v);
-
- #undef X_
- #undef W_
-}
-
-#undef TILE_WIDTH
-#define TILE_WIDTH NUMTHREAD(16,8,8) // 32 crashes on MacBookPro/AMD
-groupshared float DenseTiled_Xcache32[2*2][TILE_WIDTH][TILE_WIDTH];
-groupshared float DenseTiled_Wcache32[2*2][TILE_WIDTH][TILE_WIDTH];
-[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
-void DenseTiled32x32(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(O.flatWidth / 2, O.flatHeight / 2, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- #define X_ DenseTiled_Xcache32
- #define W_ DenseTiled_Wcache32
-
- uint tx = groupThreadID.x;
- uint ty = groupThreadID.y;
- uint x = groupID.x*TILE_WIDTH + tx;
- uint y = groupID.y*TILE_WIDTH + ty;
-
- float b0 = B.Get(x*2+0);
- float b1 = B.Get(x*2+1);
- float4 v = float4(b0, b1,
- b0, b1);
-
- for (uint m = 0; m < X.GetFlatWidth()/(TILE_WIDTH*2);)
- {
- float x0 = X.Get(y*2+0, m*TILE_WIDTH*2 + tx*2+0);
- float x1 = X.Get(y*2+0, m*TILE_WIDTH*2 + tx*2+1);
- float x2 = X.Get(y*2+1, m*TILE_WIDTH*2 + tx*2+0);
- float x3 = X.Get(y*2+1, m*TILE_WIDTH*2 + tx*2+1);
-
- float w0 = W.Get(m*TILE_WIDTH*2 + ty*2+0, x*2+0);
- float w1 = W.Get(m*TILE_WIDTH*2 + ty*2+0, x*2+1);
- float w2 = W.Get(m*TILE_WIDTH*2 + ty*2+1, x*2+0);
- float w3 = W.Get(m*TILE_WIDTH*2 + ty*2+1, x*2+1);
-
- ++m;
-
- X_[0][ty][tx] = x0;
- X_[1][ty][tx] = x1;
- X_[2][ty][tx] = x2;
- X_[3][ty][tx] = x3;
-
- W_[0][ty][tx] = w0;
- W_[1][ty][tx] = w1;
- W_[2][ty][tx] = w2;
- W_[3][ty][tx] = w3;
-
- GroupMemoryBarrierWithGroupSync();
-
- [unroll]
- for (uint i = 0; i < TILE_WIDTH; ++i)
- {
- float4 x =
- float4( X_[0][ty][i],
- X_[1][ty][i],
- X_[2][ty][i],
- X_[3][ty][i]);
- float4 w =
- float4( W_[0][i][tx],
- W_[1][i][tx],
- W_[2][i][tx],
- W_[3][i][tx]);
-
- v.x = fastfma(w.x, x.x, v.x);
- v.y = fastfma(w.y, x.x, v.y);
- v.z = fastfma(w.x, x.z, v.z);
- v.w = fastfma(w.y, x.z, v.w);
-
- v.x = fastfma(w.z, x.y, v.x);
- v.y = fastfma(w.w, x.y, v.y);
- v.z = fastfma(w.z, x.w, v.z);
- v.w = fastfma(w.w, x.w, v.w);
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
-
- O.Set(y*2+0, x*2+0, v.x);
- O.Set(y*2+0, x*2+1, v.y);
- O.Set(y*2+1, x*2+0, v.z);
- O.Set(y*2+1, x*2+1, v.w);
-
- #undef X_
- #undef W_
-}
-
-#undef TILE_WIDTH
-#define TILE_WIDTH NUMTHREAD(16,8,8)
-groupshared float DenseTiled_Xcache64[4*4][TILE_WIDTH*TILE_WIDTH];
-groupshared float DenseTiled_Wcache64[4*4][TILE_WIDTH*TILE_WIDTH];
-[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
-void DenseTiled64x64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(O.flatWidth / 4, O.flatHeight / 4, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- #define X_ DenseTiled_Xcache64
- #define W_ DenseTiled_Wcache64
-
- uint tx = groupThreadID.x;
- uint ty = groupThreadID.y;
- uint x = groupID.x*TILE_WIDTH + tx;
- uint y = groupID.y*TILE_WIDTH + ty;
-
- float b0 = B.Get(x*4+0);
- float b1 = B.Get(x*4+1);
- float b2 = B.Get(x*4+2);
- float b3 = B.Get(x*4+3);
-
- float4 v0, v1, v2, v3;
- v0 = v1 = v2 = v3 = float4(b0, b1, b2, b3);
-
- for (uint m = 0; m < X.GetFlatWidth()/(TILE_WIDTH*4); ++m)
- {
- for (uint yy = 0; yy < 4; ++yy)
- for (uint xx = 0; xx < 4; ++xx)
- {
- X_[yy*4+xx][ty*TILE_WIDTH+tx] = X.Get(y*4+yy, (m*TILE_WIDTH + tx)*4+xx);
- W_[yy*4+xx][ty*TILE_WIDTH+tx] = W.Get((m*TILE_WIDTH + ty)*4+yy, x*4+xx);
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- for (uint i = 0; i < TILE_WIDTH; ++i)
- {
- [unroll]
- for (uint q = 0; q < 4; ++q)
- {
- float x0 = X_[0*4+q][ty*TILE_WIDTH+i];
- float x1 = X_[1*4+q][ty*TILE_WIDTH+i];
- float x2 = X_[2*4+q][ty*TILE_WIDTH+i];
- float x3 = X_[3*4+q][ty*TILE_WIDTH+i];
-
- float w0 = W_[q*4+0][i*TILE_WIDTH+tx];
- float w1 = W_[q*4+1][i*TILE_WIDTH+tx];
- float w2 = W_[q*4+2][i*TILE_WIDTH+tx];
- float w3 = W_[q*4+3][i*TILE_WIDTH+tx];
-
- v0.x = fastfma(x0, w0, v0.x); //--
- v1.x = fastfma(x1, w0, v1.x);
- v2.x = fastfma(x2, w0, v2.x);
- v3.x = fastfma(x3, w0, v3.x);
- v0.y = fastfma(x0, w1, v0.y); //--
- v1.y = fastfma(x1, w1, v1.y);
- v2.y = fastfma(x2, w1, v2.y);
- v3.y = fastfma(x3, w1, v3.y);
- v0.z = fastfma(x0, w2, v0.z); //--
- v1.z = fastfma(x1, w2, v1.z);
- v2.z = fastfma(x2, w2, v2.z);
- v3.z = fastfma(x3, w2, v3.z);
- v0.w = fastfma(x0, w3, v0.w); //--
- v1.w = fastfma(x1, w3, v1.w);
- v2.w = fastfma(x2, w3, v2.w);
- v3.w = fastfma(x3, w3, v3.w);
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
-
- O.Set(y*4+0, x*4+0, v0.x);
- O.Set(y*4+0, x*4+1, v0.y);
- O.Set(y*4+0, x*4+2, v0.z);
- O.Set(y*4+0, x*4+3, v0.w);
-
- O.Set(y*4+1, x*4+0, v1.x);
- O.Set(y*4+1, x*4+1, v1.y);
- O.Set(y*4+1, x*4+2, v1.z);
- O.Set(y*4+1, x*4+3, v1.w);
-
- O.Set(y*4+2, x*4+0, v2.x);
- O.Set(y*4+2, x*4+1, v2.y);
- O.Set(y*4+2, x*4+2, v2.z);
- O.Set(y*4+2, x*4+3, v2.w);
-
- O.Set(y*4+3, x*4+0, v3.x);
- O.Set(y*4+3, x*4+1, v3.y);
- O.Set(y*4+3, x*4+2, v3.z);
- O.Set(y*4+3, x*4+3, v3.w);
-
- #undef X_
- #undef W_
-}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Dense.compute.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Dense.compute.meta
deleted file mode 100644
index 33ad83caf1..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Dense.compute.meta
+++ /dev/null
@@ -1,9 +0,0 @@
-fileFormatVersion: 2
-guid: 6b08c0ac202ad41deb8881132b21894c
-timeCreated: 1507457322
-licenseType: Pro
-ComputeShaderImporter:
- currentAPIMask: 196608
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/DenseFP16.compute b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/DenseFP16.compute
deleted file mode 100644
index 7f9f763144..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/DenseFP16.compute
+++ /dev/null
@@ -1,72 +0,0 @@
-#pragma kernel DenseFP16Div2
-
-#include "Tensor.cginc"
-
-TENSOR_DECL(X)
-TENSOR_DECL(W)
-TENSOR_DECL(B)
-TENSOR_DECL(WBK)
-TENSOR_DECL_RW(O)
-
-float f16tof32_(uint src)
-{
- // Based on Fabian Giesen's public domain half_to_float_fast3
- const uint magic = 113 << 23;
- const uint shiftedExp = 0x7c00 << 13; // exponent mask after shift
-
- // Mask out sign bit
- uint o = src & 0x7fff;
- if (o)
- {
- // Move exponent + mantissa to correct bits
- o <<= 13;
- uint exponent = o & shiftedExp;
- if (exponent == 0)
- {
- // Handle denormal
- o = asuint(asfloat(o + magic) - asfloat(magic));
- }
- else if (exponent == shiftedExp) // Inf/NaN
- o += (255 - 31) << 23;
- else
- o += (127 - 15) << 23;
- }
-
- // Copy sign bit
- o |= (src & 0x8000) << 16;
-
- return asfloat(o);
-}
-
-float2 Unpack(SharedTensor t, uint y, uint x)
-{
- uint v = asuint(t.data[t.Index(y, x) >> 1]);
- // TEMPORARY: f16tof32 is broken in GLSL/Metal compiler
- // using custom conversion function for now
- //return float2(f16tof32(v), f16tof32(v>>16));
- return float2(f16tof32_(v), f16tof32_(v>>16));
-}
-
-// NOTE: usually this path is used for <16 batches
-NUMTHREADS((256,1,1), (128,1,1), (64,1,1))
-void DenseFP16Div2(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.flatWidth/2, O.flatHeight, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- uint x = dispatchThreadID.x;
- uint y = dispatchThreadID.y;
-
- if (x*2 >= O.GetFlatWidth()) return;
- if (y >= O.GetFlatHeight()) return;
-
- float2 acc = Unpack(B, 0, x*2);
- for (uint i = 0; i < X.width; ++i)
- {
- float2 w = Unpack(W, i, x*2);
- acc += X.Get(y, i) * w;
- }
-
- O.Set(y, x*2+0, acc[0]);
- O.Set(y, x*2+1, acc[1]);
-}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/DenseFP16.compute.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/DenseFP16.compute.meta
deleted file mode 100644
index f0111a6226..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/DenseFP16.compute.meta
+++ /dev/null
@@ -1,9 +0,0 @@
-fileFormatVersion: 2
-guid: cff3cb66e54744fa4888ef91a11ec90c
-timeCreated: 1508334838
-licenseType: Pro
-ComputeShaderImporter:
- currentAPIMask: 196608
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Experimental.compute b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Experimental.compute
deleted file mode 100644
index 76856062aa..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Experimental.compute
+++ /dev/null
@@ -1,4284 +0,0 @@
-#if EXPERIMENTAL_KERNELS_ENABLED
-/*
-#pragma kernel Dense
-#pragma kernel DenseTiled
-#pragma kernel Dense10x16
-#pragma kernel DenseTiled32x32
-#pragma kernel DenseTiled64x64
-#pragma kernel Dense64
-#pragma kernel Relu
-#pragma kernel Relu256xV
-#pragma kernel Relu16x16
-#pragma kernel ReluChannelsFirst16x2x16
-#pragma kernel Relu_Cmod16_CNyx
-#pragma kernel Relu_Nyxc
-#pragma kernel Softmax
-#pragma kernel Softmax256x2
-#pragma kernel MaxPooling2D
-#pragma kernel MaxPooling2D16x4x4
-*/
-/*
-#pragma kernel Conv2D_Kernel3x3_32Channel
-#pragma kernel Conv2D_Kernel3x3_1Channel
-#pragma kernel Conv2D
-//#pragma kernel Conv2DTiled16x16_Kernel3x3
-#pragma kernel Conv2DTiled14x14_Kernel3x3
-#pragma kernel Conv2DTiled13x13_Kernel3x3
-//#pragma kernel Conv2DTiled12x12_Kernel3x3
-#pragma kernel Fill
-
-#pragma kernel Conv2D_Kernel3x3_Kmod16_Cmod4_KN
-#pragma kernel Conv2D_Kernel3x3_Kmod16_Cmod4_KNyx
-//#pragma kernel Conv2D_Kernel3x3_Cache_KCmod32_KNyx
-//#pragma kernel Conv2D_Kernel3x3_Cache_KCmod64_KNyx
-*/
-
-
-// @TODO: BIAS and WEIGHTS have changed format
-// BIAS (0,0,x,0) -> (0,0,0,x) --> (x)
-// WEIGHTS (y,0,x,0) -> (y,0,0,x) --> (y,x)
-// DENSE_OUT (y,0,x,0) -> (y,0,0,x) --> (y,x)
-
-
-//#pragma kernel Conv2D_Kmod16_Nmod8_KNY
-//#pragma kernel Conv2D_Kernel3x3_64
-
-#define BOUNDS_CHECKS 0
-
-RWStructuredBuffer Edata;
-
-struct Tensor
-{
- uint batch, height, width, channels;
- uint offset;
- uint dataLength;
-
- uint Index(uint b, uint h, uint w, uint ch)
- {
- uint index =
- b * height * width * channels +
- h * width * channels +
- w * channels +
- ch;
- return index + offset;
- }
- void Set(uint b, uint h, uint w, uint ch, float v, RWStructuredBuffer data)
- {
- data[Index(b,h,w,ch)] = v;
- }
- void Set(int b, uint h, uint w, uint ch, float v, RWStructuredBuffer data, int dataLength)
- {
- uint index = Index(b,h,w,ch);
- #if BOUNDS_CHECKS
- if (index < 0 || index >= dataLength)
- {
- InterlockedAdd(Edata[1], 1);
- return;
- }
- #endif
-
- data[Index(b,h,w,ch)] = v;
- }
-
- float Get(uint b, uint h, uint w, uint ch, StructuredBuffer data)
- {
- return data[Index(b,h,w,ch)];
- }
- float Get(uint b, uint h, uint w, uint ch, StructuredBuffer data, int dataLength)
- {
- int index = Index(b,h,w,ch);
- #if BOUNDS_CHECKS
- if (index < 0 || index >= dataLength)
- {
- InterlockedAdd(Edata[0], 1);
- return 0.0f;
- }
- #endif
-
- return data[Index(b,h,w,ch)];
- }
-};
-
-#define X ((Tensor)Xdecl)
-int4 Xdecl[2];
-StructuredBuffer Xdata;
-
-#define O ((Tensor)Odecl)
-int4 Odecl[2];
-RWStructuredBuffer Odata;
-
-#define W ((Tensor)Wdecl)
-int4 Wdecl[2];
-
-#define B ((Tensor)Bdecl)
-int4 Bdecl[2];
-
-#define K ((Tensor)Kdecl)
-int4 Kdecl[2];
-
-#define WBK ((Tensor)WBKdecl)
-int4 WBKdecl[2];
-StructuredBuffer WBKdata;
-
-uint _FilterSize;
-uint _Border;
-uint _Offset;
-
-[numthreads(1,1,1)]
-void Dense(uint3 groupID : SV_GroupID)
-{
- uint b = groupID.y;
- uint x = groupID.x;
- float v = B.Get(0, 0, x, 0, WBKdata, WBK.dataLength);
- for (uint i = 0; i < X.width; ++i)
- v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength);
-
- O.Set(b, 0, x, 0, v, Odata, O.dataLength);
-}
-
-[numthreads(10,16,1)]
-void Dense10x16(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint x = 10*groupID.x + groupThreadID.x;
- uint b = 16*groupID.y + groupThreadID.y;
- float v = B.Get(0, 0, x, 0, WBKdata, WBK.dataLength);
-
- for (uint i = 0; i < X.width;)
- {
- // can unroll up to 16 because numthreads.y=16
- v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
- v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
- v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
- v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
-
- v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
- v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
- v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
- v += X.Get(b, 0, i, 0, Xdata) * W.Get(0, i, x, 0, WBKdata, WBK.dataLength); ++i;
- }
- O.Set(b, 0, x, 0, v, Odata);
-}
-
-
-#undef THREAD_COUNT
-#define THREAD_COUNT 64 // ATM support only 8x8
-
-#undef BLOCK_WIDTH
-#define BLOCK_WIDTH 8
-
-#undef LOAD_WIDTH
-#define LOAD_WIDTH THREAD_COUNT
-
-#undef LOAD_DEPTH
-#define LOAD_DEPTH BLOCK_WIDTH
-
-groupshared float Conv_KcacheR[LOAD_DEPTH][LOAD_WIDTH];
-groupshared float Conv_XcacheR[LOAD_DEPTH][LOAD_WIDTH];
-[numthreads(THREAD_COUNT, 1, 1)]
-void Conv2D_Kernel3x3_64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ Conv_XcacheR
- #define K_ Conv_KcacheR
-
- uint id = groupThreadID.x;
- uint bx = groupID.x;
- uint by = groupID.y;
-
- uint bbx = id % BLOCK_WIDTH;
- uint bby = id / BLOCK_WIDTH;
-
- uint width = O.width;
- uint height = O.height;
-
- // ASSERT(LOAD_WIDTH == THREAD_COUNT)
- uint loadNYX = by*LOAD_WIDTH + id; // only works for 8x8
- uint loadX = loadNYX % width;
- uint loadNY = loadNYX / width;
- uint loadY = loadNY % height;
- uint loadN = loadNY / height;
-
- float v[BLOCK_WIDTH][BLOCK_WIDTH];
- for (uint yy = 0; yy < BLOCK_WIDTH; ++yy)
- for (uint xx = 0; xx < BLOCK_WIDTH; ++xx)
- {
- float bias = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xx, 0, WBKdata, WBK.dataLength);
- v[yy][xx] = bias;
- }
-
- for (uint dy = 0; dy < 3; ++dy)
- {
- bool mask = true;
-
- if (loadY+dy < _Offset) mask = false;
- if (loadY+dy-_Offset >= X.height) mask = false;
-
- for (uint dx = 0; dx < 3; ++dx)
- {
- if (loadX+dx < _Offset) mask = false;
- if (loadX+dx-_Offset >= X.width) mask = false;
-
- for (uint m = 0; m < X.channels/LOAD_DEPTH; ++m)
- {
- for (uint q = 0; q < LOAD_DEPTH; ++q)
- {
- if (mask)
- X_[q][id] = X.Get(loadN, loadY+dy-_Offset, loadX+dx-_Offset, m*LOAD_DEPTH + q, Xdata);
- else
- X_[q][id] = 0;
- K_[q][id] = K.Get(dy, dx, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id, WBKdata, WBK.dataLength);
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
- [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
- [unroll] for (uint i = 0; i < LOAD_DEPTH; ++i)
- {
- v[yyy][xxx] += X_[i][bby*BLOCK_WIDTH + yyy] * K_[i][bbx*BLOCK_WIDTH + xxx];
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
- for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
- {
- //O.Set(by*LOAD_WIDTH + bby*BLOCK_WIDTH + yyy, y, x, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xxx, v[yyy][xxx], Odata);
- uint saveNYX = by*LOAD_WIDTH + bby*BLOCK_WIDTH + yyy;
- //uint saveNYX = by*LOAD_WIDTH + ((id>>3)<<3) + yyy;
- uint saveX = saveNYX % width;
- uint saveNY = saveNYX / width;
- uint saveY = saveNY % height;
- uint saveN = saveNY / height;
-
- uint saveK = bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xxx;
- O.Set(saveN, saveY, saveX, saveK, v[yyy][xxx], Odata);
- }
-
- #undef X_
- #undef K_
-}
-
-
-#undef THREAD_COUNT
-#define THREAD_COUNT 64 // ATM support only 8x8
-
-#undef BLOCK_WIDTH
-#define BLOCK_WIDTH 8
-
-#undef LOAD_WIDTH
-#define LOAD_WIDTH THREAD_COUNT
-
-#undef LOAD_DEPTH
-#define LOAD_DEPTH BLOCK_WIDTH
-
-#if 1
-
-groupshared float DenseTiled_XcacheR[32][LOAD_WIDTH];
-groupshared float DenseTiled_WcacheR[LOAD_DEPTH][LOAD_WIDTH];
-
-[numthreads(THREAD_COUNT, 1, 1)]
-void Dense64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ DenseTiled_XcacheR
- #define W_ DenseTiled_WcacheR
-
- uint id = groupThreadID.x;
- uint bx = groupID.x;
- uint by = groupID.y;
-
- uint bbx = id % BLOCK_WIDTH;
- uint bby = id / BLOCK_WIDTH;
-
- float v[BLOCK_WIDTH][BLOCK_WIDTH];
- for (uint yy = 0; yy < BLOCK_WIDTH; ++yy)
- for (uint xx = 0; xx < BLOCK_WIDTH; ++xx)
- {
- float bias = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xx, 0, WBKdata, WBK.dataLength);
- v[yy][xx] = bias;
- }
-
- for (uint m = 0; m < X.width/LOAD_DEPTH; ++m)
- {
- for (uint q = 0; q < LOAD_DEPTH; ++q)
- {
- X_[q][id] = X.Get(by*LOAD_WIDTH + id, 0, m*LOAD_DEPTH + q, 0, Xdata);
- W_[q][id] = W.Get(0, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id, 0, WBKdata, WBK.dataLength);
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
- {
- X_[yyy][id] = X.Get(by*LOAD_WIDTH + id, 0, m*LOAD_DEPTH + yyy, 0, Xdata);
- [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
- [unroll] for (uint i = 0; i < LOAD_DEPTH; ++i)
- {
- v[yyy][xxx] += X_[i][bby*BLOCK_WIDTH + yyy] * W_[i][bbx*BLOCK_WIDTH + xxx];
- }
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
-
- for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
- for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
- O.Set(by*LOAD_WIDTH + bby*BLOCK_WIDTH + yyy, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xxx, 0, v[yyy][xxx], Odata);
-
- #undef X_
- #undef W_
-}
-
-#elif 1
-groupshared float DenseTiled_XcacheR[LOAD_DEPTH][LOAD_WIDTH];
-groupshared float DenseTiled_WcacheR[LOAD_DEPTH][LOAD_WIDTH];
-
-[numthreads(THREAD_COUNT, 1, 1)]
-void Dense64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ DenseTiled_XcacheR
- #define W_ DenseTiled_WcacheR
-
- uint id = groupThreadID.x;
- uint bx = groupID.x;
- uint by = groupID.y;
-
- uint bbx = id % BLOCK_WIDTH;
- uint bby = id / BLOCK_WIDTH;
-
- float v[BLOCK_WIDTH][BLOCK_WIDTH];
- for (uint yy = 0; yy < BLOCK_WIDTH; ++yy)
- for (uint xx = 0; xx < BLOCK_WIDTH; ++xx)
- {
- float bias = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xx, 0, WBKdata, WBK.dataLength);
- v[yy][xx] = bias;
- }
-
- for (uint m = 0; m < X.width/LOAD_DEPTH; ++m)
- {
- for (uint q = 0; q < LOAD_DEPTH; ++q)
- {
- X_[q][id] = X.Get(by*LOAD_WIDTH + id, 0, m*LOAD_DEPTH + q, 0, Xdata);
- W_[q][id] = W.Get(0, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id, 0, WBKdata, WBK.dataLength);
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
- [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
- [unroll] for (uint i = 0; i < LOAD_DEPTH; ++i)
- {
- //v[yyy][xxx] += X_[i][bby*BLOCK_WIDTH + yyy] * W_[i][bbx*BLOCK_WIDTH + xxx];
- v[yyy][xxx] = mad(X_[i][bby*BLOCK_WIDTH + yyy], W_[i][bbx*BLOCK_WIDTH + xxx], v[yyy][xxx]);
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
-
- for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
- for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
- O.Set(by*LOAD_WIDTH + bby*BLOCK_WIDTH + yyy, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xxx, 0, v[yyy][xxx], Odata);
-
- #undef X_
- #undef W_
-}
-
-#elif 1
-
-// unroll array to help some "naive" compilers to map to regs
-// could be easier to lay out zigzagging patterns
-groupshared float DenseTiled_XcacheR[LOAD_DEPTH][LOAD_WIDTH];
-groupshared float DenseTiled_WcacheR[LOAD_DEPTH][LOAD_WIDTH];
-
-[numthreads(THREAD_COUNT, 1, 1)]
-void Dense64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ DenseTiled_XcacheR
- #define W_ DenseTiled_WcacheR
-
- uint id = groupThreadID.x;
- uint bx = groupID.x;
- uint by = groupID.y;
-
- uint bbx = id % BLOCK_WIDTH;
- uint bby = id / BLOCK_WIDTH;
-
- //float v[BLOCK_WIDTH][BLOCK_WIDTH];
- float
- v00, v01, v02, v03, v04, v05, v06, v07,
- v10, v11, v12, v13, v14, v15, v16, v17,
- v20, v21, v22, v23, v24, v25, v26, v27,
- v30, v31, v32, v33, v34, v35, v36, v37,
- v40, v41, v42, v43, v44, v45, v46, v47,
- v50, v51, v52, v53, v54, v55, v56, v57,
- v60, v61, v62, v63, v64, v65, v66, v67,
- v70, v71, v72, v73, v74, v75, v76, v77;
-
- float b0 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 0, 0, WBKdata, WBK.dataLength);
- float b1 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 1, 0, WBKdata, WBK.dataLength);
- float b2 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 2, 0, WBKdata, WBK.dataLength);
- float b3 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 3, 0, WBKdata, WBK.dataLength);
- float b4 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 4, 0, WBKdata, WBK.dataLength);
- float b5 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 5, 0, WBKdata, WBK.dataLength);
- float b6 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 6, 0, WBKdata, WBK.dataLength);
- float b7 = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + 7, 0, WBKdata, WBK.dataLength);
-
- #define L_(y, x) v##y##x = b##x
- L_(0,0); L_(0,1); L_(0,2); L_(0,3); L_(0,4); L_(0,5); L_(0,6); L_(0,7);
- L_(1,0); L_(1,1); L_(1,2); L_(1,3); L_(1,4); L_(1,5); L_(1,6); L_(1,7);
- L_(2,0); L_(2,1); L_(2,2); L_(2,3); L_(2,4); L_(2,5); L_(2,6); L_(2,7);
- L_(3,0); L_(3,1); L_(3,2); L_(3,3); L_(3,4); L_(3,5); L_(3,6); L_(3,7);
- L_(4,0); L_(4,1); L_(4,2); L_(4,3); L_(4,4); L_(4,5); L_(4,6); L_(4,7);
- L_(5,0); L_(5,1); L_(5,2); L_(5,3); L_(5,4); L_(5,5); L_(5,6); L_(5,7);
- L_(6,0); L_(6,1); L_(6,2); L_(6,3); L_(6,4); L_(6,5); L_(6,6); L_(6,7);
- L_(7,0); L_(7,1); L_(7,2); L_(7,3); L_(7,4); L_(7,5); L_(7,6); L_(7,7);
- #undef L_
-
- for (uint m = 0; m < X.width/LOAD_DEPTH; ++m)
- {
- for (uint q = 0; q < LOAD_DEPTH; ++q)
- {
- X_[q][id] = X.Get(by*LOAD_WIDTH + id, 0, m*LOAD_DEPTH + q, 0, Xdata);
- W_[q][id] = W.Get(0, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id, 0, WBKdata, WBK.dataLength);
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- [unroll] for (uint i = 0; i < LOAD_DEPTH; ++i)
- {
- //v[yyy][xxx] += X_[i][bby*BLOCK_WIDTH + yyy] * W_[i][bbx*BLOCK_WIDTH + xxx];
- #define XW_(y, x) v##y##x += X_[i][bby*BLOCK_WIDTH + ##y] * W_[i][bbx*BLOCK_WIDTH + ##x]
- XW_(0,0); XW_(0,1); XW_(0,2); XW_(0,3); XW_(0,4); XW_(0,5); XW_(0,6); XW_(0,7);
- XW_(1,0); XW_(1,1); XW_(1,2); XW_(1,3); XW_(1,4); XW_(1,5); XW_(1,6); XW_(1,7);
- XW_(2,0); XW_(2,1); XW_(2,2); XW_(2,3); XW_(2,4); XW_(2,5); XW_(2,6); XW_(2,7);
- XW_(3,0); XW_(3,1); XW_(3,2); XW_(3,3); XW_(3,4); XW_(3,5); XW_(3,6); XW_(3,7);
- XW_(4,0); XW_(4,1); XW_(4,2); XW_(4,3); XW_(4,4); XW_(4,5); XW_(4,6); XW_(4,7);
- XW_(5,0); XW_(5,1); XW_(5,2); XW_(5,3); XW_(5,4); XW_(5,5); XW_(5,6); XW_(5,7);
- XW_(6,0); XW_(6,1); XW_(6,2); XW_(6,3); XW_(6,4); XW_(6,5); XW_(6,6); XW_(6,7);
- XW_(7,0); XW_(7,1); XW_(7,2); XW_(7,3); XW_(7,4); XW_(7,5); XW_(7,6); XW_(7,7);
- #undef XW_
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
-
- #define S_(a, b) O.Set(by*LOAD_WIDTH + bby*BLOCK_WIDTH + ##a, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + ##b, 0, v##a##b, Odata)
- S_(0,0); S_(0,1); S_(0,2); S_(0,3); S_(0,4); S_(0,5); S_(0,6); S_(0,7);
- S_(1,0); S_(1,1); S_(1,2); S_(1,3); S_(1,4); S_(1,5); S_(1,6); S_(1,7);
- S_(2,0); S_(2,1); S_(2,2); S_(2,3); S_(2,4); S_(2,5); S_(2,6); S_(2,7);
- S_(3,0); S_(3,1); S_(3,2); S_(3,3); S_(3,4); S_(3,5); S_(3,6); S_(3,7);
- S_(4,0); S_(4,1); S_(4,2); S_(4,3); S_(4,4); S_(4,5); S_(4,6); S_(4,7);
- S_(5,0); S_(5,1); S_(5,2); S_(5,3); S_(5,4); S_(5,5); S_(5,6); S_(5,7);
- S_(6,0); S_(6,1); S_(6,2); S_(6,3); S_(6,4); S_(6,5); S_(6,6); S_(6,7);
- S_(7,0); S_(7,1); S_(7,2); S_(7,3); S_(7,4); S_(7,5); S_(7,6); S_(7,7);
- #undef S_
-
- #undef X_
- #undef W_
-}
-
-#elif 1
-
-groupshared float DenseTiled_XcacheR[2][LOAD_DEPTH][LOAD_WIDTH];
-groupshared float DenseTiled_WcacheR[2][LOAD_DEPTH][LOAD_WIDTH];
-
-[numthreads(THREAD_COUNT, 1, 1)]
-void Dense64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ DenseTiled_XcacheR
- #define W_ DenseTiled_WcacheR
-
- uint id = groupThreadID.x;
- uint bx = groupID.x;
- uint by = groupID.y;
-
- uint bbx = id % BLOCK_WIDTH;
- uint bby = id / BLOCK_WIDTH;
-
- float v[BLOCK_WIDTH][BLOCK_WIDTH];
- for (uint yy = 0; yy < BLOCK_WIDTH; ++yy)
- [unroll] for (uint xx = 0; xx < BLOCK_WIDTH; ++xx)
- {
- float bias = B.Get(0, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xx, 0, WBKdata, WBK.dataLength);
- v[yy][xx] = bias;
- }
-
- uint m = 0;
- for (uint q = 0; q < LOAD_DEPTH; ++q)
- {
- X_[0][q][id] = X.Get(by*LOAD_WIDTH + id, 0, m*LOAD_DEPTH + q, 0, Xdata);
- W_[0][q][id] = W.Get(0, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id, 0, WBKdata, WBK.dataLength);
- }
- GroupMemoryBarrierWithGroupSync();
-
- ++m;
-
- for (; m < X.width/LOAD_DEPTH; ++m)
- {
- for (uint q = 0; q < LOAD_DEPTH; ++q)
- {
- X_[1][q][id] = X.Get(by*LOAD_WIDTH + id, 0, m*LOAD_DEPTH + q, 0, Xdata);
- W_[1][q][id] = W.Get(0, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id, 0, WBKdata, WBK.dataLength);
- }
-
- for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
- [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
- [unroll]
- for (uint i = 0; i < LOAD_DEPTH; ++i)
- {
- v[yyy][xxx] += X_[0][i][bby*BLOCK_WIDTH + yyy] * W_[0][i][bbx*BLOCK_WIDTH + xxx];
- }
-
- ++m;
- GroupMemoryBarrierWithGroupSync();
-
- if (m < X.width/LOAD_DEPTH)
- {
- for (uint q = 0; q < LOAD_DEPTH; ++q)
- {
- X_[0][q][id] = X.Get(by*LOAD_WIDTH + id, 0, m*LOAD_DEPTH + q, 0, Xdata);
- W_[0][q][id] = W.Get(0, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id, 0, WBKdata, WBK.dataLength);
- }
- }
-
- for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
- [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
- [unroll]
- for (uint i = 0; i < LOAD_DEPTH; ++i)
- {
- v[yyy][xxx] += X_[1][i][bby*BLOCK_WIDTH + yyy] * W_[1][i][bbx*BLOCK_WIDTH + xxx];
- }
- GroupMemoryBarrierWithGroupSync();
- }
-
- for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
- [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
- O.Set(by*LOAD_WIDTH + bby*BLOCK_WIDTH + yyy, 0, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xxx, 0, v[yyy][xxx], Odata);
-
- #undef X_
- #undef W_
-}
-
-#else
-
-groupshared float DenseTiled_XcacheR[LOAD_DEPTH][LOAD_WIDTH];
-groupshared float DenseTiled_WcacheR[LOAD_DEPTH][LOAD_WIDTH];
-
-[numthreads(THREAD_COUNT, 1, 1)]
-void Dense64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ DenseTiled_XcacheR
- #define W_ DenseTiled_WcacheR
-
- uint id = groupThreadID.x;
- uint bx = groupID.x;
- uint by = groupID.y;
-
- uint n = by * LOAD_WIDTH + id;
- uint x = bx * LOAD_WIDTH + id;
-
- float v[LOAD_WIDTH];
- float bias = B.Get(0, 0, x, 0, WBKdata, WBK.dataLength);
- [unroll] for (uint xx = 0; xx < LOAD_WIDTH; ++xx)
- v[xx] = bias;
-
- for (uint m = 0; m < X.width/LOAD_DEPTH; ++m)
- {
- float ww[LOAD_DEPTH];
- for (uint q = 0; q < LOAD_DEPTH; ++q)
- {
- X_[q][id] = X.Get(n, 0, m*LOAD_DEPTH + q, 0, Xdata);
- //W_[q][id] = W.Get(0, m*LOAD_DEPTH + q, x, 0, WBKdata, WBK.dataLength);
- ww[q] = W.Get(0, m*LOAD_DEPTH + q, x, 0, WBKdata, WBK.dataLength);
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- for (uint w = 0; w < LOAD_WIDTH; ++w)
- {
- [unroll]
- for (uint i = 0; i < LOAD_DEPTH; ++i)
- {
- //v[w] += X_[i][w] * W_[i][id];
- v[w] += X_[i][w] * ww[i];
- }
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
-
- [unroll] for ( xx = 0; xx < LOAD_WIDTH; ++xx)
- O.Set(by * LOAD_WIDTH + xx, 0, x, 0, v[xx], Odata);
-
- #undef X_
- #undef W_
-}
-#endif
-
-#if 1
-#undef TILE_WIDTH
-#define TILE_WIDTH 16
-groupshared float DenseTiled_Xcache64[16][TILE_WIDTH*TILE_WIDTH];
-groupshared float DenseTiled_Wcache64[16][TILE_WIDTH*TILE_WIDTH];
-[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
-void DenseTiled64x64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ DenseTiled_Xcache64
- #define W_ DenseTiled_Wcache64
-
- uint tx = groupThreadID.x;
- uint ty = groupThreadID.y;
- uint x = groupID.x*TILE_WIDTH + tx;
- uint n = groupID.y*TILE_WIDTH + ty;
-
- float b0 = B.Get(0, 0, x*4+0, 0, WBKdata, WBK.dataLength);
- float b1 = B.Get(0, 0, x*4+1, 0, WBKdata, WBK.dataLength);
- float b2 = B.Get(0, 0, x*4+2, 0, WBKdata, WBK.dataLength);
- float b3 = B.Get(0, 0, x*4+3, 0, WBKdata, WBK.dataLength);
-
- float4 v0, v1, v2, v3;
- v0 = v1 = v2 = v3 = float4(b0, b1, b2, b3);
-
- for (uint m = 0; m < X.width/(TILE_WIDTH*4); ++m)
- {
- for (uint yy = 0; yy < 4; ++yy)
- for (uint xx = 0; xx < 4; ++xx)
- {
- X_[yy*4+xx][ty*TILE_WIDTH+tx] = X.Get(n*4+yy, 0, (m*TILE_WIDTH + tx)*4+xx, 0, Xdata);
- W_[yy*4+xx][ty*TILE_WIDTH+tx] = W.Get(0, (m*TILE_WIDTH + ty)*4+yy, x*4+xx, 0, WBKdata, WBK.dataLength);
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- //[unroll]
- for (uint i = 0; i < TILE_WIDTH; ++i)
- {
- [unroll]
- for (uint q = 0; q < 4; ++q)
- {
- float x0 = X_[0*4+q][ty*TILE_WIDTH+i];
- float x1 = X_[1*4+q][ty*TILE_WIDTH+i];
- float x2 = X_[2*4+q][ty*TILE_WIDTH+i];
- float x3 = X_[3*4+q][ty*TILE_WIDTH+i];
-
- float w0 = W_[q*4+0][i*TILE_WIDTH+tx];
- float w1 = W_[q*4+1][i*TILE_WIDTH+tx];
- float w2 = W_[q*4+2][i*TILE_WIDTH+tx];
- float w3 = W_[q*4+3][i*TILE_WIDTH+tx];
-
- v0.x = mad(x0, w0, v0.x); //--
- v1.x = mad(x1, w0, v1.x);
- v2.x = mad(x2, w0, v2.x);
- v3.x = mad(x3, w0, v3.x);
- v0.y = mad(x0, w1, v0.y); //--
- v1.y = mad(x1, w1, v1.y);
- v2.y = mad(x2, w1, v2.y);
- v3.y = mad(x3, w1, v3.y);
- v0.z = mad(x0, w2, v0.z); //--
- v1.z = mad(x1, w2, v1.z);
- v2.z = mad(x2, w2, v2.z);
- v3.z = mad(x3, w2, v3.z);
- v0.w = mad(x0, w3, v0.w); //--
- v1.w = mad(x1, w3, v1.w);
- v2.w = mad(x2, w3, v2.w);
- v3.w = mad(x3, w3, v3.w);
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
-
- O.Set(n*4+0, 0, x*4+0, 0, v0.x, Odata);
- O.Set(n*4+0, 0, x*4+1, 0, v0.y, Odata);
- O.Set(n*4+0, 0, x*4+2, 0, v0.z, Odata);
- O.Set(n*4+0, 0, x*4+3, 0, v0.w, Odata);
-
- O.Set(n*4+1, 0, x*4+0, 0, v1.x, Odata);
- O.Set(n*4+1, 0, x*4+1, 0, v1.y, Odata);
- O.Set(n*4+1, 0, x*4+2, 0, v1.z, Odata);
- O.Set(n*4+1, 0, x*4+3, 0, v1.w, Odata);
-
- O.Set(n*4+2, 0, x*4+0, 0, v2.x, Odata);
- O.Set(n*4+2, 0, x*4+1, 0, v2.y, Odata);
- O.Set(n*4+2, 0, x*4+2, 0, v2.z, Odata);
- O.Set(n*4+2, 0, x*4+3, 0, v2.w, Odata);
-
- O.Set(n*4+3, 0, x*4+0, 0, v3.x, Odata);
- O.Set(n*4+3, 0, x*4+1, 0, v3.y, Odata);
- O.Set(n*4+3, 0, x*4+2, 0, v3.z, Odata);
- O.Set(n*4+3, 0, x*4+3, 0, v3.w, Odata);
-
- #undef X_
- #undef W_
-}
-
-#else
-
-#define TILE_WIDTH 16
-#define RTILE 4
-groupshared float DenseTiled_Xcache64[RTILE*RTILE][TILE_WIDTH*TILE_WIDTH];
-groupshared float DenseTiled_Wcache64[RTILE*RTILE][TILE_WIDTH*TILE_WIDTH];
-[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
-void DenseTiled64x64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ DenseTiled_Xcache64
- #define W_ DenseTiled_Wcache64
-
- uint tx = groupThreadID.x;
- uint ty = groupThreadID.y;
- uint x = groupID.x*TILE_WIDTH + tx;
- uint n = groupID.y*TILE_WIDTH + ty;
-
- float v[RTILE*RTILE];
- [unroll] for (uint xxxx = 0; xxxx < RTILE; ++xxxx)
- {
- float b = B.Get(0, 0, x*RTILE+xxxx, 0, WBKdata, WBK.dataLength);
- [unroll] for (uint yyyy = 0; yyyy < RTILE; ++yyyy)
- v[yyyy*RTILE+xxxx] = b;
- }
-
- for (uint m = 0; m < X.width/(TILE_WIDTH*RTILE); ++m)
- {
- for (uint yy = 0; yy < RTILE; ++yy)
- [unroll] for (uint xx = 0; xx < RTILE; ++xx)
- {
- X_[yy*RTILE+xx][ty*TILE_WIDTH+tx] = X.Get(n*RTILE+yy, 0, (m*TILE_WIDTH + tx)*RTILE+xx, 0, Xdata);
- W_[yy*RTILE+xx][ty*TILE_WIDTH+tx] = W.Get(0, (m*TILE_WIDTH + ty)*RTILE+yy, x*RTILE+xx, 0, WBKdata, WBK.dataLength);
- }
- GroupMemoryBarrierWithGroupSync();
-
- for (uint ii = 0; ii < TILE_WIDTH; ++ii)
- {
- [unroll] for (uint yy = 0; yy < RTILE; ++yy)
- [unroll] for (uint xx = 0; xx < RTILE; ++xx)
- [unroll] for (uint i = 0; i < RTILE; ++i)
- {
- float x = X_[yy*RTILE+i][ty*TILE_WIDTH+ii];
- float w = W_[i*RTILE+xx][ii*TILE_WIDTH+tx];
- v[yy*RTILE+xx] = mad(x, w, v[yy*RTILE+xx]);
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
-
- [unroll] for (uint yy = 0; yy < RTILE; ++yy)
- [unroll] for (uint xx = 0; xx < RTILE; ++xx)
- O.Set(n*RTILE+yy, 0, x*RTILE+xx, 0, v[yy*RTILE+xx], Odata);
-
- #undef X_
- #undef W_
-}
-
-#endif
-
-#undef TILE_WIDTH
-#define TILE_WIDTH 16 // 32 crashes on MacBookPro/AMD
-groupshared float DenseTiled_Xcache32[4][TILE_WIDTH][TILE_WIDTH];
-groupshared float DenseTiled_Wcache32[4][TILE_WIDTH][TILE_WIDTH];
-[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
-void DenseTiled32x32(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ DenseTiled_Xcache32
- #define W_ DenseTiled_Wcache32
-
- uint tx = groupThreadID.x;
- uint ty = groupThreadID.y;
- uint x = groupID.x*TILE_WIDTH + tx;
- uint n = groupID.y*TILE_WIDTH + ty;
-
- float b0 = B.Get(0, 0, x*2+0, 0, WBKdata, WBK.dataLength);
- float b1 = B.Get(0, 0, x*2+1, 0, WBKdata, WBK.dataLength);
- float4 v = float4(b0, b1,
- b0, b1);
-
- for (uint m = 0; m < X.width/(TILE_WIDTH*2);)
- {
- // @TODO: read in float2s
- float x0 = X.Get(n*2+0, 0, m*TILE_WIDTH*2 + tx*2+0, 0, Xdata);
- float x1 = X.Get(n*2+0, 0, m*TILE_WIDTH*2 + tx*2+1, 0, Xdata);
- float x2 = X.Get(n*2+1, 0, m*TILE_WIDTH*2 + tx*2+0, 0, Xdata);
- float x3 = X.Get(n*2+1, 0, m*TILE_WIDTH*2 + tx*2+1, 0, Xdata);
-
- float w0 = W.Get(0, m*TILE_WIDTH*2 + ty*2+0, x*2+0, 0, WBKdata, WBK.dataLength);
- float w1 = W.Get(0, m*TILE_WIDTH*2 + ty*2+0, x*2+1, 0, WBKdata, WBK.dataLength);
- float w2 = W.Get(0, m*TILE_WIDTH*2 + ty*2+1, x*2+0, 0, WBKdata, WBK.dataLength);
- float w3 = W.Get(0, m*TILE_WIDTH*2 + ty*2+1, x*2+1, 0, WBKdata, WBK.dataLength);
-
- ++m;
-
- X_[0][ty][tx] = x0;
- X_[1][ty][tx] = x1;
- X_[2][ty][tx] = x2;
- X_[3][ty][tx] = x3;
-
- W_[0][ty][tx] = w0;
- W_[1][ty][tx] = w1;
- W_[2][ty][tx] = w2;
- W_[3][ty][tx] = w3;
-
- GroupMemoryBarrierWithGroupSync();
-
- [unroll]
- for (uint i = 0; i < TILE_WIDTH; ++i)
- {
- float4 x = //X_[ty][i];
- float4( X_[0][ty][i],
- X_[1][ty][i],
- X_[2][ty][i],
- X_[3][ty][i]);
- float4 w = //W_[i][tx];
- float4( W_[0][i][tx],
- W_[1][i][tx],
- W_[2][i][tx],
- W_[3][i][tx]);
-
- v.x = mad(w.x, x.x, v.x);
- v.y = mad(w.y, x.x, v.y);
- v.z = mad(w.x, x.z, v.z);
- v.w = mad(w.y, x.z, v.w);
-
- v.x = mad(w.z, x.y, v.x);
- v.y = mad(w.w, x.y, v.y);
- v.z = mad(w.z, x.w, v.z);
- v.w = mad(w.w, x.w, v.w);
-
- //v.x += k.x*x.x + k.z*x.y;
- //v.y += k.y*x.x + k.w*x.y;
- //v.z += k.x*x.z + k.z*x.w;
- //v.w += k.y*x.z + k.w*x.w;
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
-
- O.Set(n*2+0, 0, x*2+0, 0, v.x, Odata);
- O.Set(n*2+0, 0, x*2+1, 0, v.y, Odata);
- O.Set(n*2+1, 0, x*2+0, 0, v.z, Odata);
- O.Set(n*2+1, 0, x*2+1, 0, v.w, Odata);
-
- #undef X_
- #undef W_
-}
-
-// sligtly faster on AMD (56ms vs 62ms)
-#undef TILE_WIDTH
-#define TILE_WIDTH 16
-//#define CACHE_ONLY_X
-//#define TRANSPOSE_W
-//#define TRANSPOSE_X
-groupshared float DenseTiled_XcacheF[TILE_WIDTH][TILE_WIDTH];
-#if !defined(CACHE_ONLY_X)
-groupshared float DenseTiled_WcacheF[TILE_WIDTH][TILE_WIDTH];
-#endif
-[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
-void DenseTiled16x16_amd(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ DenseTiled_XcacheF
- #define W_ DenseTiled_WcacheF
-
- uint tx = groupThreadID.x;
- uint ty = groupThreadID.y;
- uint x = groupID.x*TILE_WIDTH + tx;
- uint b = groupID.y*TILE_WIDTH + ty;
-
- float v = B.Get(0, 0, x, 0, WBKdata, WBK.dataLength);
-
- for (uint m = 0; m < X.width/TILE_WIDTH; ++m)
- {
- #if defined(TRANSPOSE_X)
- X_[tx][ty] = X.Get(b, 0, m*TILE_WIDTH + tx, 0, Xdata);
- #else
- X_[ty][tx] = X.Get(b, 0, m*TILE_WIDTH + tx, 0, Xdata);
- #endif
-
- #if defined(CACHE_ONLY_X)
- float ww = WBKdata[wi];
- #else
- #if defined(TRANSPOSE_W)
- W_[tx][ty] = W.Get(0, m*TILE_WIDTH + ty, x, 0, WBKdata, WBK.dataLength);
- #else
- W_[ty][tx] = W.Get(0, m*TILE_WIDTH + ty, x, 0, WBKdata, WBK.dataLength);
- #endif
- #endif
- GroupMemoryBarrierWithGroupSync();
-
- //[unroll(groupthreads)]
- [unroll]
- for (uint i = 0; i < TILE_WIDTH; ++i)
- {
- #if defined(TRANSPOSE_X)
- float x = X_[i][ty];
- #else
- float x = X_[ty][i];
- #endif
-
- #if defined(CACHE_ONLY_X)
- //float w = ww;
- //if (i != TILE_WIDTH-1) { wi += W.width; ww = WBKdata[wi]; }
- float w = W.Get(0, m*TILE_WIDTH + i, x, 0, WBKdata, WBK.dataLength);
- #else
- #if defined(TRANSPOSE_W)
- float w = W_[tx][i];
- #else
- float w = W_[i][tx];
- #endif
- #endif
-
- v += x * w;
- }
- }
-
- O.Set(b, 0, x, 0, v, Odata);
-
- #undef X_
- #undef W_
-}
-
-#undef TILE_WIDTH
-#define TILE_WIDTH 16
-groupshared float DenseTiled_Xcache[TILE_WIDTH][TILE_WIDTH];
-groupshared float DenseTiled_Wcache[TILE_WIDTH][TILE_WIDTH];
-[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
-void DenseTiled(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ DenseTiled_Xcache
- #define W_ DenseTiled_Wcache
-
- uint tx = groupThreadID.x;
- uint ty = groupThreadID.y;
- uint x = groupID.x*TILE_WIDTH + tx;
- uint b = groupID.y*TILE_WIDTH + ty;
-
- bool mask = (x < O.width && b < O.batch);
-
- float v = B.Get(0, 0, x, 0, WBKdata, WBK.dataLength);
-
- for (uint m = 0; m < X.width/TILE_WIDTH; ++m)
- {
- if (mask)
- {
- X_[ty][tx] = X.Get(b, 0, m*TILE_WIDTH + tx, 0, Xdata);
- W_[ty][tx] = W.Get(0, m*TILE_WIDTH + ty, x, 0, WBKdata, WBK.dataLength);
- }
- else
- {
- X_[ty][tx] = 0;
- W_[ty][tx] = 0;
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- [unroll]
- for (uint i = 0; i < TILE_WIDTH; ++i)
- {
- v += X_[ty][i] * W_[i][tx];
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
-
- if (mask)
- O.Set(b, 0, x, 0, v, Odata);
-
- #undef X_
- #undef W_
-}
-
-
-groupshared float DenseTiled_XcacheP[TILE_WIDTH][TILE_WIDTH];
-groupshared float DenseTiled_WcacheP[TILE_WIDTH][TILE_WIDTH];
-// Prefetch - seems to be the same performance as DenseTiled16x16 without prefetch, has higher register pressure
-[numthreads(TILE_WIDTH,TILE_WIDTH,1)]
-void DenseTiledPrefetch16x16(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ DenseTiled_XcacheP
- #define W_ DenseTiled_WcacheP
-
- uint tx = groupThreadID.x;
- uint ty = groupThreadID.y;
- uint x = groupID.x*TILE_WIDTH + tx;
- uint b = groupID.y*TILE_WIDTH + ty;
-
- float v = B.Get(0, 0, x, 0, WBKdata, WBK.dataLength);
-
- float Xregs[TILE_WIDTH][TILE_WIDTH];
- float Wregs[TILE_WIDTH][TILE_WIDTH];
- for (uint m = 0; m < X.width/TILE_WIDTH; ++m)
- {
- Xregs[ty][tx] = X.Get(b, 0, m*TILE_WIDTH + tx, 0, Xdata);
- Wregs[ty][tx] = W.Get(0, m*TILE_WIDTH + ty, x, 0, WBKdata, WBK.dataLength);
- GroupMemoryBarrierWithGroupSync();
- }
-
- for (m = 0; m < X.width/TILE_WIDTH; ++m)
- {
- X_[ty][tx] = Xregs[ty][tx];
- W_[ty][tx] = Wregs[ty][tx];
-
- Xregs[ty][tx] = X.Get(b, 0, m*TILE_WIDTH + tx, 0, Xdata);
- Wregs[ty][tx] = W.Get(0, m*TILE_WIDTH + ty, x, 0, WBKdata, WBK.dataLength);
-
- for (uint i = 0; i < TILE_WIDTH;)
- {
- // can unroll up to 16 because TILE_WIDTH=16
- v += X_[ty][i] * W_[i][tx]; ++i;
- v += X_[ty][i] * W_[i][tx]; ++i;
- v += X_[ty][i] * W_[i][tx]; ++i;
- v += X_[ty][i] * W_[i][tx]; ++i;
-
- v += X_[ty][i] * W_[i][tx]; ++i;
- v += X_[ty][i] * W_[i][tx]; ++i;
- v += X_[ty][i] * W_[i][tx]; ++i;
- v += X_[ty][i] * W_[i][tx]; ++i;
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
-
- O.Set(b, 0, x, 0, v, Odata);
- #undef X_
- #undef W_
-}
-
-[numthreads(1,1,1)]
-void Relu(uint3 groupID : SV_GroupID)
-{
- uint x = groupID.x;
- uint b = groupID.y;
- uint c = groupID.z;
- for (uint y = 0; y < X.height; ++y)
- {
- float v = X.Get(b, y, x, c, Xdata, X.dataLength);
- v = 0.5f * (v + abs(v));
- O.Set(b, y, x, c, v, Odata, O.dataLength);
- }
-}
-
-[numthreads(16,16,1)]
-void Relu_Cmod16_CNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint c = 16*groupID.x + groupThreadID.x;
- uint nyx = 16*groupID.y + groupThreadID.y;
-
- uint width = X.width;
- uint height = X.height;
-
- uint x = nyx % width;
- uint ny = nyx / width;
- uint y = ny % height;
- uint n = ny / height;
-
- float v = X.Get(n, y, x, c, Xdata, X.dataLength);
- v = 0.5f * (v + abs(v));
- O.Set(n, y, x, c, v, Odata, O.dataLength);
-}
-
-[numthreads(512,1,1)]
-void Relu_Nyxc(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint nyxc = 512*groupID.x + groupThreadID.x;
-
- uint width = X.width;
- uint height = X.height;
- uint channels = X.channels;
-
- uint c = nyxc % channels;
- uint nyx = nyxc / channels;
- uint x = nyx % width;
- uint ny = nyx / width;
- uint y = ny % height;
- uint n = ny / height;
-
- float v = X.Get(n, y, x, c, Xdata, X.dataLength);
- v = 0.5f * (v + abs(v));
- O.Set(n, y, x, c, v, Odata, O.dataLength);
-}
-
-[numthreads(16,16,1)]
-void Relu16x16(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint x = 16*groupID.x + groupThreadID.x;
- uint b = 16*groupID.y + groupThreadID.y;
- uint c = groupID.z;
-
- for (uint y = 0; y < X.height; ++y)
- {
- float v = X.Get(b, y, x, c, Xdata, X.dataLength);
- v = 0.5f * (v + abs(v));
- O.Set(b, y, x, c, v, Odata, O.dataLength);
- }
-}
-
-[numthreads(16,16,1)]
-void Relu16x16_(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint x = 16*groupID.x + groupThreadID.x;
- uint b = 16*groupID.y + groupThreadID.y;
-
- for (uint y = 0; y < X.height; ++y)
- {
- for (uint c = 0; c < X.channels; ++c)
- {
- float v = X.Get(b, y, x, c, Xdata, X.dataLength);
- v = 0.5f * (v + abs(v));
- O.Set(b, y, x, c, v, Odata, O.dataLength);
- }
- }
-}
-
-
-// channels, width, batch
-[numthreads(16,2,16)]
-void ReluChannelsFirst16x2x16(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint c = 16*groupID.x + groupThreadID.x;
- uint x = 2*groupID.y + groupThreadID.y;
- uint b = 16*groupID.z + groupThreadID.z;
-
- for (uint y = 0; y < X.height; ++y)
- {
- float v = X.Get(b, y, x, c, Xdata, X.dataLength);
- v = 0.5f * (v + abs(v));
- O.Set(b, y, x, c, v, Odata, O.dataLength);
- }
-}
-
-[numthreads(256,1,1)]
-void Relu256xV(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint x = 256*groupID.x + groupThreadID.x;
- uint b = groupID.y;
- uint c = groupID.z;
-
- for (uint y = 0; y < X.height; ++y)
- {
- float v = 0;
- for (uint b = 0; b < X.batch; )
- {
- v = X.Get(b, y, x, c, Xdata, X.dataLength);
- v = 0.5f * (v + abs(v));
- O.Set(b, y, x, c, v, Odata, O.dataLength);
- ++b;
-
- v = X.Get(b, y, x, c, Xdata, X.dataLength);
- v = 0.5f * (v + abs(v));
- O.Set(b, y, x, c, v, Odata, O.dataLength);
- ++b;
-
- v = X.Get(b, y, x, c, Xdata, X.dataLength);
- v = 0.5f * (v + abs(v));
- O.Set(b, y, x, c, v, Odata, O.dataLength);
- ++b;
-
- v = X.Get(b, y, x, c, Xdata, X.dataLength);
- v = 0.5f * (v + abs(v));
- O.Set(b, y, x, c, v, Odata, O.dataLength);
- ++b;
- }
- }
-}
-
-
-#define FLT_MAX 3.402823466e+38F
-
-[numthreads(1,1,1)]
-void Softmax(uint3 groupID : SV_GroupID)
-{
- uint b = groupID.x;
- uint x = groupID.y;
-
- float maxV = -FLT_MAX;
- for (uint i = 0; i < X.width; ++i)
- {
- float v = X.Get(b, 0, i, 0, Xdata, X.dataLength);
- if (v > maxV)
- maxV = v;
- }
-
- float sum = 0.0f;
- for (i = 0; i < X.width; ++i)
- {
- float v = X.Get(b, 0, i, 0, Xdata, X.dataLength);
- sum += exp(v - maxV);
- }
-
- float v = X.Get(b, 0, x, 0, Xdata, X.dataLength);
- v = exp(v - maxV) / sum;
- O.Set(b, 0, x, 0, v, Odata, O.dataLength);
-}
-
-[numthreads(256,2,1)]
-void Softmax256x2(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint b = 256*groupID.x + groupThreadID.x;
- uint x = 2*groupID.y + groupThreadID.y;
-
- float maxV = -FLT_MAX;
- for (uint i = 0; i < X.width; ++i)
- {
- float v = X.Get(b, 0, i, 0, Xdata, X.dataLength);
- if (v > maxV)
- maxV = v;
- }
-
- float sum = 0.0f;
- for (i = 0; i < X.width; ++i)
- {
- float v = X.Get(b, 0, i, 0, Xdata, X.dataLength);
- sum += exp(v - maxV);
- }
-
- float v = X.Get(b, 0, x, 0, Xdata, X.dataLength);
- v = exp(v - maxV) / sum;
- O.Set(b, 0, x, 0, v, Odata, O.dataLength);
-}
-
-[numthreads(1,1,1)]
-void MaxPooling2D(uint3 groupID : SV_GroupID)
-{
- uint c = groupID.x;
- uint x = groupID.y;
- uint y = groupID.z;
-
- for (uint b = 0; b < O.batch; ++b)
- {
- float v0 = X.Get(b, y*2, x*2, c, Xdata, X.dataLength);
- float v1 = X.Get(b, y*2+1, x*2, c, Xdata, X.dataLength);
- float v2 = X.Get(b, y*2, x*2+1, c, Xdata, X.dataLength);
- float v3 = X.Get(b, y*2+1, x*2+1, c, Xdata, X.dataLength);
- float v = max(v0, max(v1, max(v2, v3)));
- O.Set(b, y, x, c, v, Odata, O.dataLength);
- }
-}
-
-[numthreads(16,4,4)]
-void MaxPooling2D16x4x4(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint c = 16*groupID.x + groupThreadID.x;
- uint x = 4*groupID.y + groupThreadID.y;
- uint y = 4*groupID.z + groupThreadID.z;
-
- for (uint b = 0; b < O.batch; ++b)
- {
- float v0 = X.Get(b, y*2, x*2, c, Xdata, X.dataLength);
- float v1 = X.Get(b, y*2+1, x*2, c, Xdata, X.dataLength);
- float v2 = X.Get(b, y*2, x*2+1, c, Xdata, X.dataLength);
- float v3 = X.Get(b, y*2+1, x*2+1, c, Xdata, X.dataLength);
- float v = max(v0, max(v1, max(v2, v3)));
- O.Set(b, y, x, c, v, Odata, O.dataLength);
- }
-}
-
-[numthreads(16,16,2)]
-void Conv2D_Valid(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint k = 16*groupID.x + groupThreadID.x;
- uint n = 16*groupID.y + groupThreadID.y;
- uint y = 2*groupID.z + groupThreadID.z + _FilterSize;
-
- //for (int y = _FilterSize; y < X.height - _FilterSize; ++y)
- {
- for (uint x = _FilterSize; x < X.width - _FilterSize; ++x)
- {
- float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- for (int i = -(int)_FilterSize; i < (int)_FilterSize + 1; ++i)
- {
- for (int j = -(int)_FilterSize; j < (int)_FilterSize + 1; ++j)
- {
- for (uint c = 0; c < X.channels; ++c)
- {
- v += X.Get(n, y+j, x+i, c, Xdata, X.dataLength) * K.Get(_FilterSize+j, _FilterSize+i, c, k, WBKdata, WBK.dataLength);
- }
- }
- }
- O.Set(n, y-_FilterSize, x-_FilterSize, k, v, Odata, O.dataLength);
- }
- }
-}
-
-[numthreads(16,8,1)]
-void Conv2D_Kmod16_Nmod8_KNY(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint k = 16*groupID.x + groupThreadID.x;
- uint n = 8*groupID.y + groupThreadID.y;
- uint y = 1*groupID.z + groupThreadID.z;
-
- //for (int y = _FilterSize; y < X.height - _FilterSize; ++y)
- {
- for (uint x = 0; x < X.width - _Border; ++x)
- {
- float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- for (uint j = 0; j < 2*_FilterSize+1; ++j)
- {
- if (y+j < _Offset) continue;
- if (y+j-_Offset >= X.height) continue;
-
- for (uint i = 0; i < 2*_FilterSize+1; ++i)
- {
- if (x+i < _Offset) continue;
- if (x+i-_Offset >= X.width) continue;
-
- for (uint c = 0; c < X.channels; ++c)
- {
- v += X.Get(n, y+j-_Offset, x+i-_Offset, c, Xdata, X.dataLength) * K.Get(j, i, c, k, WBKdata, WBK.dataLength);
- }
- }
- }
- O.Set(n, y, x, k, v, Odata, O.dataLength);
- }
- }
-}
-
-[numthreads(1,1,1)]
-void Conv2D(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint k = 1*groupID.x + groupThreadID.x;
- uint n = 1*groupID.y + groupThreadID.y;
- uint y = 1*groupID.z + groupThreadID.z;
-
- //for (int y = _FilterSize; y < X.height - _FilterSize; ++y)
- {
- for (uint x = 0; x < X.width - _Border; ++x)
- {
- float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- for (uint j = 0; j < 2*_FilterSize+1; ++j)
- {
- if (y+j < _Offset) continue;
- if (y+j-_Offset >= X.height) continue;
-
- for (uint i = 0; i < 2*_FilterSize+1; ++i)
- {
- if (x+i < _Offset) continue;
- if (x+i-_Offset >= X.width) continue;
-
- for (uint c = 0; c < X.channels; ++c)
- {
- v += X.Get(n, y+j-_Offset, x+i-_Offset, c, Xdata, X.dataLength) * K.Get(j, i, c, k, WBKdata, WBK.dataLength);
- }
- }
- }
- O.Set(n, y, x, k, v, Odata, O.dataLength);
- }
- }
-}
-
-#if 0
-
-#define MAX_TILE_WIDTH 16
-#define KERNEL_COUNT 4
-#define KERNEL_SIZE 3
-#define KERNEL_RADIUS 1 //(KERNEL_SIZE-1)/2
-groupshared float XCcache[MAX_TILE_WIDTH+KERNEL_SIZE-1][MAX_TILE_WIDTH+KERNEL_SIZE-1];
-groupshared float Kcache[KERNEL_SIZE][KERNEL_SIZE][KERNEL_COUNT];
-
-#undef TILE_WIDTH
-#define TILE_WIDTH 13
-[numthreads(TILE_WIDTH,TILE_WIDTH,KERNEL_COUNT)]
-void Conv2DTiled14x14_Kernel3x3(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint tx = groupThreadID.x;
- uint ty = groupThreadID.y;
- uint tk = groupThreadID.z;
- uint gx = groupID.x;
- uint gy = groupID.y;
- uint gk = groupID.z;
- uint tileCornerX = gx*TILE_WIDTH;
- uint tileCornerY = gy*TILE_WIDTH;
- uint x = tileCornerX + tx;
- uint y = tileCornerY + ty;
- uint k = gk*KERNEL_COUNT + tk;
- uint idx = ty*TILE_WIDTH + tx;
-
- for (uint b = 0; b < X.batch; ++b)
- {
- float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- for (uint c = 0; c < X.channels; ++c)
- {
- if (tk == 0)
- XCcache[ty][tx] = X.Get(b, y, x, c, Xdata);
- else if (tk == 1 && idx < TILE_WIDTH * 2)
- {
- uint yy = idx / 2;
- uint xx = idx % 2 + TILE_WIDTH;
- XCcache[yy][xx] = X.Get(b, tileCornerY+yy, tileCornerX+xx, c, Xdata);
- }
- else if (tk == 2 && idx < (TILE_WIDTH + 2) * 2)
- {
- uint yy = idx / (TILE_WIDTH + 2) + TILE_WIDTH;
- uint xx = idx % (TILE_WIDTH + 2);
- XCcache[yy][xx] = X.Get(b, tileCornerY+yy, tileCornerX+xx, c, Xdata);
- }
- if (tk == 3)
- {
- uint kk = idx / (KERNEL_SIZE * KERNEL_SIZE);
- uint kyx = idx % (KERNEL_SIZE * KERNEL_SIZE);
- if (kk < KERNEL_COUNT)
- {
- uint yy = kyx / KERNEL_SIZE;
- uint xx = kyx % KERNEL_SIZE;
- Kcache[yy][xx][kk] = K.Get(yy, xx, c, gk*KERNEL_COUNT+kk, WBKdata, WBK.dataLength);
- }
- }
- GroupMemoryBarrierWithGroupSync();
-
- for (int i = 0; i < KERNEL_SIZE; ++i)
- {
- for (int j = 0; j < KERNEL_SIZE; ++j)
- {
- v += XCcache[ty+j][tx+i] * Kcache[j][i][tk];
- }
- }
- }
- O.Set(b, y, x, k, v, Odata, O.dataLength);
- GroupMemoryBarrierWithGroupSync();
- }
-}
-
-#undef TILE_WIDTH
-#define TILE_WIDTH 12
-[numthreads(TILE_WIDTH,TILE_WIDTH,KERNEL_COUNT)]
-void Conv2DTiled13x13_Kernel3x3(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint tx = groupThreadID.x;
- uint ty = groupThreadID.y;
- uint tk = groupThreadID.z;
- uint gx = groupID.x;
- uint gy = groupID.y;
- uint gk = groupID.z;
- uint tileCornerX = gx*TILE_WIDTH;
- uint tileCornerY = gy*TILE_WIDTH;
- uint x = tileCornerX + tx;
- uint y = tileCornerY + ty;
- uint k = gk*KERNEL_COUNT + tk;
- uint idx = ty*TILE_WIDTH + tx;
-
- for (uint b = 0; b < X.batch; ++b)
- {
- float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- for (uint c = 0; c < X.channels; ++c)
- {
- if (tk == 0)
- XCcache[ty][tx] = X.Get(b, y, x, c, Xdata);
- else if (tk == 1 && idx < TILE_WIDTH * 2)
- {
- uint yy = idx / 2;
- uint xx = idx % 2 + TILE_WIDTH;
- XCcache[yy][xx] = X.Get(b, tileCornerY+yy, tileCornerX+xx, c, Xdata);
- }
- else if (tk == 2 && idx < (TILE_WIDTH + 2) * 2)
- {
- uint yy = idx / (TILE_WIDTH + 2) + TILE_WIDTH;
- uint xx = idx % (TILE_WIDTH + 2);
- XCcache[yy][xx] = X.Get(b, tileCornerY+yy, tileCornerX+xx, c, Xdata);
- }
- if (tk == 3)
- {
- uint kk = idx / (KERNEL_SIZE * KERNEL_SIZE);
- uint kyx = idx % (KERNEL_SIZE * KERNEL_SIZE);
- if (kk < KERNEL_COUNT)
- {
- uint yy = kyx / KERNEL_SIZE;
- uint xx = kyx % KERNEL_SIZE;
- Kcache[yy][xx][kk] = K.Get(yy, xx, c, gk*KERNEL_COUNT+kk, WBKdata, WBK.dataLength);
- }
- }
- GroupMemoryBarrierWithGroupSync();
-
- for (int i = 0; i < KERNEL_SIZE; ++i)
- {
- for (int j = 0; j < KERNEL_SIZE; ++j)
- {
- v += XCcache[ty+j][tx+i] * Kcache[j][i][tk];
- }
- }
- }
- O.Set(b, y, x, k, v, Odata, O.dataLength);
- GroupMemoryBarrierWithGroupSync();
- }
-}
-
-/*
-#undef TILE_WIDTH
-#define TILE_WIDTH 12
-[numthreads(TILE_WIDTH,TILE_WIDTH,KERNEL_COUNT)]
-void Conv2DTiled12x12_Kernel3x3(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint tx = groupThreadID.x;
- uint ty = groupThreadID.y;
- uint tk = groupThreadID.z;
- uint gx = groupID.x;
- uint gy = groupID.y;
- uint gk = groupID.z;
- uint tileCornerX = gx*TILE_WIDTH;
- uint tileCornerY = gy*TILE_WIDTH;
- uint x = tileCornerX + tx;
- uint y = tileCornerY + ty;
- uint k = gk*KERNEL_COUNT + tk;
- uint idx = ty*TILE_WIDTH + tx;
-
- for (uint b = 0; b < X.batch; ++b)
- {
- float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- for (uint c = 0; c < X.channels; ++c)
- {
- if (gk == 0)
- XCcache[ty][tx] = X.Get(b, y, x, c, Xdata);
- else if (gk == 1 && idx < TILE_WIDTH * 2)
- {
- uint yy = idx / 2;
- uint xx = idx % 2 + TILE_WIDTH;
- XCcache[yy][xx] = X.Get(b, tileCornerY+yy, tileCornerX+xx, c, Xdata);
- }
- else if (gk == 2 && idx < (TILE_WIDTH + 2) * 2)
- {
- uint yy = idx / (TILE_WIDTH + 2) + TILE_WIDTH;
- uint xx = idx % (TILE_WIDTH + 2);
- XCcache[yy][xx] = X.Get(b, tileCornerY+yy, tileCornerX+xx, c, Xdata);
- }
- else if (gk == 3 && ty < KERNEL_SIZE && tx < KERNEL_SIZE)
- Kcache[ty][tx][tk] = K.Get(ty, tx, c, k, WBKdata, WBK.dataLength);
- GroupMemoryBarrierWithGroupSync();
-
- for (int i = 0; i < KERNEL_SIZE; ++i)
- {
- for (int j = 0; j < KERNEL_SIZE; ++j)
- {
- v += XCcache[ty+j][tx+i] * Kcache[j][i][tk];
- }
- }
- }
- O.Set(b, y-KERNEL_RADIUS, x-KERNEL_RADIUS, k, v, Odata, O.dataLength);
- GroupMemoryBarrierWithGroupSync();
- }
-}
-*/
-
-// %TODO: only supports up to 32 channels now
-#undef KERNEL_COUNT
-#undef CHANNEL_COUNT
-#define KERNEL_COUNT 16
-#define CHANNEL_COUNT 32
-groupshared float K2cache[CHANNEL_COUNT][KERNEL_COUNT][9];
-[numthreads(KERNEL_COUNT,CHANNEL_COUNT,1)]
-void Conv2D_Kernel3x3_32Channel_Valid(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint tk = groupThreadID.x;
- uint k = KERNEL_COUNT*groupID.x + tk;
- uint n = CHANNEL_COUNT*groupID.y + groupThreadID.y;
-
- for (uint q = 0; q < 9; ++q)
- {
- uint tc = n % CHANNEL_COUNT;
- K2cache[tc][tk][q] = K.Get(q/3, q%3, tc, k, WBKdata, WBK.dataLength);
- }
- GroupMemoryBarrierWithGroupSync();
-
- for (uint y = 0; y < X.height - _FilterSize*2; ++y)
- {
- for (uint x = 0; x < X.width - _FilterSize*2; ++x)
- {
- float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- for (uint q = 0; q < 9; ++q)
- for (uint c = 0; c < CHANNEL_COUNT; c += 4)
- {
- //K.Get(q/3, q%3, c, k, WBKdata, WBK.dataLength);
- v += X.Get(n, y+q/3, x+q%3, c+0, Xdata, X.dataLength) * K2cache[c+0][tk][q];
- v += X.Get(n, y+q/3, x+q%3, c+1, Xdata, X.dataLength) * K2cache[c+1][tk][q];
- v += X.Get(n, y+q/3, x+q%3, c+2, Xdata, X.dataLength) * K2cache[c+2][tk][q];
- v += X.Get(n, y+q/3, x+q%3, c+3, Xdata, X.dataLength) * K2cache[c+3][tk][q];
- }
- O.Set(n, y, x, k, v, Odata, O.dataLength);
- }
- }
-}
-
-[numthreads(KERNEL_COUNT,CHANNEL_COUNT,1)]
-void Conv2D_Kernel3x3_32Channel(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint tk = groupThreadID.x;
- uint k = KERNEL_COUNT*groupID.x + tk;
- uint n = CHANNEL_COUNT*groupID.y + groupThreadID.y;
-
- for (uint q = 0; q < 9; ++q)
- {
- uint tc = n % CHANNEL_COUNT;
- K2cache[tc][tk][q] = K.Get(q/3, q%3, tc, k, WBKdata, WBK.dataLength);
- }
- GroupMemoryBarrierWithGroupSync();
-
- for (uint y = 0; y < X.height - _Border; ++y)
- {
- for (uint x = 0; x < X.width - _Border; ++x)
- {
- float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- for (uint dy = 0; dy < 3; ++dy)
- {
- if (y+dy < _Offset) continue;
- if (y+dy-_Offset >= X.height) continue;
- for (uint dx = 0; dx < 3; ++dx)
- {
- if (x+dx < _Offset) continue;
- if (x+dx-_Offset >= X.width) continue;
-
- uint q = dy*3+dx;
- for (uint c = 0; c < CHANNEL_COUNT; c += 4)
- {
- //K.Get(q/3, q%3, c, k, WBKdata, WBK.dataLength);
- v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+0, Xdata, X.dataLength) * K2cache[c+0][tk][q];
- v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+1, Xdata, X.dataLength) * K2cache[c+1][tk][q];
- v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+2, Xdata, X.dataLength) * K2cache[c+2][tk][q];
- v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+3, Xdata, X.dataLength) * K2cache[c+3][tk][q];
- }
- }
- }
- O.Set(n, y, x, k, v, Odata, O.dataLength);
- }
- }
-}
-
-groupshared float X2cache[2][CHANNEL_COUNT][KERNEL_COUNT];
-[numthreads(KERNEL_COUNT,CHANNEL_COUNT,1)]
-void Conv2D_Kernel3x3_32Channel_(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint tk = groupThreadID.x;
- uint tn = groupThreadID.y;
- uint k = KERNEL_COUNT*groupID.x + tk;
- uint n = CHANNEL_COUNT*groupID.y + tn;
-
- for (uint q = 0; q < 9; ++q)
- {
- uint tc = n % CHANNEL_COUNT;
- K2cache[q][tc][tk] = K.Get(q/3, q%3, tc, k, WBKdata, WBK.dataLength);
- }
- //GroupMemoryBarrierWithGroupSync(); <-- unnecessary, we have one inside the loop
-
- for (uint y = 0; y < X.height - _FilterSize*2; ++y)
- {
- for (uint x = 0; x < X.width - _FilterSize*2; ++x)
- {
- float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- for (uint cBlock = 0; cBlock < CHANNEL_COUNT; cBlock += KERNEL_COUNT)
- {
- for (uint q = 0; q < 9; ++q)
- {
- uint tc = k % KERNEL_COUNT;
- X2cache[q%2][tn][tc] = X.Get(n, y+q/3, x+q%3, cBlock+tc, Xdata, X.dataLength);
- GroupMemoryBarrierWithGroupSync();
-
- for (tc = 0; tc < KERNEL_COUNT; ++tc)
- v += X2cache[q%2][tn][tc] * K2cache[q][cBlock+tc][tk];
- }
- }
- O.Set(n, y, x, k, v, Odata, O.dataLength);
- }
- }
-}
-
-// 16x8 => 0.101
-// 32x4 => 0.114
-// 8x8 => 0.131
-
-#define PARAM_X 16
-#define PARAM_Y 8
-[numthreads(PARAM_X, PARAM_Y, 1)]
-void Conv2D_Kernel3x3_Kmod16_Cmod4_KN(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint k = PARAM_X * groupID.x + groupThreadID.x;
- uint n = PARAM_Y * groupID.y + groupThreadID.y;
-
- for (uint y = 0; y < X.height - _Border; ++y)
- {
- for (uint x = 0; x < X.width - _Border; ++x)
- {
- float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- for (uint dy = 0; dy < 3; ++dy)
- {
- if (y+dy < _Offset) continue;
- if (y+dy-_Offset >= X.height) continue;
- for (uint dx = 0; dx < 3; ++dx)
- {
- if (x+dx < _Offset) continue;
- if (x+dx-_Offset >= X.width) continue;
-
- for (uint c = 0; c < X.channels; c += 4)
- {
- v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+0, Xdata, X.dataLength) * K.Get(dy, dx, c+0, k, WBKdata, WBK.dataLength);
- v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+1, Xdata, X.dataLength) * K.Get(dy, dx, c+1, k, WBKdata, WBK.dataLength);
- v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+2, Xdata, X.dataLength) * K.Get(dy, dx, c+2, k, WBKdata, WBK.dataLength);
- v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+3, Xdata, X.dataLength) * K.Get(dy, dx, c+3, k, WBKdata, WBK.dataLength);
- }
- }
- }
- O.Set(n, y, x, k, v, Odata, O.dataLength);
- }
- }
-}
-#undef PARAM_X
-#undef PARAM_Y
-#define PARAM_X 16
-#define PARAM_Y 8
-
-// 16x8 => 0.096
-// 8x8 => 0.117
-[numthreads(PARAM_X, PARAM_Y, 1)]
-void Conv2D_Kernel3x3_Kmod16_Cmod4_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint k = PARAM_X * groupID.x + groupThreadID.x;
- uint nyx = PARAM_Y * groupID.y + groupThreadID.y;
-
- uint width = X.width - _Border;
- uint height = X.height - _Border;
-
- uint x = nyx % width;
- uint ny = nyx / width;
- uint y = ny % height;
- uint n = ny / height;
-
- //for (uint y = 0; y < X.height - _Border; ++y)
- //{
- // for (uint x = 0; x < X.width - _Border; ++x)
- // {
- float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- for (uint dy = 0; dy < 3; ++dy)
- {
- if (y+dy < _Offset) continue;
- if (y+dy-_Offset >= X.height) continue;
- for (uint dx = 0; dx < 3; ++dx)
- {
- if (x+dx < _Offset) continue;
- if (x+dx-_Offset >= X.width) continue;
-
- for (uint c = 0; c < X.channels; c += 4)
- {
- v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+0, Xdata, X.dataLength) * K.Get(dy, dx, c+0, k, WBKdata, WBK.dataLength);
- v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+1, Xdata, X.dataLength) * K.Get(dy, dx, c+1, k, WBKdata, WBK.dataLength);
- v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+2, Xdata, X.dataLength) * K.Get(dy, dx, c+2, k, WBKdata, WBK.dataLength);
- v += X.Get(n, y+dy-_Offset, x+dx-_Offset, c+3, Xdata, X.dataLength) * K.Get(dy, dx, c+3, k, WBKdata, WBK.dataLength);
- }
- }
- }
- O.Set(n, y, x, k, v, Odata, O.dataLength);
- // }
- //}
-}
-
-#undef CTILE
-#define CTILE 16
-
-#undef PARAM_X
-#undef PARAM_Y
-#define PARAM_X CTILE
-#define PARAM_Y CTILE
-
-#define TYPE float
-
-groupshared TYPE Conv_XcacheT[CTILE][CTILE];
-groupshared TYPE Conv_KcacheT[CTILE][CTILE];
-
-[numthreads(PARAM_X, PARAM_Y, 1)]
-void Conv2D_Kernel3x3_Cache_KCmod16_KNyx_(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ Conv_XcacheT
- #define K_ Conv_KcacheT
-
- uint gx = groupThreadID.x;
- uint gy = groupThreadID.y;
-
- uint k = PARAM_X * groupID.x + groupThreadID.x;
- uint nyx = PARAM_Y * groupID.y + groupThreadID.y;
-
- uint width = X.width - _Border;
- uint height = X.height - _Border;
-
- uint x = nyx % width;
- uint ny = nyx / width;
- uint y = ny % height;
- uint n = ny / height;
-
- //half v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- TYPE v = WBKdata[k + B.offset];
- for (uint dy = 0; dy < 3; ++dy)
- {
- bool mask = true;
-
- if (y+dy < _Offset) mask = false;
- if (y+dy-_Offset >= X.height) mask = false;
-
- for (uint dx = 0; dx < 3; ++dx)
- {
- if (x+dx < _Offset) mask = false;
- if (x+dx-_Offset >= X.width) mask = false;
-
- int Xi = (( n * X.height +
- y+dy-_Offset ) * X.width +
- x+dx-_Offset ) * X.channels +
- gx;
-
- int Ki = (( dy * K.height +
- dx ) * K.width +
- /*m*CTILE +*/ gy ) * K.channels +
- k + K.offset;
-
- for (uint m = 0; m < X.channels/CTILE; ++m)
- {
- if (mask)
- {
- //X_[gy][gx] = X.Get(n, y+dy-_Offset, x+dx-_Offset, m*CTILE + gx, Xdata);
- X_[gy][gx] = Xdata[Xi + m*CTILE];
- }
- else
- {
- X_[gy][gx] = 0;
- }
- //K_[gy][gx] = K.Get(dy, dx, m*CTILE + gy, k, WBKdata, WBK.dataLength);
- //K_[gy][gx] = WBKdata[((
- // dy * K.height +
- // dx ) * K.width +
- // m*CTILE + gy ) * K.channels +
- // k + K.offset];
- //K_[gy][gx] = WBKdata[Ki + m*CTILE * K.channels];
- K_[gy][gx] = WBKdata[Ki + m*CTILE * K.channels];
- GroupMemoryBarrierWithGroupSync();
-
- for (uint i = 0; i < CTILE;)
- {
- /*
- // can unroll up to CTILE
- half4 x4 = ((half4[CTILE][CTILE/4])(X_))[gy][i];
- half4 k4 = ((half4[CTILE][CTILE/4])(K_))[gx][i];
-
- v += dot(x4, k4); ++i;
- v += dot(x4, k4); ++i;
- */
-
- v += X_[gy][i] * K_[i][gx]; ++i;
- v += X_[gy][i] * K_[i][gx]; ++i;
- v += X_[gy][i] * K_[i][gx]; ++i;
- v += X_[gy][i] * K_[i][gx]; ++i;
- v += X_[gy][i] * K_[i][gx]; ++i;
- v += X_[gy][i] * K_[i][gx]; ++i;
- v += X_[gy][i] * K_[i][gx]; ++i;
- v += X_[gy][i] * K_[i][gx]; ++i;
-
- }
- }
- }
- }
- //O.Set(n, y, x, k, v, Odata, O.dataLength);
- Odata[((
- n * O.height +
- y ) * O.width +
- x ) * O.channels +
- k] = v;
-
- #undef X_
- #undef K_
-}
-
-#undef CTILE
-#define CTILE 16
-groupshared float Conv_XcacheA[4][CTILE][CTILE];
-groupshared float Conv_Kcache0[CTILE][CTILE];
-groupshared float Conv_Kcache1[CTILE][CTILE];
-groupshared float Conv_Kcache2[CTILE][CTILE];
-groupshared float Conv_Kcache3[CTILE][CTILE];
-[numthreads(CTILE, CTILE, 1)]
-void Conv2D_Kernel3x3_Cache_KCmod32_KNyx____(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ Conv_XcacheA
- #define K_0 Conv_Kcache0
- #define K_1 Conv_Kcache1
- #define K_2 Conv_Kcache2
- #define K_3 Conv_Kcache3
-
-
-
- uint gx = groupThreadID.x;
- uint gy = groupThreadID.y;
-
- uint k = CTILE * groupID.x + groupThreadID.x;
- uint nyx = CTILE * groupID.y + groupThreadID.y;
-
- uint width = X.width - _Border;
- uint height = X.height - _Border;
-
- uint x = nyx % width;
- uint ny = nyx / width;
- uint y = ny % height;
- uint n = ny / height;
-
- float b0 = B.Get(0, 0, k*2+0, 0, WBKdata, WBK.dataLength);
- float b1 = B.Get(0, 0, k*2+1, 0, WBKdata, WBK.dataLength);
- float4 v = float4(b0, b1,
- b0, b1);
-
- for (uint dy = 0; dy < 3; ++dy)
- {
- bool mask = true;
-
- if (y+dy < _Offset) mask = false;
- if (y+dy-_Offset >= X.height) mask = false;
-
- for (uint dx = 0; dx < 3; ++dx)
- {
- if (x+dx < _Offset) mask = false;
- if (x+dx-_Offset >= X.width) mask = false;
-
- for (uint m = 0; m < X.channels/(CTILE*2); ++m)
- {
- float x0 = 0;
- float x1 = 0;
- float x2 = 0;
- float x3 = 0;
-
- if (mask)
- {
- x0 = X.Get(n*2+0, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+0, Xdata);
- x1 = X.Get(n*2+0, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+1, Xdata);
- x2 = X.Get(n*2+1, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+0, Xdata);
- x3 = X.Get(n*2+1, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+1, Xdata);
- }
-
- float k0 = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+0, WBKdata, WBK.dataLength);
- float k1 = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+1, WBKdata, WBK.dataLength);
- float k2 = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+0, WBKdata, WBK.dataLength);
- float k3 = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+1, WBKdata, WBK.dataLength);
-
- //X_[gy][gx] = float4(x0, x1,
- // x2, x3);
- //K_[gy][gx] = float4(k0, k1,
- // k2, k3);
- X_[0][gy][gx] = x0;
- X_[1][gy][gx] = x1;
- X_[2][gy][gx] = x2;
- X_[3][gy][gx] = x3;
-
- K_0[gy][gx] = k0;
- K_1[gy][gx] = k1;
- K_2[gy][gx] = k2;
- K_3[gy][gx] = k3;
-
- GroupMemoryBarrierWithGroupSync();
-
- [unroll]
- for (uint i = 0; i < CTILE; ++i)
- {
- float4 x = //X_[gy][i];
- float4( X_[0][gy][i],
- X_[1][gy][i],
- X_[2][gy][i],
- X_[3][gy][i]);
- //float4 k = //K_[i][gx];
- // float4( K_0[i][gx],
- // K_1[i][gx],
- // K_2[i][gx],
- // K_3[i][gx]);
- k0 = K_0[i][gx];
- k1 = K_1[i][gx];
- k2 = K_2[i][gx];
- k3 = K_3[i][gx];
-
- v.x = mad(k0, x.x, v.x);
- v.x = mad(k2, x.y, v.x);
-
- v.y = mad(k1, x.x, v.y);
- v.y = mad(k2, x.y, v.y);
-
- v.z = mad(k0, x.z, v.z);
- v.z = mad(k2, x.w, v.z);
-
- v.w = mad(k1, x.z, v.w);
- v.w = mad(k3, x.w, v.w);
-
- //v.x += k.x*x.x + k.z*x.y;
- //v.y += k.y*x.x + k.w*x.y;
- //v.z += k.x*x.z + k.z*x.w;
- //v.w += k.y*x.z + k.w*x.w;
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- //Odata[nyx * O.channels + k] = v;
-
- /*Odata[((
- n * O.height +
- y ) * O.width +
- x ) * O.channels +
- k] = v;
- */
-
- O.Set(n*2+0, y, x, k*2+0, v.x, Odata);
- O.Set(n*2+0, y, x, k*2+1, v.y, Odata);
- O.Set(n*2+1, y, x, k*2+0, v.z, Odata);
- O.Set(n*2+1, y, x, k*2+1, v.w, Odata);
-
- #undef X_
- #undef K_
-}
-
-
-#undef CTILE
-#define CTILE 16
-groupshared float Conv_Xcache[4][CTILE][CTILE];
-groupshared float Conv_Kcache[4][CTILE][CTILE];
-[numthreads(CTILE, CTILE, 1)]
-void Conv2D_Kernel3x3_Cache_KCmod32_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ Conv_Xcache
- #define K_ Conv_Kcache
-
- uint gx = groupThreadID.x;
- uint gy = groupThreadID.y;
-
- uint k = CTILE * groupID.x + groupThreadID.x;
- uint nyx = CTILE * groupID.y + groupThreadID.y;
-
- uint width = X.width - _Border;
- uint height = X.height - _Border;
-
- uint x = nyx % width;
- uint ny = nyx / width;
- uint y = ny % height;
- uint n = ny / height;
-
- float b0 = B.Get(0, 0, k*2+0, 0, WBKdata, WBK.dataLength);
- float b1 = B.Get(0, 0, k*2+1, 0, WBKdata, WBK.dataLength);
- float4 v = float4(b0, b1,
- b0, b1);
-
- for (uint dy = 0; dy < 3; ++dy)
- {
- bool mask = true;
-
- if (y+dy < _Offset) mask = false;
- if (y+dy-_Offset >= X.height) mask = false;
-
- for (uint dx = 0; dx < 3; ++dx)
- {
- if (x+dx < _Offset) mask = false;
- if (x+dx-_Offset >= X.width) mask = false;
-
- for (uint m = 0; m < X.channels/(CTILE*2); ++m)
- {
- float x0 = 0;
- float x1 = 0;
- float x2 = 0;
- float x3 = 0;
-
- if (mask)
- {
- x0 = X.Get(n*2+0, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+0, Xdata);
- x1 = X.Get(n*2+0, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+1, Xdata);
- x2 = X.Get(n*2+1, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+0, Xdata);
- x3 = X.Get(n*2+1, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*2+1, Xdata);
- }
-
- float k0 = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+0, WBKdata, WBK.dataLength);
- float k1 = K.Get(dy, dx, (m*CTILE + gy)*2+0, k*2+1, WBKdata, WBK.dataLength);
- float k2 = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+0, WBKdata, WBK.dataLength);
- float k3 = K.Get(dy, dx, (m*CTILE + gy)*2+1, k*2+1, WBKdata, WBK.dataLength);
-
- //X_[gy][gx] = float4(x0, x1,
- // x2, x3);
- //K_[gy][gx] = float4(k0, k1,
- // k2, k3);
- X_[0][gy][gx] = x0;
- X_[1][gy][gx] = x1;
- X_[2][gy][gx] = x2;
- X_[3][gy][gx] = x3;
-
- K_[0][gy][gx] = k0;
- K_[1][gy][gx] = k1;
- K_[2][gy][gx] = k2;
- K_[3][gy][gx] = k3;
-
- GroupMemoryBarrierWithGroupSync();
-
- [unroll]
- for (uint i = 0; i < CTILE; ++i)
- {
- float4 x = //X_[gy][i];
- float4( X_[0][gy][i],
- X_[1][gy][i],
- X_[2][gy][i],
- X_[3][gy][i]);
- float4 k = //K_[i][gx];
- float4( K_[0][i][gx],
- K_[1][i][gx],
- K_[2][i][gx],
- K_[3][i][gx]);
-
- v.x = mad(k.x, x.x, v.x);
- v.x = mad(k.z, x.y, v.x);
-
- v.y = mad(k.y, x.x, v.y);
- v.y = mad(k.w, x.y, v.y);
-
- v.z = mad(k.x, x.z, v.z);
- v.z = mad(k.z, x.w, v.z);
-
- v.w = mad(k.y, x.z, v.w);
- v.w = mad(k.w, x.w, v.w);
-
- //v.x += k.x*x.x + k.z*x.y;
- //v.y += k.y*x.x + k.w*x.y;
- //v.z += k.x*x.z + k.z*x.w;
- //v.w += k.y*x.z + k.w*x.w;
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- //Odata[nyx * O.channels + k] = v;
-
- /*Odata[((
- n * O.height +
- y ) * O.width +
- x ) * O.channels +
- k] = v;
- */
-
- O.Set(n*2+0, y, x, k*2+0, v.x, Odata);
- O.Set(n*2+0, y, x, k*2+1, v.y, Odata);
- O.Set(n*2+1, y, x, k*2+0, v.z, Odata);
- O.Set(n*2+1, y, x, k*2+1, v.w, Odata);
-
- #undef X_
- #undef K_
-}
-
-#if 0 // =====================================================================================================
-
-#undef CTILE
-#define CTILE 16
-#define RTILE 4
-groupshared float Conv_XcacheR[RTILE*RTILE][CTILE*CTILE];
-groupshared float Conv_KcacheR[RTILE*RTILE][CTILE*CTILE];
-[numthreads(CTILE, CTILE, 1)]
-void Conv2D_Kernel3x3_Cache_KCmod64_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ Conv_XcacheR
- #define K_ Conv_KcacheR
-
- uint gx = groupThreadID.x;
- uint gy = groupThreadID.y;
-
- uint k = CTILE * groupID.x + groupThreadID.x;
- uint nyx = CTILE * groupID.y + groupThreadID.y;
-
- uint width = X.width - _Border;
- uint height = X.height - _Border;
-
- uint x = nyx % width;
- uint ny = nyx / width;
- uint y = ny % height;
- uint n = ny / height;
-
- float v[RTILE][RTILE];
- for (uint xxxx = 0; xxxx < RTILE; ++xxxx)
- {
- float b = B.Get(0, 0, k*RTILE+xxxx, 0, WBKdata, WBK.dataLength);
- for (uint yyyy = 0; yyyy < RTILE; ++yyyy)
- v[yyyy][xxxx] = b;
- }
-
- for (uint dy = 0; dy < 3; ++dy)
- {
- bool mask = true;
-
- if (y+dy < _Offset) mask = false;
- if (y+dy-_Offset >= X.height) mask = false;
-
- for (uint dx = 0; dx < 3; ++dx)
- {
- if (x+dx < _Offset) mask = false;
- if (x+dx-_Offset >= X.width) mask = false;
-
- for (uint m = 0; m < X.channels/(CTILE*RTILE); ++m)
- {
- for (uint yy = 0; yy < RTILE; ++yy)
- for (uint xx = 0; xx < RTILE; ++xx)
- {
- if (mask)
- X_[yy*RTILE+xx][gy*CTILE+gx] = X.Get(n*RTILE+yy, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*RTILE+xx, Xdata);
- else
- X_[yy*RTILE+xx][gy*CTILE+gx] = 0;
- K_[yy*RTILE+xx][gy*CTILE+gx] = K.Get(dy, dx, (m*CTILE + gy)*RTILE+yy, k*RTILE+xx, WBKdata, WBK.dataLength);
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- for (uint ii = 0; ii < CTILE; ++ii)
- {
- float x[RTILE][RTILE];
- float k[RTILE][RTILE];
-
- [unroll]
- for (uint yy = 0; yy < RTILE; ++yy)
- {
- [unroll]
- for (uint xx = 0; xx < RTILE; ++xx)
- {
- x[yy][xx] = X_[yy*RTILE+xx][gy*CTILE+ii];
- k[yy][xx] = K_[yy*RTILE+xx][ii*CTILE+gx];
- }
- }
-
-
- [unroll]
- for (uint yyy = 0; yyy < RTILE; ++yyy)
- {
- [unroll]
- for (uint xxx = 0; xxx < RTILE; ++xxx)
- {
- [unroll]
- for (uint i = 0; i < RTILE; ++i)
- {
- v[yyy][xxx] = mad(x[yyy][i], k[i][xxx], v[yyy][xxx]);
- }
- }
- }
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- for (uint yy = 0; yy < RTILE; ++yy)
- for (uint xx = 0; xx < RTILE; ++xx)
- O.Set(n*RTILE+yy, y, x, k*RTILE+xx, v[yy][xx], Odata);
-
- #undef X_
- #undef K_
-}
-
-#elif 1 // =====================================================================================================
-
-#undef CTILE
-#define CTILE 16
-groupshared float2 Conv_KcacheR[8][CTILE*CTILE];
-groupshared float2 Conv_XcacheR[8][CTILE*CTILE];
-[numthreads(CTILE, CTILE, 1)]
-void Conv2D_Kernel3x3_Cache_KCmod64_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ Conv_XcacheR
- #define K_ Conv_KcacheR
-
- uint gx = groupThreadID.x;
- uint gy = groupThreadID.y;
-
- uint k = CTILE * groupID.x + groupThreadID.x;
- uint nyx = CTILE * groupID.y + groupThreadID.y;
-
- uint width = X.width - _Border;
- uint height = X.height - _Border;
-
- uint x = nyx % width;
- uint ny = nyx / width;
- uint y = ny % height;
- uint n = ny / height;
-
- float b0 = B.Get(0, 0, k*4+0, 0, WBKdata, WBK.dataLength);
- float b1 = B.Get(0, 0, k*4+1, 0, WBKdata, WBK.dataLength);
- float b2 = B.Get(0, 0, k*4+2, 0, WBKdata, WBK.dataLength);
- float b3 = B.Get(0, 0, k*4+3, 0, WBKdata, WBK.dataLength);
-
- float4 v0, v1, v2, v3;
- v0 = v1 = v2 = v3 = float4(b0, b1, b2, b3);
-
- for (uint dy = 0; dy < 3; ++dy)
- {
- bool mask = true;
-
- if (y+dy < _Offset) mask = false;
- if (y+dy-_Offset >= X.height) mask = false;
-
- for (uint dx = 0; dx < 3; ++dx)
- {
- if (x+dx < _Offset) mask = false;
- if (x+dx-_Offset >= X.width) mask = false;
-
- for (uint m = 0; m < X.channels/(CTILE*4); ++m)
- {
- for (uint yy = 0; yy < 4; ++yy)
- for (uint xx = 0; xx < 2; ++xx)
- {
- // 111ms
- if (mask)
- {
- X_[yy*2+xx][gy*CTILE+gx].x = X.Get(n*4+yy, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*4+xx*2+0, Xdata);
- X_[yy*2+xx][gy*CTILE+gx].y = X.Get(n*4+yy, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*4+xx*2+1, Xdata);
- }
- else
- {
- X_[yy*2+xx][gy*CTILE+gx].x = 0;
- X_[yy*2+xx][gy*CTILE+gx].y = 0;
- }
-
- K_[yy*2+xx][gy*CTILE+gx].x = K.Get(dy, dx, (m*CTILE + gy)*4+yy, k*4+xx*2+0, WBKdata, WBK.dataLength);
- K_[yy*2+xx][gy*CTILE+gx].y = K.Get(dy, dx, (m*CTILE + gy)*4+yy, k*4+xx*2+1, WBKdata, WBK.dataLength);
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- for (uint i = 0; i < CTILE; ++i)
- {
- #if 1 // ----------------------------------------------------------
-
- float2 x[8];
- float2 k[8];
-
- // 109ms
- // dcl_temps 29
- for (uint regs = 0; regs < 8; ++regs)
- {
- x[regs] = X_[regs][gy*CTILE+i];
- k[regs] = K_[regs][i*CTILE+gx];
- }
-
- for (uint q = 0; q < 4; ++q)
- {
- float
- k0 = k[q*2+0].x,
- k1 = k[q*2+0].y,
- k2 = k[q*2+1].x,
- k3 = k[q*2+1].y;
- float
- x0 = x[0+q/2].x,
- x1 = x[2+q/2].x,
- x2 = x[4+q/2].x,
- x3 = x[6+q/2].x;
-
- v0.x = mad(x0, k0, v0.x); //--
- v1.x = mad(x1, k0, v1.x);
- v2.x = mad(x2, k0, v2.x);
- v3.x = mad(x3, k0, v3.x);
- v0.y = mad(x0, k1, v0.y); //--
- v1.y = mad(x1, k1, v1.y);
- v2.y = mad(x2, k1, v2.y);
- v3.y = mad(x3, k1, v3.y);
- v0.z = mad(x0, k2, v0.z); //--
- v1.z = mad(x1, k2, v1.z);
- v2.z = mad(x2, k2, v2.z);
- v3.z = mad(x3, k2, v3.z);
- v0.w = mad(x0, k3, v0.w); //--
- v1.w = mad(x1, k3, v1.w);
- v2.w = mad(x2, k3, v2.w);
- v3.w = mad(x3, k3, v3.w);
-
- ++q;
-
- k0 = k[q*2+0].x;
- k1 = k[q*2+0].y;
- k2 = k[q*2+1].x;
- k3 = k[q*2+1].y;
-
- x0 = x[0+q/2].y;
- x1 = x[2+q/2].y;
- x2 = x[4+q/2].y;
- x3 = x[6+q/2].y;
-
- v0.x = mad(x0, k0, v0.x); //--
- v1.x = mad(x1, k0, v1.x);
- v2.x = mad(x2, k0, v2.x);
- v3.x = mad(x3, k0, v3.x);
- v0.y = mad(x0, k1, v0.y); //--
- v1.y = mad(x1, k1, v1.y);
- v2.y = mad(x2, k1, v2.y);
- v3.y = mad(x3, k1, v3.y);
- v0.z = mad(x0, k2, v0.z); //--
- v1.z = mad(x1, k2, v1.z);
- v2.z = mad(x2, k2, v2.z);
- v3.z = mad(x3, k2, v3.z);
- v0.w = mad(x0, k3, v0.w); //--
- v1.w = mad(x1, k3, v1.w);
- v2.w = mad(x2, k3, v2.w);
- v3.w = mad(x3, k3, v3.w);
- }
-
- #endif // ----------------------------------------------------------
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- #if 1 // ----------------------------------------------------------
-
- // 117ms
- O.Set(n*4+0, y, x, k*4+0, v0.x, Odata);
- O.Set(n*4+0, y, x, k*4+1, v0.y, Odata);
- O.Set(n*4+0, y, x, k*4+2, v0.z, Odata);
- O.Set(n*4+0, y, x, k*4+3, v0.w, Odata);
-
- O.Set(n*4+1, y, x, k*4+0, v1.x, Odata);
- O.Set(n*4+1, y, x, k*4+1, v1.y, Odata);
- O.Set(n*4+1, y, x, k*4+2, v1.z, Odata);
- O.Set(n*4+1, y, x, k*4+3, v1.w, Odata);
-
- O.Set(n*4+2, y, x, k*4+0, v2.x, Odata);
- O.Set(n*4+2, y, x, k*4+1, v2.y, Odata);
- O.Set(n*4+2, y, x, k*4+2, v2.z, Odata);
- O.Set(n*4+2, y, x, k*4+3, v2.w, Odata);
-
- O.Set(n*4+3, y, x, k*4+0, v3.x, Odata);
- O.Set(n*4+3, y, x, k*4+1, v3.y, Odata);
- O.Set(n*4+3, y, x, k*4+2, v3.z, Odata);
- O.Set(n*4+3, y, x, k*4+3, v3.w, Odata);
-
- #else // ----------------------------------------------------------
-
- // 118ms
- O.Set(n*4+0, y, x, k*4+0, v0.x, Odata);
- O.Set(n*4+1, y, x, k*4+0, v1.x, Odata);
- O.Set(n*4+2, y, x, k*4+0, v2.x, Odata);
- O.Set(n*4+3, y, x, k*4+0, v3.x, Odata);
-
- O.Set(n*4+0, y, x, k*4+1, v0.y, Odata);
- O.Set(n*4+1, y, x, k*4+1, v1.y, Odata);
- O.Set(n*4+2, y, x, k*4+1, v2.y, Odata);
- O.Set(n*4+3, y, x, k*4+1, v3.y, Odata);
-
- O.Set(n*4+0, y, x, k*4+2, v0.z, Odata);
- O.Set(n*4+1, y, x, k*4+2, v1.z, Odata);
- O.Set(n*4+2, y, x, k*4+2, v2.z, Odata);
- O.Set(n*4+3, y, x, k*4+2, v3.z, Odata);
-
- O.Set(n*4+0, y, x, k*4+3, v0.w, Odata);
- O.Set(n*4+1, y, x, k*4+3, v1.w, Odata);
- O.Set(n*4+2, y, x, k*4+3, v2.w, Odata);
- O.Set(n*4+3, y, x, k*4+3, v3.w, Odata);
-
- #endif // ----------------------------------------------------------
-
-
- #undef X_
- #undef K_
-}
-
-#elif 1 // =====================================================================================================
-
-#undef CTILE
-#define CTILE 16
-groupshared float Conv_KcacheR[16][CTILE*CTILE];
-groupshared float Conv_XcacheR[16][CTILE*CTILE];
-[numthreads(CTILE, CTILE, 1)]
-void Conv2D_Kernel3x3_Cache_KCmod64_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ Conv_XcacheR
- #define K_ Conv_KcacheR
-
- uint gx = groupThreadID.x;
- uint gy = groupThreadID.y;
-
- uint k = CTILE * groupID.x + groupThreadID.x;
- uint nyx = CTILE * groupID.y + groupThreadID.y;
-
- uint width = X.width - _Border;
- uint height = X.height - _Border;
-
- uint x = nyx % width;
- uint ny = nyx / width;
- uint y = ny % height;
- uint n = ny / height;
-
- float b0 = B.Get(0, 0, k*4+0, 0, WBKdata, WBK.dataLength);
- float b1 = B.Get(0, 0, k*4+1, 0, WBKdata, WBK.dataLength);
- float b2 = B.Get(0, 0, k*4+2, 0, WBKdata, WBK.dataLength);
- float b3 = B.Get(0, 0, k*4+3, 0, WBKdata, WBK.dataLength);
-
- float4 v0, v1, v2, v3;
- v0 = v1 = v2 = v3 = float4(b0, b1, b2, b3);
-
- for (uint dy = 0; dy < 3; ++dy)
- {
- bool mask = true;
-
- if (y+dy < _Offset) mask = false;
- if (y+dy-_Offset >= X.height) mask = false;
-
- for (uint dx = 0; dx < 3; ++dx)
- {
- if (x+dx < _Offset) mask = false;
- if (x+dx-_Offset >= X.width) mask = false;
-
- for (uint m = 0; m < X.channels/(CTILE*4); ++m)
- {
- for (uint yy = 0; yy < 4; ++yy)
- for (uint xx = 0; xx < 4; ++xx)
- {
- #if 1 // ----------------------------------------------------------
-
- // 111ms
- if (mask)
- X_[yy*4+xx][gy*CTILE+gx] = X.Get(n*4+yy, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*4+xx, Xdata);
- else
- X_[yy*4+xx][gy*CTILE+gx] = 0;
- K_[yy*4+xx][gy*CTILE+gx] = K.Get(dy, dx, (m*CTILE + gy)*4+yy, k*4+xx, WBKdata, WBK.dataLength);
-
- #else // ----------------------------------------------------------
-
- // 122ms
- if (mask)
- X_[yy*4+(gx%4)][gy*CTILE+xx*4+(gx/4)] = X.Get(n*4+yy, y+dy-_Offset, x+dx-_Offset, m*CTILE*4 + xx*CTILE + gx, Xdata);
- else
- X_[yy*4+(gx%4)][gy*CTILE+xx*4+(gx/4)] = 0;
- K_[yy*4+(k%4)][gy*CTILE+xx*4+(gx/4)] = K.Get(dy, dx, (m*CTILE + gy)*4+yy, CTILE*groupID.x*4 + xx*CTILE + gx, WBKdata, WBK.dataLength);
-
- #endif // ----------------------------------------------------------
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- for (uint i = 0; i < CTILE; ++i)
- {
-
- #if 0 // ----------------------------------------------------------
-
- float x[16];
- float k[16];
-
- k[0] = K_[0][i*CTILE+gx];
- x[0] = X_[0][gy*CTILE+i];
- x[4] = X_[4][gy*CTILE+i];
- x[8] = X_[8][gy*CTILE+i];
- x[12] = X_[12][gy*CTILE+i];
-
- for (uint q = 0; q < 3; ++q)
- {
- k[q*4+1] = K_[q*4+1][i*CTILE+gx];
- v0.x = mad(x[0*4+q], k[q*4+0], v0.x); //--
- v1.x = mad(x[1*4+q], k[q*4+0], v1.x);
- x[0*4+q+1] = X_[0*4+q+1][gy*CTILE+i];
- v2.x = mad(x[2*4+q], k[q*4+0], v2.x);
- v3.x = mad(x[3*4+q], k[q*4+0], v3.x);
- k[q*4+2] = K_[q*4+2][i*CTILE+gx];
- v0.y = mad(x[0*4+q], k[q*4+1], v0.y); //--
- v1.y = mad(x[1*4+q], k[q*4+1], v1.y);
- x[1*4+q+1] = X_[1*4+q+1][gy*CTILE+i];
- v2.y = mad(x[2*4+q], k[q*4+1], v2.y);
- v3.y = mad(x[3*4+q], k[q*4+1], v3.y);
- k[q*4+3] = K_[q*4+3][i*CTILE+gx];
- v0.z = mad(x[0*4+q], k[q*4+2], v0.z); //--
- v1.z = mad(x[1*4+q], k[q*4+2], v1.z);
- x[2*4+q+1] = X_[2*4+q+1][gy*CTILE+i];
- v2.z = mad(x[2*4+q], k[q*4+2], v2.z);
- v3.z = mad(x[3*4+q], k[q*4+2], v3.z);
- k[q*4+4] = K_[q*4+4][i*CTILE+gx];
- v0.w = mad(x[0*4+q], k[q*4+3], v0.w); //--
- v1.w = mad(x[1*4+q], k[q*4+3], v1.w);
- x[3*4+q+1] = X_[3*4+q+1][gy*CTILE+i];
- v2.w = mad(x[2*4+q], k[q*4+3], v2.w);
- v3.w = mad(x[3*4+q], k[q*4+3], v3.w);
- }
- {
- k[q*4+1] = K_[q*4+1][i*CTILE+gx];
- v0.x = mad(x[0*4+q], k[q*4+0], v0.x); //--
- v1.x = mad(x[1*4+q], k[q*4+0], v1.x);
- v2.x = mad(x[2*4+q], k[q*4+0], v2.x);
- v3.x = mad(x[3*4+q], k[q*4+0], v3.x);
- k[q*4+2] = K_[q*4+2][i*CTILE+gx];
- v0.y = mad(x[0*4+q], k[q*4+1], v0.y); //--
- v1.y = mad(x[1*4+q], k[q*4+1], v1.y);
- v2.y = mad(x[2*4+q], k[q*4+1], v2.y);
- v3.y = mad(x[3*4+q], k[q*4+1], v3.y);
- k[q*4+3] = K_[q*4+3][i*CTILE+gx];
- v0.z = mad(x[0*4+q], k[q*4+2], v0.z); //--
- v1.z = mad(x[1*4+q], k[q*4+2], v1.z);
- v2.z = mad(x[2*4+q], k[q*4+2], v2.z);
- v3.z = mad(x[3*4+q], k[q*4+2], v3.z);
- v0.w = mad(x[0*4+q], k[q*4+3], v0.w); //--
- v1.w = mad(x[1*4+q], k[q*4+3], v1.w);
- v2.w = mad(x[2*4+q], k[q*4+3], v2.w);
- v3.w = mad(x[3*4+q], k[q*4+3], v3.w);
- }
-
- #elif 0 // ----------------------------------------------------------
-
- //float x[4];
- //float k[4];
-
- float k0 = K_[0*4+0][i*CTILE+gx];
- float x0 = X_[0*4+0][gy*CTILE+i];
- float x1 = X_[1*4+0][gy*CTILE+i];
- float x2 = X_[2*4+0][gy*CTILE+i];
- float x3 = X_[3*4+0][gy*CTILE+i];
-
- float k1, k2, k3;
- float x0p, x1p, x2p, x3p;
-
- uint q = 0;
- //for (uint q = 0; q < 4;)
- {
- //x[regs] = X_[regs][gy*CTILE+i];
-
- k1 = K_[q*4+1][i*CTILE+gx];
- v0.x = mad(x0, k0, v0.x); //--
- v1.x = mad(x1, k0, v1.x);
- x0p = X_[0*4+q+1][gy*CTILE+i];
- v2.x = mad(x2, k0, v2.x);
- v3.x = mad(x3, k0, v3.x);
-
- k2 = K_[q*4+2][i*CTILE+gx];
- v0.y = mad(x0, k1, v0.y); //--
- v1.y = mad(x1, k1, v1.y);
- x1p = X_[1*4+q+1][gy*CTILE+i];
- v2.y = mad(x2, k1, v2.y);
- v3.y = mad(x3, k1, v3.y);
-
- k3 = K_[q*4+3][i*CTILE+gx];
- v0.z = mad(x0, k2, v0.z); //--
- v1.z = mad(x1, k2, v1.z);
- x2p = X_[2*4+q+1][gy*CTILE+i];
- v2.z = mad(x2, k2, v2.z);
- v3.z = mad(x3, k2, v3.z);
-
- k0 = K_[q*4+4][i*CTILE+gx];
- v0.w = mad(x0, k3, v0.w); //--
- v1.w = mad(x1, k3, v1.w);
- x3p = X_[3*4+q+1][gy*CTILE+i];
- v2.w = mad(x2, k3, v2.w);
- v3.w = mad(x3, k3, v3.w);
-
- ++q;
-
- k1 = K_[q*4+1][i*CTILE+gx];
- v0.x = mad(x0p, k0, v0.x); //--
- v1.x = mad(x1p, k0, v1.x);
- x0 = X_[0*4+q+1][gy*CTILE+i];
- v2.x = mad(x2p, k0, v2.x);
- v3.x = mad(x3p, k0, v3.x);
-
- k2 = K_[q*4+2][i*CTILE+gx];
- v0.y = mad(x0p, k1, v0.y); //--
- v1.y = mad(x1p, k1, v1.y);
- x1 = X_[1*4+q+1][gy*CTILE+i];
- v2.y = mad(x2p, k1, v2.y);
- v3.y = mad(x3p, k1, v3.y);
-
- k3 = K_[q*4+3][i*CTILE+gx];
- v0.z = mad(x0p, k2, v0.z); //--
- v1.z = mad(x1p, k2, v1.z);
- x2 = X_[2*4+q+1][gy*CTILE+i];
- v2.z = mad(x2p, k2, v2.z);
- v3.z = mad(x3p, k2, v3.z);
-
- k0 = K_[q*4+4][i*CTILE+gx];
- v0.w = mad(x0p, k3, v0.w); //--
- v1.w = mad(x1p, k3, v1.w);
- x3 = X_[3*4+q+1][gy*CTILE+i];
- v2.w = mad(x2p, k3, v2.w);
- v3.w = mad(x3p, k3, v3.w);
-
- ++q;
-
- k1 = K_[q*4+1][i*CTILE+gx];
- v0.x = mad(x0, k0, v0.x); //--
- v1.x = mad(x1, k0, v1.x);
- x0p = X_[0*4+q+1][gy*CTILE+i];
- v2.x = mad(x2, k0, v2.x);
- v3.x = mad(x3, k0, v3.x);
-
- k2 = K_[q*4+2][i*CTILE+gx];
- v0.y = mad(x0, k1, v0.y); //--
- v1.y = mad(x1, k1, v1.y);
- x1p = X_[1*4+q+1][gy*CTILE+i];
- v2.y = mad(x2, k1, v2.y);
- v3.y = mad(x3, k1, v3.y);
-
- k3 = K_[q*4+3][i*CTILE+gx];
- v0.z = mad(x0, k2, v0.z); //--
- v1.z = mad(x1, k2, v1.z);
- x2p = X_[2*4+q+1][gy*CTILE+i];
- v2.z = mad(x2, k2, v2.z);
- v3.z = mad(x3, k2, v3.z);
-
- k0 = K_[q*4+4][i*CTILE+gx];
- v0.w = mad(x0, k3, v0.w); //--
- v1.w = mad(x1, k3, v1.w);
- x3p = X_[3*4+q+1][gy*CTILE+i];
- v2.w = mad(x2, k3, v2.w);
- v3.w = mad(x3, k3, v3.w);
-
- ++q;
-
- k1 = K_[q*4+1][i*CTILE+gx];
- v0.x = mad(x0p, k0, v0.x); //--
- v1.x = mad(x1p, k0, v1.x);
- //x0p = X_[0*4+q][gy*CTILE+i];
- v2.x = mad(x2p, k0, v2.x);
- v3.x = mad(x3p, k0, v3.x);
-
- k2 = K_[q*4+2][i*CTILE+gx];
- v0.y = mad(x0p, k1, v0.y); //--
- v1.y = mad(x1p, k1, v1.y);
- //x1p = X_[1*4+q][gy*CTILE+i];
- v2.y = mad(x2p, k1, v2.y);
- v3.y = mad(x3p, k1, v3.y);
-
- k3 = K_[q*4+3][i*CTILE+gx];
- v0.z = mad(x0p, k2, v0.z); //--
- v1.z = mad(x1p, k2, v1.z);
- //x2p = X_[2*4+q][gy*CTILE+i];
- v2.z = mad(x2p, k2, v2.z);
- v3.z = mad(x3p, k2, v3.z);
-
- //k0 = K_[(q+1)*4][i*CTILE+gx];
- v0.w = mad(x0p, k3, v0.w); //--
- v1.w = mad(x1p, k3, v1.w);
- //x3p = X_[3*4+q][gy*CTILE+i];
- v2.w = mad(x2p, k3, v2.w);
- v3.w = mad(x3p, k3, v3.w);
-
- ++q;
- }
-
-
- #elif 1 // ----------------------------------------------------------
-
- float x[16];
- float k[16];
-
- // 109ms
- // dcl_temps 29
- for (uint regs = 0; regs < 16; ++regs)
- {
- x[regs] = X_[regs][gy*CTILE+i];
- k[regs] = K_[regs][i*CTILE+gx];
- }
-
- for (uint q = 0; q < 4; ++q)
- {
- v0.x = mad(x[0*4+q], k[q*4+0], v0.x); //--
- v1.x = mad(x[1*4+q], k[q*4+0], v1.x);
- v2.x = mad(x[2*4+q], k[q*4+0], v2.x);
- v3.x = mad(x[3*4+q], k[q*4+0], v3.x);
- v0.y = mad(x[0*4+q], k[q*4+1], v0.y); //--
- v1.y = mad(x[1*4+q], k[q*4+1], v1.y);
- v2.y = mad(x[2*4+q], k[q*4+1], v2.y);
- v3.y = mad(x[3*4+q], k[q*4+1], v3.y);
- v0.z = mad(x[0*4+q], k[q*4+2], v0.z); //--
- v1.z = mad(x[1*4+q], k[q*4+2], v1.z);
- v2.z = mad(x[2*4+q], k[q*4+2], v2.z);
- v3.z = mad(x[3*4+q], k[q*4+2], v3.z);
- v0.w = mad(x[0*4+q], k[q*4+3], v0.w); //--
- v1.w = mad(x[1*4+q], k[q*4+3], v1.w);
- v2.w = mad(x[2*4+q], k[q*4+3], v2.w);
- v3.w = mad(x[3*4+q], k[q*4+3], v3.w);
- }
-
- #elif 1 // ----------------------------------------------------------
-
- // 111ms
- // dcl_temps 34
- [unroll]
- for (uint regs = 0; regs < 16; ++regs)
- {
- x[regs] = X_[regs][gy*CTILE+i];
- k[regs] = K_[regs][i*CTILE+gx];
- }
- v0.x = mad(x[0*4+0], k[0*4+0], v0.x); //--
- v1.x = mad(x[1*4+0], k[0*4+0], v1.x);
- v2.x = mad(x[2*4+0], k[0*4+0], v2.x);
- v3.x = mad(x[3*4+0], k[0*4+0], v3.x);
- v0.y = mad(x[0*4+0], k[0*4+1], v0.y); //--
- v1.y = mad(x[1*4+0], k[0*4+1], v1.y);
- v2.y = mad(x[2*4+0], k[0*4+1], v2.y);
- v3.y = mad(x[3*4+0], k[0*4+1], v3.y);
- v0.z = mad(x[0*4+0], k[0*4+2], v0.z); //--
- v1.z = mad(x[1*4+0], k[0*4+2], v1.z);
- v2.z = mad(x[2*4+0], k[0*4+2], v2.z);
- v3.z = mad(x[3*4+0], k[0*4+2], v3.z);
- v0.w = mad(x[0*4+0], k[0*4+3], v0.w); //--
- v1.w = mad(x[1*4+0], k[0*4+3], v1.w);
- v2.w = mad(x[2*4+0], k[0*4+3], v2.w);
- v3.w = mad(x[3*4+0], k[0*4+3], v3.w);
-
- v0.x = mad(x[0*4+1], k[1*4+0], v0.x); //--
- v1.x = mad(x[1*4+1], k[1*4+0], v1.x);
- v2.x = mad(x[2*4+1], k[1*4+0], v2.x);
- v3.x = mad(x[3*4+1], k[1*4+0], v3.x);
- v0.y = mad(x[0*4+1], k[1*4+1], v0.y); //--
- v1.y = mad(x[1*4+1], k[1*4+1], v1.y);
- v2.y = mad(x[2*4+1], k[1*4+1], v2.y);
- v3.y = mad(x[3*4+1], k[1*4+1], v3.y);
- v0.z = mad(x[0*4+1], k[1*4+2], v0.z); //--
- v1.z = mad(x[1*4+1], k[1*4+2], v1.z);
- v2.z = mad(x[2*4+1], k[1*4+2], v2.z);
- v3.z = mad(x[3*4+1], k[1*4+2], v3.z);
- v0.w = mad(x[0*4+1], k[1*4+3], v0.w); //--
- v1.w = mad(x[1*4+1], k[1*4+3], v1.w);
- v2.w = mad(x[2*4+1], k[1*4+3], v2.w);
- v3.w = mad(x[3*4+1], k[1*4+3], v3.w);
-
- v0.x = mad(x[0*4+2], k[2*4+0], v0.x); //--
- v1.x = mad(x[1*4+2], k[2*4+0], v1.x);
- v2.x = mad(x[2*4+2], k[2*4+0], v2.x);
- v3.x = mad(x[3*4+2], k[2*4+0], v3.x);
- v0.y = mad(x[0*4+2], k[2*4+1], v0.y); //--
- v1.y = mad(x[1*4+2], k[2*4+1], v1.y);
- v2.y = mad(x[2*4+2], k[2*4+1], v2.y);
- v3.y = mad(x[3*4+2], k[2*4+1], v3.y);
- v0.z = mad(x[0*4+2], k[2*4+2], v0.z); //--
- v1.z = mad(x[1*4+2], k[2*4+2], v1.z);
- v2.z = mad(x[2*4+2], k[2*4+2], v2.z);
- v3.z = mad(x[3*4+2], k[2*4+2], v3.z);
- v0.w = mad(x[0*4+2], k[2*4+3], v0.w); //--
- v1.w = mad(x[1*4+2], k[2*4+3], v1.w);
- v2.w = mad(x[2*4+2], k[2*4+3], v2.w);
- v3.w = mad(x[3*4+2], k[2*4+3], v3.w);
-
- v0.x = mad(x[0*4+3], k[3*4+0], v0.x); //--
- v1.x = mad(x[1*4+3], k[3*4+0], v1.x);
- v2.x = mad(x[2*4+3], k[3*4+0], v2.x);
- v3.x = mad(x[3*4+3], k[3*4+0], v3.x);
- v0.y = mad(x[0*4+3], k[3*4+1], v0.y); //--
- v1.y = mad(x[1*4+3], k[3*4+1], v1.y);
- v2.y = mad(x[2*4+3], k[3*4+1], v2.y);
- v3.y = mad(x[3*4+3], k[3*4+1], v3.y);
- v0.z = mad(x[0*4+3], k[3*4+2], v0.z); //--
- v1.z = mad(x[1*4+3], k[3*4+2], v1.z);
- v2.z = mad(x[2*4+3], k[3*4+2], v2.z);
- v3.z = mad(x[3*4+3], k[3*4+2], v3.z);
- v0.w = mad(x[0*4+3], k[3*4+3], v0.w); //--
- v1.w = mad(x[1*4+3], k[3*4+3], v1.w);
- v2.w = mad(x[2*4+3], k[3*4+3], v2.w);
- v3.w = mad(x[3*4+3], k[3*4+3], v3.w);
-
- #else // ----------------------------------------------------------
-
- // 115 ms, reg dependencies
- // dcl_temps 32
- [unroll]
- for (uint regs = 0; regs < 16; ++regs)
- {
- x[regs] = X_[regs][gy*CTILE+i];
- k[regs] = K_[regs][i*CTILE+gx];
- }
-
- v0.x = mad(x[0*4+0], k[0*4+0], v0.x); //--
- v0.x = mad(x[0*4+1], k[1*4+0], v0.x);
- v0.x = mad(x[0*4+2], k[2*4+0], v0.x);
- v0.x = mad(x[0*4+3], k[3*4+0], v0.x);
- v0.y = mad(x[0*4+0], k[0*4+1], v0.y); //--
- v0.y = mad(x[0*4+1], k[1*4+1], v0.y);
- v0.y = mad(x[0*4+2], k[2*4+1], v0.y);
- v0.y = mad(x[0*4+3], k[3*4+1], v0.y);
- v0.z = mad(x[0*4+0], k[0*4+2], v0.z); //--
- v0.z = mad(x[0*4+1], k[1*4+2], v0.z);
- v0.z = mad(x[0*4+2], k[2*4+2], v0.z);
- v0.z = mad(x[0*4+3], k[3*4+2], v0.z);
- v0.w = mad(x[0*4+0], k[0*4+3], v0.w); //--
- v0.w = mad(x[0*4+1], k[1*4+3], v0.w);
- v0.w = mad(x[0*4+2], k[2*4+3], v0.w);
- v0.w = mad(x[0*4+3], k[3*4+3], v0.w);
-
- v1.x = mad(x[1*4+0], k[0*4+0], v1.x); //--
- v1.x = mad(x[1*4+1], k[1*4+0], v1.x);
- v1.x = mad(x[1*4+2], k[2*4+0], v1.x);
- v1.x = mad(x[1*4+3], k[3*4+0], v1.x);
- v1.y = mad(x[1*4+0], k[0*4+1], v1.y); //--
- v1.y = mad(x[1*4+1], k[1*4+1], v1.y);
- v1.y = mad(x[1*4+2], k[2*4+1], v1.y);
- v1.y = mad(x[1*4+3], k[3*4+1], v1.y);
- v1.z = mad(x[1*4+0], k[0*4+2], v1.z); //--
- v1.z = mad(x[1*4+1], k[1*4+2], v1.z);
- v1.z = mad(x[1*4+2], k[2*4+2], v1.z);
- v1.z = mad(x[1*4+3], k[3*4+2], v1.z);
- v1.w = mad(x[1*4+0], k[0*4+3], v1.w); //--
- v1.w = mad(x[1*4+1], k[1*4+3], v1.w);
- v1.w = mad(x[1*4+2], k[2*4+3], v1.w);
- v1.w = mad(x[1*4+3], k[3*4+3], v1.w);
-
- v2.x = mad(x[2*4+0], k[0*4+0], v2.x); //--
- v2.x = mad(x[2*4+1], k[1*4+0], v2.x);
- v2.x = mad(x[2*4+2], k[2*4+0], v2.x);
- v2.x = mad(x[2*4+3], k[3*4+0], v2.x);
- v2.y = mad(x[2*4+0], k[0*4+1], v2.y); //--
- v2.y = mad(x[2*4+1], k[1*4+1], v2.y);
- v2.y = mad(x[2*4+2], k[2*4+1], v2.y);
- v2.y = mad(x[2*4+3], k[3*4+1], v2.y);
- v2.z = mad(x[2*4+0], k[0*4+2], v2.z); //--
- v2.z = mad(x[2*4+1], k[1*4+2], v2.z);
- v2.z = mad(x[2*4+2], k[2*4+2], v2.z);
- v2.z = mad(x[2*4+3], k[3*4+2], v2.z);
- v2.w = mad(x[2*4+0], k[0*4+3], v2.w); //--
- v2.w = mad(x[2*4+1], k[1*4+3], v2.w);
- v2.w = mad(x[2*4+2], k[2*4+3], v2.w);
- v2.w = mad(x[2*4+3], k[3*4+3], v2.w);
-
- v3.x = mad(x[3*4+0], k[0*4+0], v3.x); //--
- v3.x = mad(x[3*4+1], k[1*4+0], v3.x);
- v3.x = mad(x[3*4+2], k[2*4+0], v3.x);
- v3.x = mad(x[3*4+3], k[3*4+0], v3.x);
- v3.y = mad(x[3*4+0], k[0*4+1], v3.y); //--
- v3.y = mad(x[3*4+1], k[1*4+1], v3.y);
- v3.y = mad(x[3*4+2], k[2*4+1], v3.y);
- v3.y = mad(x[3*4+3], k[3*4+1], v3.y);
- v3.z = mad(x[3*4+0], k[0*4+2], v3.z); //--
- v3.z = mad(x[3*4+1], k[1*4+2], v3.z);
- v3.z = mad(x[3*4+2], k[2*4+2], v3.z);
- v3.z = mad(x[3*4+3], k[3*4+2], v3.z);
- v3.w = mad(x[3*4+0], k[0*4+3], v3.w); //--
- v3.w = mad(x[3*4+1], k[1*4+3], v3.w);
- v3.w = mad(x[3*4+2], k[2*4+3], v3.w);
- v3.w = mad(x[3*4+3], k[3*4+3], v3.w);
-
- #endif // ----------------------------------------------------------
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- #if 1 // ----------------------------------------------------------
-
- // 117ms
- O.Set(n*4+0, y, x, k*4+0, v0.x, Odata);
- O.Set(n*4+0, y, x, k*4+1, v0.y, Odata);
- O.Set(n*4+0, y, x, k*4+2, v0.z, Odata);
- O.Set(n*4+0, y, x, k*4+3, v0.w, Odata);
-
- O.Set(n*4+1, y, x, k*4+0, v1.x, Odata);
- O.Set(n*4+1, y, x, k*4+1, v1.y, Odata);
- O.Set(n*4+1, y, x, k*4+2, v1.z, Odata);
- O.Set(n*4+1, y, x, k*4+3, v1.w, Odata);
-
- O.Set(n*4+2, y, x, k*4+0, v2.x, Odata);
- O.Set(n*4+2, y, x, k*4+1, v2.y, Odata);
- O.Set(n*4+2, y, x, k*4+2, v2.z, Odata);
- O.Set(n*4+2, y, x, k*4+3, v2.w, Odata);
-
- O.Set(n*4+3, y, x, k*4+0, v3.x, Odata);
- O.Set(n*4+3, y, x, k*4+1, v3.y, Odata);
- O.Set(n*4+3, y, x, k*4+2, v3.z, Odata);
- O.Set(n*4+3, y, x, k*4+3, v3.w, Odata);
-
- #else // ----------------------------------------------------------
-
- // 118ms
- O.Set(n*4+0, y, x, k*4+0, v0.x, Odata);
- O.Set(n*4+1, y, x, k*4+0, v1.x, Odata);
- O.Set(n*4+2, y, x, k*4+0, v2.x, Odata);
- O.Set(n*4+3, y, x, k*4+0, v3.x, Odata);
-
- O.Set(n*4+0, y, x, k*4+1, v0.y, Odata);
- O.Set(n*4+1, y, x, k*4+1, v1.y, Odata);
- O.Set(n*4+2, y, x, k*4+1, v2.y, Odata);
- O.Set(n*4+3, y, x, k*4+1, v3.y, Odata);
-
- O.Set(n*4+0, y, x, k*4+2, v0.z, Odata);
- O.Set(n*4+1, y, x, k*4+2, v1.z, Odata);
- O.Set(n*4+2, y, x, k*4+2, v2.z, Odata);
- O.Set(n*4+3, y, x, k*4+2, v3.z, Odata);
-
- O.Set(n*4+0, y, x, k*4+3, v0.w, Odata);
- O.Set(n*4+1, y, x, k*4+3, v1.w, Odata);
- O.Set(n*4+2, y, x, k*4+3, v2.w, Odata);
- O.Set(n*4+3, y, x, k*4+3, v3.w, Odata);
-
- #endif // ----------------------------------------------------------
-
-
- #undef X_
- #undef K_
-}
-
-#else // =====================================================================================================
-
-#undef CTILE
-#define CTILE 16
-#define RTILE 4
-groupshared float Conv_XcacheR[RTILE*RTILE][CTILE*CTILE];
-groupshared float Conv_KcacheR[RTILE*RTILE][CTILE*CTILE];
-[numthreads(CTILE, CTILE, 1)]
-void Conv2D_Kernel3x3_Cache_KCmod64_KNyx(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ Conv_XcacheR
- #define K_ Conv_KcacheR
-
- uint gx = groupThreadID.x;
- uint gy = groupThreadID.y;
-
- uint k = CTILE * groupID.x + groupThreadID.x;
- uint nyx = CTILE * groupID.y + groupThreadID.y;
-
- uint width = X.width - _Border;
- uint height = X.height - _Border;
-
- uint x = nyx % width;
- uint ny = nyx / width;
- uint y = ny % height;
- uint n = ny / height;
-
- float v[RTILE*RTILE];
- for (uint xxxx = 0; xxxx < RTILE; ++xxxx)
- {
- float b = B.Get(0, 0, k*RTILE+xxxx, 0, WBKdata, WBK.dataLength);
- for (uint yyyy = 0; yyyy < RTILE; ++yyyy)
- v[yyyy*RTILE+xxxx] = b;
- }
-
- for (uint dy = 0; dy < 3; ++dy)
- {
- bool mask = true;
-
- if (y+dy < _Offset) mask = false;
- if (y+dy-_Offset >= X.height) mask = false;
-
- for (uint dx = 0; dx < 3; ++dx)
- {
- if (x+dx < _Offset) mask = false;
- if (x+dx-_Offset >= X.width) mask = false;
-
- for (uint m = 0; m < X.channels/(CTILE*RTILE); ++m)
- {
-
- for (uint yy = 0; yy < RTILE; ++yy)
- for (uint xx = 0; xx < RTILE; ++xx)
- {
- if (mask)
- X_[yy*RTILE+xx][gy*CTILE+gx] = X.Get(n*RTILE+yy, y+dy-_Offset, x+dx-_Offset, (m*CTILE + gx)*RTILE+xx, Xdata);
- else
- X_[yy*RTILE+xx][gy*CTILE+gx] = 0;
- K_[yy*RTILE+xx][gy*CTILE+gx] = K.Get(dy, dx, (m*CTILE + gy)*RTILE+yy, k*RTILE+xx, WBKdata, WBK.dataLength);
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- for (uint ii = 0; ii < CTILE; ++ii)
- {
- float x[RTILE*RTILE];
- float k[RTILE*RTILE];
-
- [unroll]
- for (uint iii = 0; iii < RTILE*RTILE; ++iii)
- {
- x[iii] = X_[iii][gy*CTILE+ii];
- k[iii] = K_[iii][ii*CTILE+gx];
- }
-
- [unroll]
- for (uint r = 0; r < RTILE*RTILE; ++r)
- {
- [unroll]
- for (uint i = 0; i < RTILE; ++i)
- {
- uint xxx = r % RTILE;
- v[r] = mad(x[r], k[i*RTILE+xxx], v[r]);
-
- //v[yyy][xxx] += x[yyy][i] * k[i][xxx];
- }
- }
-
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- for (uint yy = 0; yy < RTILE; ++yy)
- for (uint xx = 0; xx < RTILE; ++xx)
- O.Set(n*RTILE+yy, y, x, k*RTILE+xx, v[yy*RTILE+xx], Odata);
-
- #undef X_
- #undef K_
-}
-#endif
-
-[numthreads(CTILE, CTILE, 1)]
-void Conv2D_Kernel3x3_Cache_KCmod16_KNyx_TEMPLATE(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- #define X_ Conv_XcacheT
- #define K_ Conv_KcacheT
-
- uint gx = groupThreadID.x;
- uint gy = groupThreadID.y;
-
- uint k = CTILE * groupID.x + groupThreadID.x;
- uint nyx = CTILE * groupID.y + groupThreadID.y;
-
- uint width = X.width - _Border;
- uint height = X.height - _Border;
-
- uint x = nyx % width;
- uint ny = nyx / width;
- uint y = ny % height;
- uint n = ny / height;
-
- float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- for (uint dy = 0; dy < 3; ++dy)
- {
- bool mask = true;
-
- if (y+dy < _Offset) mask = false;
- if (y+dy-_Offset >= X.height) mask = false;
-
- for (uint dx = 0; dx < 3; ++dx)
- {
- if (x+dx < _Offset) mask = false;
- if (x+dx-_Offset >= X.width) mask = false;
-
- //for (uint m = 0; m < (9*128)/CTILE; ++m)
- for (uint m = 0; m < X.channels/CTILE; ++m)
- {
- if (mask)
- X_[gy][gx] = X.Get(n, y+dy-_Offset, x+dx-_Offset, m*CTILE + gx, Xdata);
- else
- X_[gy][gx] = 0;
- K_[gy][gx] = K.Get(dy, dx, m*CTILE + gy, k, WBKdata, WBK.dataLength);
- GroupMemoryBarrierWithGroupSync();
-
- [unroll]
- for (uint i = 0; i < CTILE; ++i)
- {
- float x = X_[gy][i];
- float k =.25;// K_[i][gx];
- v += x * k;
- }
- }
- }
- }
-
- //Odata[nyx * O.channels + k] = v;
-
- Odata[((
- n * O.height +
- y ) * O.width +
- x ) * O.channels +
- k] = v;
-
- #undef X_
- #undef K_
-}
-// %TODO: only supports up to 51 kernels (51 = 16*16*2/(9kernel+1bias)) for now. Add a loop to handle more!
-/*
-groupshared float K1cache[KERNEL_SIZE][KERNEL_SIZE][32];
-groupshared float B1cache[32];
-[numthreads(16,16,2)]
-void Conv2D_Kernel3x3_1Channel(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint k = 16*groupID.x + groupThreadID.x;
- uint n = 16*groupID.y + groupThreadID.y;
- uint y = 2*groupID.z + groupThreadID.z + _FilterSize;
-
- uint idx = 16*16*groupThreadID.z + 16*groupThreadID.y + groupThreadID.x;
- if (idx < 9 * K.channels)
- {
- uint kx = idx / K.channels;
- uint kk = idx % K.channels;
- K1cache[kx/3][kx%3][kk] = K.Get(kx/3, kx%3, 0, kk, WBKdata, WBK.dataLength);
- }
- else if (idx < 10 * K.channels)
- {
- uint kk = idx % K.channels;
- B1cache[kk] = B.Get(0, 0, kk, 0, WBKdata, WBK.dataLength);
- }
- GroupMemoryBarrierWithGroupSync();
-
- for (uint x = _FilterSize; x < X.width - _FilterSize; ++x)
- {
- float v = B1cache[k];//B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- for (int i = -_FilterSize; i < _FilterSize + 1; ++i)
- {
- for (int j = -_FilterSize; j < _FilterSize + 1; ++j)
- {
- v += X.Get(n, y+j, x+i, 0, Xdata, X.dataLength) * K1cache[_FilterSize+j][_FilterSize+i][k];
- }
- }
- O.Set(n, y-_FilterSize, x-_FilterSize, k, v, Odata, O.dataLength);
- }
-}
-*/
-
-groupshared float K1cache[32][9];
-[numthreads(32,16,1)]
-void Conv2D_Kernel3x3_1Channel(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- uint tk = groupThreadID.x;
- uint k = 32*groupID.x + tk;
- uint n = 16*groupID.y + groupThreadID.y;
-
- //for (uint q = 0; q < 9; ++q)
- {
- uint q = n % 9;
- K1cache[tk][q] = K.Get(q/3, q%3, 0, k, WBKdata, WBK.dataLength);
- }
- GroupMemoryBarrierWithGroupSync();
-
- for (uint y = 0; y < X.height - _FilterSize*2; ++y)
- {
- for (uint x = 0; x < X.width - _FilterSize*2; ++x)
- {
- float v = B.Get(0, 0, k, 0, WBKdata, WBK.dataLength);
- //for (uint q = 0; q < 9; ++q)
- // v += X.Get(n, y+q/3, x+q%3, 0, Xdata, X.dataLength) * K1cache[tk][q];
- v += X.Get(n, y+0, x+0, 0, Xdata, X.dataLength) * K1cache[tk][0];
- v += X.Get(n, y+0, x+1, 0, Xdata, X.dataLength) * K1cache[tk][1];
- v += X.Get(n, y+0, x+2, 0, Xdata, X.dataLength) * K1cache[tk][2];
-
- v += X.Get(n, y+1, x+0, 0, Xdata, X.dataLength) * K1cache[tk][3];
- v += X.Get(n, y+1, x+1, 0, Xdata, X.dataLength) * K1cache[tk][4];
- v += X.Get(n, y+1, x+2, 0, Xdata, X.dataLength) * K1cache[tk][5];
-
- v += X.Get(n, y+2, x+0, 0, Xdata, X.dataLength) * K1cache[tk][6];
- v += X.Get(n, y+2, x+1, 0, Xdata, X.dataLength) * K1cache[tk][7];
- v += X.Get(n, y+2, x+2, 0, Xdata, X.dataLength) * K1cache[tk][8];
-
- O.Set(n, y, x, k, v, Odata, O.dataLength);
- }
- }
-}
-
-float fillValue;
-
-[numthreads(1,1,1)]
-void Fill(uint3 groupID : SV_GroupID)
-{
- uint b = groupID.x;
- uint h = groupID.y;
- uint w = groupID.z;
- for (uint ch = 0; ch < O.channels; ++ch)
- O.Set(b, h, w, ch+1, fillValue, Odata, O.dataLength);
-}
-#endif
-
-
-/*
-Cbufferconsts{
- uint n;
- uint dispatchDim_x;};
-#define groupDim_x 512
-groupshared float Accumulate_sharedMem[groupDim_x * channels];
-[numthreads(groupDim_x, 1, 1)]
-void Accumulate(uint tid: SV_GroupIndex, uint3 groupIdx: groupID)
-{
- #define sharedMem Reduce_sharedMem
- unsigned int i = groupIdx.x * (groupDim_x * 2) + tid;
- unsigned int dispatchSize = (groupDim_x * 2) * dispatchDim_x;
- sharedMem[tid] = 0;
- do {
- sharedMem[tid] += g_idata[i] + g_idata[i+groupDim_x];
- i += dispatchSize;
- } while (i < n);
- GroupMemoryBarrierWithGroupSync();
-
- if (groupDim_x >= 256)
- {
- if (tid < 128) { sharedMem[tid] += sharedMem[tid + 128 * channels]; }
- GroupMemoryBarrierWithGroupSync();
- }
-
- if (groupDim_x >= 128)
- {
- if (tid < 64) { sharedMem[tid] += sharedMem[tid + 64]; }
- GroupMemoryBarrierWithGroupSync();
- }
-
- if (tid < 32)
- {
- if (groupDim_x >= 64) sharedMem[tid] += sharedMem[tid + 32* channels];
- if (groupDim_x >= 32) sharedMem[tid] += sharedMem[tid + 16* channels];
- if (groupDim_x >= 16) sharedMem[tid] += sharedMem[tid + 8* channels];
- if (groupDim_x >= 8) sharedMem[tid] += sharedMem[tid + 4* channels];
- if (groupDim_x >= 4) sharedMem[tid] += sharedMem[tid + 2* channels];
- if (groupDim_x >= 2) sharedMem[tid] += sharedMem[tid + 1* channels];
- }
-
- if (tid == 0) g_odata[groupIdx.x] = sharedMem[0];
-
- #undef sharedMem
-}
-*/
- /*
-// Could do to reduce across NxN patch fitting within a group, HW <= HW / N
-// Repeat, until HW == 1
-
-// Alternatively reduce across Y axis, then X
-
-#undef MAX_CHANNELS
-#define MAX_CHANNELS 2048
-groupshared float GlobalAvgPool2D_AccumulatorPerChannel[MAX_CHANNELS];
-[numthreads(4,8,8)]
-void GlobalAvgPool2D(uint3 dispatchThreadID : SV_DispatchThreadID, uint threadID : SV_ThreadID)
-{
- // NOTE: dispatched over X (not O)
- DISPATCH_ARGS(X.channels, X.width, X.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= X.channels || c >= MAX_CHANNELS) return;
- if (x >= X.width) return;
- if (y >= X.height) return;
-
- // Accumulate
- for (uint n = 0; n < X.batch; ++n)
- {
- // Clear accumulator
- // @TODO: ThreadID
- //uint threadID = groupThreadID.x * 4 + groupThreadID.y * 8 + groupThreadID.z * 8;
- if (threadID < MAX_CHANNELS)
- GlobalAvgPool2D_AccumulatorPerChannel[threadID] = 0;
- GroupMemoryBarrierWithGroupSync();
-
- GlobalAvgPool2D_AccumulatorPerChannel[c] += X.Get(n, y, x, c);
- // @TODO: atomicAdd?
-
- GroupMemoryBarrierWithGroupSync();
- if (threadID < MAX_CHANNELS)
- {
- float v = GlobalAvgPool2D_AccumulatorPerChannel[threadID];
- O.Set(n, 0, 0, c, v / (X.width * X.height));
- }
- }
-}*/
-
-
-[numthreads(64,2,2)]
-void Conv2D_Reg2x2(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (k >= K.channels) return;
- if (x*2 >= O.width) return;
- if (y*2 >= O.height) return;
-
- uint2 leftCorner = _Pad.xy;
- uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
- for (uint n = 0; n < O.batch; ++n)
- {
- float4 acc = B.Get(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos0 = uint2(x*2+0, y*2+0) * _Stride.xy + uint2(dx, dy);
- uint2 pos1 = uint2(x*2+1, y*2+0) * _Stride.xy + uint2(dx, dy);
- uint2 pos2 = uint2(x*2+0, y*2+1) * _Stride.xy + uint2(dx, dy);
- uint2 pos3 = uint2(x*2+1, y*2+1) * _Stride.xy + uint2(dx, dy);
-
- for (uint c = 0; c < X.channels; ++c)
- {
- if (all(pos0 >= leftCorner) && all(pos0 < rightCorner))
- acc.x = fastfma(X.Get(n, pos0 - leftCorner, c), K.Get(dy, dx, c, k), acc.x);
- if (all(pos1 >= leftCorner) && all(pos1 < rightCorner))
- acc.y = fastfma(X.Get(n, pos1 - leftCorner, c), K.Get(dy, dx, c, k), acc.y);
- if (all(pos2 >= leftCorner) && all(pos2 < rightCorner))
- acc.z = fastfma(X.Get(n, pos2 - leftCorner, c), K.Get(dy, dx, c, k), acc.z);
- if (all(pos3 >= leftCorner) && all(pos3 < rightCorner))
- acc.w = fastfma(X.Get(n, pos3 - leftCorner, c), K.Get(dy, dx, c, k), acc.w);
- }
- }
- }
-
- O.Set(n, y*2+0, x*2+0, k, acc.x);
- O.Set(n, y*2+0, x*2+1, k, acc.y);
- O.Set(n, y*2+1, x*2+0, k, acc.z);
- O.Set(n, y*2+1, x*2+1, k, acc.w);
- }
-}
-
-#define SIZE 2
-[numthreads(64, 2, 2)]
-void Conv2D_Reg_Loop(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (k >= K.channels) return;
- if (x*SIZE >= O.width) return;
- if (y*SIZE >= O.height) return;
-
- uint2 leftCorner = _Pad.xy;
- uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc[SIZE*SIZE];
- [unroll]
- for (uint q = 0; q < SIZE*SIZE; ++q)
- acc[q] = B.Get(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos[SIZE*SIZE];
- [unroll]
- for (uint q = 0; q < SIZE*SIZE; ++q)
- pos[q] = uint2(x*SIZE+(q%SIZE), y*SIZE+(q/SIZE)) * _Stride.xy + uint2(dx, dy);
-
- // @TODO: investigate
- // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
-
- for (uint c = 0; c < X.channels; ++c)
- [unroll]
- for (q = 0; q < SIZE*SIZE; ++q)
- if (all(pos[q] >= leftCorner) && all(pos[q] < rightCorner))
- acc[q] = fastfma(X.Get(n, pos[q] - leftCorner, c), K.Get(dy, dx, c, k), acc[q]);
- }
- }
-
- [unroll]
- for (q = 0; q < SIZE*SIZE; ++q)
- O.Set(n, y*SIZE+(q/SIZE), x*SIZE+(q%SIZE), k, acc[q]);
- }
-}
-
-NUMTHREADS((16,4,4), (8,4,4), (16,2,2))
-//[numthreads(64, 1, 1)]
-void Conv2D_safe(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (k >= K.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc = B.Get(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos = uint2(x, y) * _Stride.xy + uint2(dx, dy);
-
- for (uint c = 0; c < X.channels; ++c)
- acc = fastfma(X.SafeGet(n, pos, c, _Pad.xy), K.Get(dy, dx, c, k), acc);
- }
- }
-
- O.Set(n, y, x, k, acc);
- }
-}
-
-
-#undef L1CACHESIZE
-#define L1CACHESIZE 32
-groupshared float Conv2D_L1Cached32_X[L1CACHESIZE];
-[numthreads(L1CACHESIZE, 1, 1)]
-void Conv2D_L1Cached32(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- #define X_ Conv2D_L1Cached32_X
-
- uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
- uint x = groupID.y;
- uint y = groupID.z;
-
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc = B.SafeGet(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos = uint2(x,y) * _Stride.xy + uint2(dx,dy);
-
- for (uint c = 0; c < X.channels; c += L1CACHESIZE)
- {
- // Cache X
- X_[groupThreadID.x] = X.SafeGet(n, pos, c + groupThreadID.x, _Pad.xy);
- GroupMemoryBarrierWithGroupSync();
-
- // X * K
- if (k < K.channels)
- {
- for (uint dc = 0; dc < L1CACHESIZE; ++dc)
- acc = fastfma(X_[dc], K.Get(dy, dx, c + dc, k), acc);
- }
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- O.Set(n, y, x, k, acc);
- }
-
- #undef X_
-}
-
-#undef L1CACHESIZE
-#define L1CACHESIZE 64
-groupshared float Conv2D_L1Cached64_X[L1CACHESIZE];
-[numthreads(L1CACHESIZE, 1, 1)]
-void Conv2D_L1Cached64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- #define X_ Conv2D_L1Cached64_X
-
- uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
- uint x = groupID.y;
- uint y = groupID.z;
-
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc = B.SafeGet(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos = uint2(x,y) * _Stride.xy + uint2(dx,dy);
- for (uint c = 0; c < X.channels; c += L1CACHESIZE)
- {
- // Cache X
- X_[groupThreadID.x] = X.SafeGet(n, pos, c + groupThreadID.x, _Pad.xy);
- GroupMemoryBarrierWithGroupSync();
-
- // X * K
- if (k < K.channels)
- {
- for (uint dc = 0; dc < L1CACHESIZE; ++dc)
- acc = fastfma(X_[dc], K.Get(dy, dx, c + dc, k), acc);
- }
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- O.Set(n, y, x, k, acc);
- }
-
- #undef X_
-}
-
-
-#undef SIZE
-#define SIZE 2
-[numthreads(64, 2, 2)]
-void Conv2D_Reg_Loop_safe(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (k >= K.channels) return;
- if (x*SIZE >= O.width) return;
- if (y*SIZE >= O.height) return;
-
- uint2 leftCorner = _Pad.xy;
- uint2 rightCorner = uint2(X.width, X.height) + _Pad.xy;
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc[SIZE*SIZE];
- [unroll]
- for (uint q = 0; q < SIZE*SIZE; ++q)
- acc[q] = B.Get(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos[SIZE*SIZE];
- [unroll]
- for (uint q = 0; q < SIZE*SIZE; ++q)
- pos[q] = uint2(x*SIZE+(q%SIZE), y*SIZE+(q/SIZE)) * _Stride.xy + uint2(dx, dy);
-
- // @TODO: investigate
- // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
-
- for (uint c = 0; c < X.channels; ++c)
- [unroll]
- for (q = 0; q < SIZE*SIZE; ++q)
- acc[q] = fastfma(X.SafeGet(n, pos[q], c, _Pad.xy), K.Get(dy, dx, c, k), acc[q]);
- }
- }
-
- [unroll]
- for (q = 0; q < SIZE*SIZE; ++q)
- O.Set(n, y*SIZE+(q/SIZE), x*SIZE+(q%SIZE), k, acc[q]);
- }
-}
-
-
-#undef L1CACHESIZE
-#define L1CACHESIZE 64
-#undef SIZE
-#define SIZE 2
-groupshared float Conv2D_L1Cached64_Reg_Loop2x2_X[SIZE*SIZE][L1CACHESIZE];
-[numthreads(L1CACHESIZE, 1, 1)]
-void Conv2D_L1Cached64_Reg_Loop2x2(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- #define X_ Conv2D_L1Cached64_Reg_Loop2x2_X
-
- uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
- uint x = groupID.y;
- uint y = groupID.z;
-
- // need all threads to load channels, thus will do late check against kernel count
- if (x*SIZE >= O.width) return;
- if (y*SIZE >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc[SIZE*SIZE];
- [unroll]
- for (uint q = 0; q < SIZE*SIZE; ++q)
- acc[q] = B.SafeGet(k);
-
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos[SIZE*SIZE];
- [unroll]
- for (uint q = 0; q < SIZE*SIZE; ++q)
- pos[q] = uint2(x*SIZE+(q%SIZE), y*SIZE+(q/SIZE)) * _Stride.xy + uint2(dx, dy);
-
- for (uint c = 0; c < X.channels; c += L1CACHESIZE)
- {
- // Cache X
- uint dc = groupThreadID.x;
- [unroll]
- for (q = 0; q < SIZE*SIZE; ++q)
- X_[q][dc] = X.SafeGet(n, pos[q], c + dc, _Pad.xy);
- GroupMemoryBarrierWithGroupSync();
-
- // X * K
- if (k < K.channels) // need all threads to load channels, thus late check against kernel count
- {
- uint kIndex = K.Index(dy, dx, c, k);
- for (dc = 0; dc < L1CACHESIZE; ++dc)
- {
- for (q = 0; q < SIZE*SIZE; ++q)
- acc[q] = fastfma(X_[q][dc], K.data[kIndex], acc[q]); //K.Get(dy, dx, c + dc, k);
- kIndex += K.channels;
- }
- }
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- if (k < K.channels) // need all threads to load channels, thus late check against kernel count
- [unroll]
- for (q = 0; q < SIZE*SIZE; ++q)
- O.Set(n, y*SIZE+(q/SIZE), x*SIZE+(q%SIZE), k, acc[q]);
- }
-
- #undef X_
-}
-
-
-#undef L1CACHESIZE
-#define L1CACHESIZE 64
-#undef SIZE
-#define SIZE 4
-groupshared float Conv2D_L1Cached64_Reg_Loop_X[SIZE*SIZE][L1CACHESIZE];
-[numthreads(L1CACHESIZE, 1, 1)]
-void Conv2D_L1Cached64_Reg_Loop(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- #define X_ Conv2D_L1Cached64_Reg_Loop_X
-
- uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
- uint x = groupID.y;
- uint y = groupID.z;
-
- // need all threads to load channels, thus will do late check against kernel count
- if (x*SIZE >= O.width) return;
- if (y*SIZE >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc[SIZE*SIZE];
- [unroll]
- for (uint q = 0; q < SIZE*SIZE; ++q)
- acc[q] = B.SafeGet(k);
-
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos[SIZE*SIZE];
- [unroll]
- for (uint q = 0; q < SIZE*SIZE; ++q)
- pos[q] = uint2(x*SIZE+(q%SIZE), y*SIZE+(q/SIZE)) * _Stride.xy + uint2(dx, dy);
-
- for (uint c = 0; c < X.channels; c += L1CACHESIZE)
- {
- // Cache X
- uint dc = groupThreadID.x;
- [unroll]
- for (q = 0; q < SIZE*SIZE; ++q)
- X_[q][dc] = X.SafeGet(n, pos[q], c + dc, _Pad.xy);
- GroupMemoryBarrierWithGroupSync();
-
- // X * K
- if (k < K.channels) // need all threads to load channels, thus late check against kernel count
- {
- uint kIndex = K.Index(dy, dx, c, k);
- for (dc = 0; dc < L1CACHESIZE; ++dc)
- {
- for (q = 0; q < SIZE*SIZE; ++q)
- acc[q] = fastfma(X_[q][dc], K.data[kIndex], acc[q]);//K.Get(dy, dx, c + dc, k);
- kIndex += K.channels;
- }
- }
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- if (k < K.channels) // need all threads to load channels, thus late check against kernel count
- [unroll]
- for (q = 0; q < SIZE*SIZE; ++q)
- O.Set(n, y*SIZE+(q/SIZE), x*SIZE+(q%SIZE), k, acc[q]);
- }
-
- #undef X_
-}
-
-
-#undef L1CACHESIZE
-#define L1CACHESIZE 64
-#define SIZE_W 4
-#define SIZE_H 2
-groupshared float Conv2D_L1Cached64_Reg_Loop_safe__X[SIZE_H*SIZE_W][L1CACHESIZE];
-[numthreads(L1CACHESIZE, 1, 1)]
-void Conv2D_L1Cached64_Reg_Loop_safe_(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- #define X_ Conv2D_L1Cached64_Reg_Loop_safe__X
-
- uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
- uint x = groupID.y;
- uint y = groupID.z;
-
- // need all threads to load channels, thus will do late check against kernel count
- if (x*SIZE_W >= O.width) return;
- if (y*SIZE_H >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc[SIZE_H*SIZE_W];
- [unroll]
- for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
- acc[q] = B.SafeGet(k);
-
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos[SIZE_H*SIZE_W];
- [unroll]
- for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
- pos[q] = uint2(x*SIZE_W+(q%SIZE_W), y*SIZE_H+(q/SIZE_W)) * _Stride.xy + uint2(dx, dy);
-
- for (uint c = 0; c < X.channels; c += L1CACHESIZE)
- {
- // Cache X
- uint dc = groupThreadID.x;
- [unroll]
- for (q = 0; q < SIZE_H*SIZE_W; ++q)
- X_[q][dc] = X.SafeGet(n, pos[q], c + dc, _Pad.xy);
- GroupMemoryBarrierWithGroupSync();
-
- // X * K
- if (k < K.channels) // need all threads to load channels, thus late check against kernel count
- {
- uint kIndex = K.Index(dy, dx, c, k);
- for (dc = 0; dc < L1CACHESIZE; ++dc)
- {
- [unroll]
- for (q = 0; q < SIZE_H*SIZE_W; ++q)
- acc[q] = fastfma(X_[q][dc], K.data[kIndex], acc[q]);
- kIndex += K.channels;
- }
- }
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- if (k < K.channels) // need all threads to load channels, thus late check against kernel count
- [unroll]
- for (q = 0; q < SIZE_H*SIZE_W; ++q)
- {
- uint ox = x*SIZE_W+(q%SIZE_W);
- uint oy = y*SIZE_H+(q/SIZE_W);
- if (ox < O.width && oy < O.height)
- O.Set(n, oy, ox, k, acc[q]);
- }
- }
-
- #undef X_
-}
-#undef SIZE_H
-#undef SIZE_W
-
-
-/*
-#undef L1CACHESIZE
-#define L1CACHESIZE 32
-#define SIZE_W 4
-#define SIZE_H 2
-groupshared float Conv2D_L1Cached64_Reg_Loop_safe__X[SIZE_H*SIZE_W][L1CACHESIZE];
-[numthreads(L1CACHESIZE, SIZE_W, SIZE_H)]
-void Conv2D_L1Cached64_Reg_Loop_safe_(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- #define X_ Conv2D_L1Cached64_Reg_Loop_safe__X
-
- uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
- uint x = SIZE_W * groupID.y + groupThreadID.y;
- uint y = SIZE_H * groupID.z + groupThreadID.z;
-
- // need all threads to load channels, thus will do late check against kernel count
- //if (x*SIZE_W >= O.width) return;
- //if (y*SIZE_H >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc[SIZE_H*SIZE_W];
- [unroll]
- for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
- acc[q] = B.SafeGet(k);
-
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- //uint2 pos[SIZE_H*SIZE_W];
- //[unroll]
- //for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
- // pos[q] = uint2(x*SIZE_W+(q%SIZE_W), y*SIZE_H+(q/SIZE_W)) * _Stride.xy + uint2(dx, dy);
-
- for (uint c = 0; c < X.channels; c += L1CACHESIZE)
- {
- // Cache X
- uint dc = groupThreadID.x;
- uint gx = groupThreadID.y;
- uint gy = groupThreadID.z;
- //[unroll]
- //for (q = 0; q < SIZE_H*SIZE_W; ++q)
- //{
- uint2 pos = uint2(x*SIZE_W+gx, y*SIZE_H+gy) * _Stride.xy + uint2(dx, dy);
- X_[SIZE_W*gy+gx][dc] = X.SafeGet(n, pos, c + dc, _Pad.xy);
- //}
- GroupMemoryBarrierWithGroupSync();
-
- // X * K
- if (k < K.channels &&
- x*SIZE_W < O.width &&
- y*SIZE_H < O.height) // need all threads to load channels, thus late check against kernel count
- {
- uint kIndex = K.Index(dy, dx, c, k);
- for (dc = 0; dc < L1CACHESIZE; ++dc)
- {
- [unroll]
- for (q = 0; q < SIZE_H*SIZE_W; ++q)
- acc[q] += X_[q][dc] * K.data[kIndex];//K.Get(dy, dx, c + dc, k);
- kIndex += K.channels;
- }
- }
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- if (k < K.channels) // need all threads to load channels, thus late check against kernel count
- [unroll]
- for (q = 0; q < SIZE_H*SIZE_W; ++q)
- {
- uint ox = x*SIZE_W+(q%SIZE_W);
- uint oy = y*SIZE_H+(q/SIZE_W);
- if (ox < O.width && oy < O.height)
- O.Set(n, oy, ox, k, acc[q]);
- }
- }
-
- #undef X_
-}
-#undef SIZE_H
-#undef SIZE_W
-*/
-
-/*
-#undef L1CACHESIZE
-#define L1CACHESIZE 64
-groupshared float Conv2D_RegCached_X[4][L1CACHESIZE];
-[numthreads(L1CACHESIZE, 1, 1)]
-void Conv2D_RegCached(uint3 dispatchThreadID : SV_DispatchThreadID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- #define X_ Conv2D_RegCached_X
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (x*2 >= O.width) return;
- if (y*2 >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float4 acc = B.SafeGet(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint2 pos0 = uint2(x*2+0,y*2+0) * _Stride + uint2(dx,dy);
- uint2 pos1 = uint2(x*2+1,y*2+0) * _Stride + uint2(dx,dy);
- uint2 pos2 = uint2(x*2+0,y*2+1) * _Stride + uint2(dx,dy);
- uint2 pos3 = uint2(x*2+1,y*2+1) * _Stride + uint2(dx,dy);
-
- // Cache X
- uint c_ = groupThreadID.x;
- if (c_ < X.channels)
- {
- X_[0][c_] = X.SafeGet(n, pos0, c_, _Pad.xy);
- X_[1][c_] = X.SafeGet(n, pos1, c_, _Pad.xy);
- X_[2][c_] = X.SafeGet(n, pos2, c_, _Pad.xy);
- X_[3][c_] = X.SafeGet(n, pos3, c_, _Pad.xy);
- }
- GroupMemoryBarrierWithGroupSync();
-
- // X * K
- if (k < K.channels)
- for (uint c = 0; c < X.channels; ++c)
- {
- acc.x += X_[0][c] * K.Get(dy, dx, c, k);
- acc.y += X_[1][c] * K.Get(dy, dx, c, k);
- acc.z += X_[2][c] * K.Get(dy, dx, c, k);
- acc.w += X_[3][c] * K.Get(dy, dx, c, k);
- }
- GroupMemoryBarrierWithGroupSync();
- }
- }
-
- O.Set(n, y*2+0, x*2+0, k, acc.x);
- O.Set(n, y*2+0, x*2+1, k, acc.y);
- O.Set(n, y*2+1, x*2+0, k, acc.z);
- O.Set(n, y*2+1, x*2+1, k, acc.w);
- }
-}
-*/
-
-/*
-[numthreads(16,4,4)]
-void Conv2DTrans(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (k >= K.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- uint2 strideMask = _Stride.xy - 1;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc = B.Get(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); dy += _Stride.y)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); dx += _Stride.x)
- {
- uint dxShifted = dx + (x&strideMask.x);
- uint dyShifted = dy + (y&strideMask.y);
-
- uint xx = x + dxShifted;
- uint yy = y + dyShifted;
-
- uint oy = (yy - _Pad.y) / _Stride.y;
- uint ox = (xx - _Pad.x) / _Stride.x;
-
- bool mask = xx >= _Pad.x && yy >= _Pad.y && ox < X.width && oy < X.height;
- if (!mask) continue;
-
- // [unroll] - crashes metal compiler
- for (uint c = 0; c < X.channels; ++c)
- {
- acc += X.Get(n, oy, ox, c) * K.Get( K.GetKernelHeight() - 1 - dyShifted,
- K.GetKernelWidth() - 1 - dxShifted, c, k);
- }
- }
- }
-
- O.Set(n, y, x, k, acc);
- }
-}
-*/
-
-
-
-#undef SIZE
-#define SIZE 4
-[numthreads(16, 4, 4)]
-void Conv2DTrans_Reg_Loop_safe(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (k >= K.channels) return;
- if (x*SIZE >= O.width) return;
- if (y*SIZE >= O.height) return;
-
- uint2 strideMask = _Stride.xy - 1;
-
- uint2 pad = _Pad.xy / _Stride.xy;
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc[SIZE*SIZE];
- [unroll]
- for (uint q = 0; q < SIZE*SIZE; ++q)
- acc[q] = B.Get(k);
-
- for (uint dy = 0; dy < K.GetKernelHeight(); dy += _Stride.y)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); dx += _Stride.x)
- {
- uint2 kernelPos[SIZE*SIZE];
- uint2 pos[SIZE*SIZE];
-
- [unroll]
- for (uint q = 0; q < SIZE*SIZE; ++q)
- {
- uint2 xy = uint2(x*SIZE+(q%SIZE), y*SIZE+(q/SIZE));
- kernelPos[q] = uint2(dx, dy) + (xy & strideMask);
- pos[q] = (xy + kernelPos[q]) / _Stride.xy;
-
- // transpose
- kernelPos[q] = uint2(K.GetKernelWidth(), K.GetKernelHeight()) - 1 - kernelPos[q];
- }
-
- for (uint c = 0; c < X.channels; ++c)
- [unroll]
- for (q = 0; q < SIZE*SIZE; ++q)
- acc[q] = fastfma(X.SafeGet(n, pos[q], c, pad.xy), K.Get(kernelPos[q].y, kernelPos[q].x, c, k), acc[q]);
- //acc[q] += X.SafeGet(n, pos[q], c, pad.xy) * K.Get(kernelPos[q].y, kernelPos[q].x, c, k);
- }
- }
-
- [unroll]
- for (q = 0; q < SIZE*SIZE; ++q)
- O.Set(n, y*SIZE+(q/SIZE), x*SIZE+(q%SIZE), k, acc[q]);
- }
-}
-
-
-
-#undef L1CACHESIZE
-#define L1CACHESIZE 64
-#define SIZE_W 4
-#define SIZE_H 2
-groupshared float Conv2DTrans_L1Cached64_Reg_Loop_safe__X[SIZE_H*SIZE_W][L1CACHESIZE];
-[numthreads(L1CACHESIZE, 1, 1)]
-void Conv2DTrans_L1Cached64_Reg_Loop_safe_(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- #define X_ Conv2DTrans_L1Cached64_Reg_Loop_safe__X
-
- uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
- uint x = groupID.y;
- uint y = groupID.z;
-
- // need all threads to load channels, thus will do late check against kernel count
- if (x*SIZE_W >= O.width) return;
- if (y*SIZE_H >= O.height) return;
-
- uint2 strideMask = _Stride.xy - 1;
- uint2 pad = _Pad.xy / _Stride.xy;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc[SIZE_H*SIZE_W];
- [unroll]
- for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
- acc[q] = B.SafeGet(k);
-
- for (uint dy = 0; dy < K.GetKernelHeight(); dy += _Stride.y)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); dx += _Stride.x)
- {
- uint2 kernelPos[SIZE_H*SIZE_W];
- uint2 pos[SIZE_H*SIZE_W];
-
- [unroll]
- for (uint q = 0; q < SIZE_H*SIZE_W; ++q)
- {
- uint2 xy = uint2(x*SIZE_W+(q%SIZE_W), y*SIZE_H+(q/SIZE_W));
- kernelPos[q] = uint2(dx, dy) + (xy & strideMask);
- pos[q] = (xy + kernelPos[q]) / _Stride.xy;
-
- // transpose
- kernelPos[q] = uint2(K.GetKernelWidth(), K.GetKernelHeight()) - 1 - kernelPos[q];
- }
-
- for (uint c = 0; c < X.channels; c += L1CACHESIZE)
- {
- // Cache X
- uint dc = groupThreadID.x;
- [unroll]
- for (q = 0; q < SIZE_H*SIZE_W; ++q)
- X_[q][dc] = X.SafeGet(n, pos[q], c + dc, pad.xy);
- GroupMemoryBarrierWithGroupSync();
-
- // X * K
- if (k < K.channels) // need all threads to load channels, thus late check against kernel count
- {
- for (dc = 0; dc < L1CACHESIZE; ++dc)
- {
- [unroll]
- for (q = 0; q < SIZE_H*SIZE_W; ++q)
- acc[q] = fastfma(X_[q][dc], K.Get(kernelPos[q].y, kernelPos[q].x, c + dc, k), acc[q]);
- //acc[q] += X_[q][dc] * K.Get(kernelPos[q].y, kernelPos[q].x, c + dc, k);
- }
- }
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- if (k < K.channels) // need all threads to load channels, thus late check against kernel count
- [unroll]
- for (q = 0; q < SIZE_H*SIZE_W; ++q)
- {
- uint ox = x*SIZE_W+(q%SIZE_W);
- uint oy = y*SIZE_H+(q/SIZE_W);
- if (ox < O.width && oy < O.height)
- O.Set(n, oy, ox, k, acc[q]);
- }
- }
-
- #undef X_
-}
-#undef SIZE_H
-#undef SIZE_W
-
-
-/*
-#undef L1CACHESIZE
-#define L1CACHESIZE 64
-groupshared float Conv2DTrans_L1Cached64_Reg_Loop_safe_X[L1CACHESIZE];
-[numthreads(L1CACHESIZE, 1, 1)]
-void Conv2DTrans_L1Cached64_Reg_Loop_safe(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- DISPATCH_ARGS(K.kernelCount, X.width, X.height);
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- #define X_ Conv2DTrans_L1Cached64_Reg_Loop_safe_X
-
- uint k = L1CACHESIZE * groupID.x + groupThreadID.x;
- uint x = groupID.y;
- uint y = groupID.z;
-
- // need all threads to load channels, thus will do late check against kernel count
- if (x >= X.width) return;
- if (y >= X.height) return;
-
- uint2 pad = _Pad.xy / _Stride.xy;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- for (uint sy = 0; sy < _Stride.y; ++sy)
- {
- for (uint sx = 0; sx < _Stride.x; ++sx)
- {
- float acc = B.SafeGet(k);
-
- for (uint dy = sy; dy < K.GetKernelHeight(); dy += _Stride.y)
- {
- for (uint dx = sx; dx < K.GetKernelWidth(); dx += _Stride.x)
- {
- uint2 pos = uint2(x, y) + uint2(sx + dx, sy + dy) / _Stride.xy;
-
- for (uint c = 0; c < X.channels; c += L1CACHESIZE)
- {
- // Cache X
- uint dc = groupThreadID.x;
- X_[dc] = X.SafeGet(n, pos, c + dc, pad);
- GroupMemoryBarrierWithGroupSync();
-
- // X * K
- if (k < K.channels) // need all threads to load channels, thus late check against kernel count
- {
- for (dc = 0; dc < L1CACHESIZE; ++dc)
- {
- acc = fastfma( X_[dc],
- K.Get( K.GetKernelHeight() - 1 - dy,
- K.GetKernelWidth() - 1 - dx, c + dc, k),
- acc);
- }
- }
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- uint oy = y * _Stride.y + sy;
- uint ox = x * _Stride.x + sx;
- if (oy < O.height && ox < O.width && k < K.channels)
- O.Set(n, oy, ox, k, acc);
- }
- }
- }
-
- #undef X_
-}
-*/
-#endif
-
-
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Experimental.compute.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Experimental.compute.meta
deleted file mode 100644
index 49e7b42da5..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Experimental.compute.meta
+++ /dev/null
@@ -1,9 +0,0 @@
-fileFormatVersion: 2
-guid: 299ca130202014274b506123e830c52d
-timeCreated: 1506672486
-licenseType: Pro
-ComputeShaderImporter:
- currentAPIMask: 196608
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/FastNV.compute b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/FastNV.compute
deleted file mode 100644
index 00f077e362..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/FastNV.compute
+++ /dev/null
@@ -1,188 +0,0 @@
-//#pragma kernel Dense64
-//#pragma kernel Conv2D_Kernel3x3_64
-
-#include "Tensor.cginc"
-
-TENSOR_DECL(X)
-TENSOR_DECL(W)
-TENSOR_DECL(K)
-TENSOR_DECL(B)
-TENSOR_DECL(WBK)
-TENSOR_DECL_RW(O)
-
-uint4 _Pad;
-uint4 _Stride;
-
-#undef THREAD_COUNT
-#define THREAD_COUNT 64 // ATM support only 8x8
-
-#undef BLOCK_WIDTH
-#define BLOCK_WIDTH 8
-
-#undef LOAD_WIDTH
-#define LOAD_WIDTH THREAD_COUNT
-
-#undef LOAD_DEPTH
-#define LOAD_DEPTH BLOCK_WIDTH
-
-groupshared float DenseTiled_XcacheR[LOAD_DEPTH][LOAD_WIDTH];
-groupshared float DenseTiled_WcacheR[LOAD_DEPTH][LOAD_WIDTH];
-
-[numthreads(THREAD_COUNT, 1, 1)]
-void Dense64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- // @TODO: DISPATCH_ARGS(...)
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- #define X_ DenseTiled_XcacheR
- #define W_ DenseTiled_WcacheR
-
- uint id = groupThreadID.x;
- uint bx = groupID.x;
- uint by = groupID.y;
-
- uint bbx = id % BLOCK_WIDTH;
- uint bby = id / BLOCK_WIDTH;
-
- float v[BLOCK_WIDTH][BLOCK_WIDTH];
- for (uint yy = 0; yy < BLOCK_WIDTH; ++yy)
- for (uint xx = 0; xx < BLOCK_WIDTH; ++xx)
- {
- float bias = B.Get(bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xx);
- v[yy][xx] = bias;
- }
-
- for (uint m = 0; m < X.GetFlatWidth()/LOAD_DEPTH; ++m)
- {
- for (uint q = 0; q < LOAD_DEPTH; ++q)
- {
- X_[q][id] = X.Get(by*LOAD_WIDTH + id, m*LOAD_DEPTH + q);
- W_[q][id] = W.Get(m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id);
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
- [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
- [unroll] for (uint i = 0; i < LOAD_DEPTH; ++i)
- {
- v[yyy][xxx] = mad(X_[i][bby*BLOCK_WIDTH + yyy], W_[i][bbx*BLOCK_WIDTH + xxx], v[yyy][xxx]);
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
-
- for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
- for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
- O.Set(by*LOAD_WIDTH + bby*BLOCK_WIDTH + yyy, bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xxx, v[yyy][xxx]);
-
- #undef X_
- #undef W_
-}
-
-
-#undef THREAD_COUNT
-#define THREAD_COUNT 64 // ATM support only 8x8
-
-#undef BLOCK_WIDTH
-#define BLOCK_WIDTH 8
-
-#undef LOAD_WIDTH
-#define LOAD_WIDTH THREAD_COUNT
-
-#undef LOAD_DEPTH
-#define LOAD_DEPTH BLOCK_WIDTH
-
-groupshared float Conv_KcacheR[LOAD_DEPTH][LOAD_WIDTH];
-groupshared float Conv_XcacheR[LOAD_DEPTH][LOAD_WIDTH];
-[numthreads(THREAD_COUNT, 1, 1)]
-void Conv2D_Kernel3x3_64(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
-{
- // @TODO: DISPATCH_ARGS(...)
- TENSOR_SHARED2_ARGS4(X, K, B, WBK, O);
-
- #define X_ Conv_XcacheR
- #define K_ Conv_KcacheR
-
- uint id = groupThreadID.x;
- uint bx = groupID.x;
- uint by = groupID.y;
-
- uint bbx = id % BLOCK_WIDTH;
- uint bby = id / BLOCK_WIDTH;
-
- uint width = O.width;
- uint height = O.height;
-
- // ASSERT(LOAD_WIDTH == THREAD_COUNT)
- uint loadNYX = by*LOAD_WIDTH + id; // only works for 8x8
- uint loadX = loadNYX % width;
- uint loadNY = loadNYX / width;
- uint loadY = loadNY % height;
- uint loadN = loadNY / height;
-
- // @TODO: validate that _Stride works, added the following 2 lines without testing
- loadX *= _Stride.x;
- loadY *= _Stride.y;
-
- float v[BLOCK_WIDTH][BLOCK_WIDTH];
- [unroll] for (uint yy = 0; yy < BLOCK_WIDTH; ++yy)
- [unroll] for (uint xx = 0; xx < BLOCK_WIDTH; ++xx)
- {
- float bias = B.Get(bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xx);
- v[yy][xx] = bias;
- }
-
- for (uint dy = 0; dy < 3; ++dy)
- {
- bool mask = true;
-
- if (loadY+dy < _Pad.y) mask = false;
- if (loadY+dy - _Pad.w >= X.height) mask = false;
-
- for (uint dx = 0; dx < 3; ++dx)
- {
- if (loadX+dx < _Pad.x) mask = false;
- if (loadX+dx - _Pad.z >= X.width) mask = false;
-
- for (uint m = 0; m < X.channels/LOAD_DEPTH; ++m)
- {
- for (uint q = 0; q < LOAD_DEPTH; ++q)
- {
- if (mask)
- X_[q][id] = X.Get(loadN, loadY+dy-_Pad.y, loadX+dx-_Pad.x, m*LOAD_DEPTH + q);
- else
- X_[q][id] = 0;
- K_[q][id] = K.Get(dy, dx, m*LOAD_DEPTH + q, bx*LOAD_WIDTH + id);
- }
-
- GroupMemoryBarrierWithGroupSync();
-
- for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
- [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
- [unroll] for (uint i = 0; i < LOAD_DEPTH; ++i)
- {
- v[yyy][xxx] += X_[i][bby*BLOCK_WIDTH + yyy] * K_[i][bbx*BLOCK_WIDTH + xxx];
- }
-
- GroupMemoryBarrierWithGroupSync();
- }
- }
- }
-
- [unroll] for (uint yyy = 0; yyy < BLOCK_WIDTH; ++yyy)
- [unroll] for (uint xxx = 0; xxx < BLOCK_WIDTH; ++xxx)
- {
- uint saveNYX = by*LOAD_WIDTH + bby*BLOCK_WIDTH + yyy;
- uint saveX = saveNYX % width;
- uint saveNY = saveNYX / width;
- uint saveY = saveNY % height;
- uint saveN = saveNY / height;
-
- uint saveK = bx*LOAD_WIDTH + bbx*BLOCK_WIDTH + xxx;
- O.Set(saveN, saveY, saveX, saveK, v[yyy][xxx]);
- }
-
- #undef X_
- #undef K_
-}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/FastNV.compute.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/FastNV.compute.meta
deleted file mode 100644
index 91a842521a..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/FastNV.compute.meta
+++ /dev/null
@@ -1,9 +0,0 @@
-fileFormatVersion: 2
-guid: c7c673db45e6845d5abaed4ed5ef42e1
-timeCreated: 1507294253
-licenseType: Pro
-ComputeShaderImporter:
- currentAPIMask: 196608
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Generic.compute b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Generic.compute
deleted file mode 100644
index fc3dc82793..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Generic.compute
+++ /dev/null
@@ -1,463 +0,0 @@
-#pragma kernel ScaleBias
-#pragma kernel ScaleBias_CNyx
-#pragma kernel ScaleBias_CNyx2
-#pragma kernel ScaleBias_Flat
-#pragma kernel ScaleBias_Loop
-#pragma kernel Upsample2D
-#pragma kernel AvgPool2D
-#pragma kernel MaxPool2D
-#pragma kernel AvgPool2D_NoPads
-#pragma kernel MaxPool2D_NoPads
-//#pragma kernel MaxPool2D_Pool2x2_NoPads
-#pragma kernel GlobalAvgPool2D
-#pragma kernel InstanceNorm
-#pragma kernel InstanceNormTail_CNyx2
-#pragma kernel InstanceNormTail_Flat
-#pragma kernel Copy
-
-/*
-ScaleBias_Flat+ScaleBias_CNyx2 (NEW) vs ScaleBias+ScaleBias_CNyx
-Compute Precompiled
-
-MOBILENET@4
-<<= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- float bias = B.Get(0, 0, 0, c);
- float scale = W.Get(0, 0, 0, c);
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- v = v * scale + bias;
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((16,16,1), (16,8,1), (16,4,1))
-void ScaleBias_CNyx(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- uint c = dispatchThreadID.x;
- uint nyx = dispatchThreadID.y;
-
- uint x = nyx % X.width;
- uint ny = nyx / X.width;
- uint y = ny % X.height;
- uint n = ny / X.height;
-
- if (c >= X.channels) return;
- if (n >= X.batch) return;
-
- float bias = B.Get(0, 0, 0, c);
- float scale = W.Get(0, 0, 0, c);
-
- float v = X.Get(n, y, x, c);
- v = v * scale + bias;
- O.Set(n, y, x, c, v);
-}
-
-NUMTHREADS((256,1,1), (128,1,1), (64,1,1))
-void ScaleBias_Flat(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.length, 1, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- uint i = dispatchThreadID.x;
- if (i > O.GetLength()) return;
-
- uint c = i % X.channels;
- float bias = B.Get(c);
- float scale = W.Get(c);
-
- float v = X.Get(i);
- v = v * scale + bias;
- O.Set(i, v);
-}
-
-NUMTHREADS((256,1,1), (128,1,1), (64,1,1))
-void ScaleBias_Loop(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.length, 1, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- uint i = dispatchThreadID.x;
- uint len = O.GetLength();
-
- while (i < len)
- {
- uint c = i % X.channels;
- float bias = B.Get(c);
- float scale = W.Get(c);
-
- float v = X.Get(i);
- v = v * scale + bias;
- O.Set(i, v);
-
- i += _LoopStride;
- }
-}
-
-NUMTHREADS((32,4,1), (32,2,1), (16,2,1))
-void ScaleBias_CNyx2(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- uint c = dispatchThreadID.x;
- uint i = dispatchThreadID.y * X.channels + c;
-
- if (c >= X.channels) return;
- if (i >= X.GetLength()) return;
-
- float bias = B.Get(c);
- float scale = W.Get(c);
-
- float v = X.Get(i);
- v = v * scale + bias;
- O.Set(i, v);
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void Upsample2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- // NOTE: dispatched over X (not O)
- DISPATCH_ARGS(X.channels, X.width, X.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= X.channels) return;
- if (x >= X.width) return;
- if (y >= X.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
-
- for (uint dy = 0; dy < _Pool.y; ++dy)
- for (uint dx = 0; dx < _Pool.x; ++dx)
- {
- uint oy = y * _Pool.y + dy;
- uint ox = x * _Pool.x + dx;
- O.Set(n, oy, ox, c, v);
- }
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void MaxPool2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float maxV = -FLT_MAX;
- for (uint dy = 0; dy < _Pool.y; ++dy)
- for (uint dx = 0; dx < _Pool.x; ++dx)
- {
- uint oy = y * _Stride.y + dy;
- uint ox = x * _Stride.x + dx;
-
- bool mask = (oy >= _Pad.y) && (ox >= _Pad.x) && (oy - _Pad.y < X.height) && (ox - _Pad.x < X.width);
- float v = (mask)? X.Get(n, oy - _Pad.y, ox - _Pad.x, c): 0;
-
- maxV = max(v, maxV);
- }
-
- O.Set(n, y, x, c, maxV);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void AvgPool2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float acc = 0;
- float counter = 0;
- for (uint dy = 0; dy < _Pool.y; ++dy)
- for (uint dx = 0; dx < _Pool.x; ++dx)
- {
- uint oy = y * _Stride.y + dy;
- uint ox = x * _Stride.x + dx;
-
- bool mask = (oy >= _Pad.y) && (ox >= _Pad.x) && (oy - _Pad.y < X.height) && (ox - _Pad.x < X.width);
- acc += (mask)? X.Get(n, oy - _Pad.y, ox - _Pad.x, c): 0;
- counter += (mask)? 1: 0;
- }
-
- acc /= counter;
- O.Set(n, y, x, c, acc);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void MaxPool2D_NoPads(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float maxV = -FLT_MAX;
- for (uint dy = 0; dy < _Pool[1]; ++dy)
- for (uint dx = 0; dx < _Pool[0]; ++dx)
- {
- float v = X.Get(n, y * _Stride[1] + dy, x * _Stride[0] + dx, c);
- maxV = max(v, maxV);
- }
-
- O.Set(n, y, x, c, maxV);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void AvgPool2D_NoPads(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- float invPoolSize = 1.0f / (_Pool[0] * _Pool[1]);
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = 0;
- for (uint dy = 0; dy < _Pool[1]; ++dy)
- for (uint dx = 0; dx < _Pool[0]; ++dx)
- v += X.Get(n, y * _Stride[1] + dy, x * _Stride[0] + dx, c) * invPoolSize;
-
- O.Set(n, y, x, c, v);
- }
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-//NUMTHREADS((16,4,4), (16,4,2), (16,2,2))
-void MaxPool2D_Pool2x2_NoPads(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.width, O.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (c >= O.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v0 = X.Get(n, y*2, x*2, c);
- float v1 = X.Get(n, y*2+1, x*2, c);
- float v2 = X.Get(n, y*2, x*2+1, c);
- float v3 = X.Get(n, y*2+1, x*2+1, c);
- float v = max(v0, max(v1, max(v2, v3)));
-
- O.Set(n, y, x, c, v);
- }
-}
-
-[numthreads(32,1,1)]
-void GlobalAvgPool2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, 1, 1);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x;
- if (c >= O.channels) return;
- //ASSERT(X.batch == O.batch)
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = 0;
- for (uint y = 0; y < X.height; ++y)
- for (uint x = 0; x < X.width; ++x)
- v += X.Get(n, y, x, c);
-
- v /= (X.height * X.width);
- O.Set(n, 0, 0, c, v);
- }
-}
-
-[numthreads(64,1,1)]
-void InstanceNorm(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, 1, 1);
- TENSOR_SHARED2_ARGS4(X, W, B, WBK, O);
-
- uint c = dispatchThreadID.x;
- if (c >= O.channels) return;
- //ASSERT(X.shape == O.shape)
-
- float gamma = W.Get(0, 0, 0, c);
- float beta = B.Get(0, 0, 0, c);
-
- for (uint n = 0; n < O.batch; ++n)
- {
- uint x, y;
- // calc mean
- float acc = 0;
- for (y = 0; y < O.height; ++y)
- for (x = 0; x < O.width; ++x)
- acc += X.Get(n, y, x, c);
- float mean = acc / (O.width * O.height);
-
- // calc variance
- acc = 0;
- for (y = 0; y < O.height; ++y)
- for (x = 0; x < O.width; ++x)
- {
- float delta = X.Get(n, y, x, c) - mean;
- acc += delta * delta;
- }
- float var = acc / (O.width * O.height);
-
- // normalization factor
- float invNormFactor = 1 / sqrt(var + FLT_EPSILON);
-
- float scale = gamma * invNormFactor;
- float bias = beta - gamma * mean * invNormFactor;
-
- // apply normalization
- for (y = 0; y < O.height; ++y)
- for (x = 0; x < O.width; ++x)
- {
- float v = X.Get(n, y, x, c);
- v = v * scale + bias;
- O.Set(n, y, x, c, v);
- }
- }
-}
-
-NUMTHREADS((256,1,1), (128,1,1), (64,1,1))
-void InstanceNormTail_Flat(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.length, 1, 1);
- TENSOR_ARGS4(X, W, B, O);
-
- uint i = dispatchThreadID.x;
- if (i > O.GetLength()) return;
-
- uint c = i % X.channels;
-
- float variance = W.Get(c);
- float mean = B.Get(c);
- // normalization factor
- float invNormFactor = 1 / sqrt(variance + FLT_EPSILON);
-
- float v = X.Get(i);
- //v = gamma * (v * invNormFactor - mean * invNormFactor) + beta
- v = v * invNormFactor - mean * invNormFactor;
-
- O.Set(i, v);
-}
-
-NUMTHREADS((32,4,1), (32,2,1), (16,2,1))
-void InstanceNormTail_CNyx2(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- DISPATCH_ARGS(O.channels, O.batch * O.height * O.width, 1);
- TENSOR_ARGS4(X, W, B, O);
-
- uint c = dispatchThreadID.x;
- uint i = dispatchThreadID.y * X.channels + c;
-
- if (c >= X.channels) return;
- if (i >= X.GetLength()) return;
-
- float variance = W.Get(c);
- float mean = B.Get(c);
- // normalization factor
- float invNormFactor = 1 / sqrt(variance + FLT_EPSILON);
-
- float v = X.Get(i);
- //v = gamma * (v * invNormFactor - mean * invNormFactor) + beta
- v = v * invNormFactor - mean * invNormFactor;
-
- O.Set(i, v);
-}
-
-NUMTHREADS((4,8,8), (4,8,4), (4,4,4))
-void Copy(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- // NOTE: dispatched over X (not O)
- DISPATCH_ARGS(X.channels, X.width, X.height);
- TENSOR_ARGS2(X, O);
-
- uint c = dispatchThreadID.x; uint x = dispatchThreadID.y; uint y = dispatchThreadID.z;
- if (c >= X.channels) return; if (x >= X.width) return; if (y >= X.height) return;
-
- for (uint n = 0; n < X.batch; ++n)
- {
- float v = X.Get(n, y, x, c);
- O.Set(n + _Pad[0], y + _Pad[1], x + _Pad[2], c + _Pad[3], v);
- }
-}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Generic.compute.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Generic.compute.meta
deleted file mode 100644
index 47cf35156c..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Generic.compute.meta
+++ /dev/null
@@ -1,9 +0,0 @@
-fileFormatVersion: 2
-guid: 62f5efacd43b24dd38ead3ce0d80cc34
-timeCreated: 1495527718
-licenseType: Pro
-ComputeShaderImporter:
- currentAPIMask: 196608
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Random.cginc b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Random.cginc
deleted file mode 100644
index 2263f68713..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Random.cginc
+++ /dev/null
@@ -1,70 +0,0 @@
-
-// Based on: https://stackoverflow.com/questions/5149544/can-i-generate-a-random-number-inside-a-pixel-shader
-// Output: Random number: [0,1), that is between 0.0 and 0.999999... inclusive.
-// Author: Michael Pohoreski
-// Copyright: Copyleft 2012 :-)
-float RandomUsingCos(float4 seed)
-{
- float4 K1 = float4( // Transcendental numbers:
- 0.64341054629, // (Cahen's constant)
- 23.14069263277926, // e^pi (Gelfond's constant)
- 2.665144142690225, // 2^sqrt(2) (Gelfond-Schneider constant)
- 3.14159265359 // pi
- );
- return frac(cos(dot(seed, K1)) * 12345.6789);
-}
-
-// Based on: https://stackoverflow.com/questions/4200224/random-noise-functions-for-glsl
-// Author: Spatial
-// 05 July 2013
-
-// A single iteration of Bob Jenkins' One-At-A-Time hashing algorithm.
-uint hash(uint x)
-{
- x += ( x << 10u );
- x ^= ( x >> 6u );
- x += ( x << 3u );
- x ^= ( x >> 11u );
- x += ( x << 15u );
- return x;
-}
-uint hash( uint2 v ) { return hash( v.x ^ hash(v.y) ); }
-uint hash( uint3 v ) { return hash( v.x ^ hash(v.y) ^ hash(v.z) ); }
-uint hash( uint4 v ) { return hash( v.x ^ hash(v.y) ^ hash(v.z) ^ hash(v.w) ); }
-
-// Construct a float with half-open range [0:1] using low 23 bits.
-// All zeroes yields 0.0, all ones yields the next smallest representable value below 1.0.
-float floatConstruct(uint m)
-{
- const uint ieeeMantissa = 0x007FFFFFu; // binary32 mantissa bitmask
- const uint ieeeOne = 0x3F800000u; // 1.0 in IEEE binary32
-
- m &= ieeeMantissa; // Keep only mantissa bits (fractional part)
- m |= ieeeOne; // Add fractional part to 1.0
-
- float f = asfloat(m); // Range [1:2]
- return f - 1.0; // Range [0:1]
-}
-
-// Pseudo-random value in half-open range [0:1].
-float RandomUsingHash(float4 seed)
-{
- return floatConstruct(hash(asuint(seed)));
-}
-
-
-// More alternatives:
-// https://github.com/ashima/webgl-noise
-// https://www.shadertoy.com/view/4djSRW
-
-// ------------------------------------------------------------------------------------------
-
-float Random(float4 seed)
-{
- return RandomUsingCos(seed);
-}
-
-float Bernoulli(float4 seed, float p)
-{
- return Random(seed) <= p ? 1: 0;
-}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Random.cginc.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Random.cginc.meta
deleted file mode 100644
index 572d47b4cd..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Random.cginc.meta
+++ /dev/null
@@ -1,10 +0,0 @@
-fileFormatVersion: 2
-guid: 5a17e0b3943a74564a02a8ed0a41228b
-timeCreated: 1520855309
-licenseType: Pro
-ShaderImporter:
- externalObjects: {}
- defaultTextures: []
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Tensor.cginc b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Tensor.cginc
deleted file mode 100644
index a829458a99..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Tensor.cginc
+++ /dev/null
@@ -1,327 +0,0 @@
-#define BARRACUDA_MAX_THREAD_COUNT 64
-#if (BARRACUDA_MAX_THREAD_COUNT>=256)
-#define NUMTHREADS(t256,t128,t64) [numthreads t256]
-#define NUMTHREAD(t256, t128, t64) t256
-#elif (BARRACUDA_MAX_THREAD_COUNT>=128)
-#define NUMTHREADS(t256,t128,t64) [numthreads t128]
-#define NUMTHREAD(t256,t128,t64) t128
-#elif (BARRACUDA_MAX_THREAD_COUNT>=64)
-#define NUMTHREADS(t256,t128,t64) [numthreads t64]
-#define NUMTHREAD(t256,t128,t64) t64
-#endif
-
-struct Tensor
-{
- // @TODO: actually uint seems not like a good idea anymore, consider going to int
- uint batch, height, width, channels;
-
- void Init(uint4 nhwc)
- {
- batch = nhwc.x;
- height = nhwc.y;
- width = nhwc.z;
- channels = nhwc.w;
- }
-
- uint4 Dims()
- {
- return uint4(batch, height, width, channels);
- }
- uint GetFlatHeight()
- {
- return batch;
- }
- uint GetFlatWidth()
- {
- return height * width * channels;
- }
- uint GetKernelHeight()
- {
- // kernels storage: {kernel_width * kernel_height * kernel_channels * kernel_count}
- uint kernelHeight = batch;
- return kernelHeight;
- }
- uint GetKernelWidth()
- {
- // kernels storage: {kernel_width * kernel_height * kernel_channels * kernel_count}
- uint kernelWidth = height;
- return kernelWidth;
- }
- uint GetKernelDepth()
- {
- // kernels storage: {kernel_width * kernel_height * kernel_channels * kernel_count}
- uint kernelDepth = width;
- return kernelDepth;
- }
- uint GetKernelCount()
- {
- // kernels storage: {kernel_width * kernel_height * kernel_channels * kernel_count}
- uint kernelCount = channels;
- return kernelCount;
- }
- uint GetLength()
- {
- return batch * height * width * channels;
- }
-
- uint Index(uint b, uint h, uint w, uint ch)
- {
- uint index =
- b * height * width * channels +
- h * width * channels +
- w * channels +
- ch;
- return index;
- }
-
- uint Index(uint b, uint i)
- {
- uint index =
- b * height * width * channels +
- i;
- return index;
- }
-};
-
-struct ReadonlyTensor : Tensor
-{
- StructuredBuffer data;
-
- void Init(uint4 nhwc, StructuredBuffer data_)
- {
- Tensor::Init(nhwc);
- data = data_;
- }
-
- float Get(uint b, uint h, uint w, uint ch)
- {
- return data[Index(b,h,w,ch)];
- }
- float Get(uint b, uint2 pos, uint ch)
- {
- return data[Index(b, pos.y, pos.x, ch)];
- }
- float Get(uint b, uint i)
- {
- return data[Index(b,i)];
- }
- float Get(uint i)
- {
- return data[i];
- }
-
- float BroadcastGet(uint b, uint h, uint w, uint ch)
- {
- return Get(b % batch, h % height, w % width, ch % channels);
- }
- float BroadcastGet(uint b, uint2 pos, uint ch)
- {
- return BroadcastGet(b, pos.y, pos.x, ch);
- }
- float BroadcastGet(uint b, uint i)
- {
- return Get(b % GetFlatHeight(), i % GetFlatWidth());
- }
-
- float SafeGet(uint b, uint2 pos, uint ch, uint2 pad)
- {
- if (b >= batch || ch >= channels) return 0;
-
- if (any(pos < pad)) return 0;
- if (any(pos >= uint2(width, height) + pad)) return 0;
- pos -= pad;
-
- return data[Index(b, pos.y, pos.x, ch)];
- }
- float SafeGet(uint b, uint h, uint w, uint ch, uint2 pad)
- {
- return SafeGet(b, uint2(w, h), ch, pad);
- }
- float SafeGet(uint b, uint i)
- {
- if (b >= batch || i >= height * width * channels) return 0;
- return Get(b,i);
- }
- float SafeGet(uint i)
- {
- if (i >= batch * height * width * channels) return 0;
- return Get(i);
- }
-};
-
-struct ReadWriteTensor : Tensor
-{
- RWStructuredBuffer data;
-
- void Init(int4 nhwc, RWStructuredBuffer data_)
- {
- Tensor::Init(nhwc);
- data = data_;
- }
-
- float Get(uint b, uint h, uint w, uint ch)
- {
- return data[Index(b,h,w,ch)];
- }
- float Get(uint b, uint2 pos, uint ch)
- {
- return data[Index(b, pos.y, pos.x, ch)];
- }
- float Get(uint b, uint i)
- {
- return data[Index(b,i)];
- }
- float Get(uint i)
- {
- return data[i];
- }
-
- float BroadcastGet(uint b, uint h, uint w, uint ch)
- {
- return Get(b % batch, h % height, w % width, ch % channels);
- }
- float BroadcastGet(uint b, uint2 pos, uint ch)
- {
- return BroadcastGet(b, pos.y, pos.x, ch);
- }
- float BroadcastGet(uint b, uint i)
- {
- return Get(b % GetFlatHeight(), i % GetFlatWidth());
- }
-
- float SafeGet(uint b, uint2 pos, uint ch, uint2 pad)
- {
- if (b >= batch || ch >= channels) return 0;
-
- if (any(pos < pad)) return 0;
- if (any(pos >= uint2(width, height) + pad)) return 0;
- pos -= pad;
-
- return Get(b, pos.y, pos.x, ch);
- }
- float SafeGet(uint b, uint h, uint w, uint ch, uint2 pad)
- {
- return SafeGet(b, uint2(w, h), ch, pad);
- }
- float SafeGet(uint b, uint i)
- {
- if (b >= batch || i >= height * width * channels) return 0;
- return Get(b,i);
- }
- float SafeGet(uint i)
- {
- if (i >= batch * height * width * channels) return 0;
- return Get(i);
- }
-
-
- void Set(uint b, uint h, uint w, uint ch, float v)
- {
- data[Index(b,h,w,ch)] = v;
- }
- void Set(uint y, uint x, float v)
- {
- data[Index(y,x)] = v;
- }
- void Set(uint i, float v)
- {
- data[i] = v;
- }
-};
-
-struct SharedTensor : Tensor
-{
- StructuredBuffer data;
- uint offset;
-
- void Init(uint4 nhwc, uint4 info, StructuredBuffer data_)
- {
- Tensor::Init(nhwc);
- data = data_;
- offset = info.x;
- }
-
- float Get(uint b, uint h, uint w, uint ch)
- {
- return data[Index(b,h,w,ch) + offset];
- }
- float Get(uint b, uint2 pos, uint ch)
- {
- return Get(b, pos.y, pos.x, ch);
- }
- float Get(uint b, uint i)
- {
- return data[Index(b,i) + offset];
- }
- float Get(uint i)
- {
- return data[i + offset];
- }
-
- float BroadcastGet(uint b, uint h, uint w, uint ch)
- {
- return Get(b % batch, h % height, w % width, ch % channels);
- }
- float BroadcastGet(uint b, uint2 pos, uint ch)
- {
- return BroadcastGet(b, pos.y, pos.x, ch);
- }
- float BroadcastGet(uint b, uint i)
- {
- return Get(b % GetFlatHeight(), i % GetFlatWidth());
- }
-
- float SafeGet(uint b, uint2 pos, uint ch, uint2 pad)
- {
- if (b >= batch || ch >= channels) return 0;
-
- if (any(pos < pad)) return 0;
- if (any(pos >= uint2(width, height) + pad)) return 0;
- pos -= pad;
-
- return Get(b, pos, ch);
- }
- float SafeGet(uint b, uint h, uint w, uint ch, uint2 pad)
- {
- return SafeGet(b, uint2(w, h), ch, pad);
- }
- float SafeGet(uint b, uint i)
- {
- if (b >= batch || i >= height * width * channels) return 0;
- return Get(b,i);
- }
- float SafeGet(uint i)
- {
- if (i >= batch * height * width * channels) return 0;
- return Get(i);
- }
-};
-
-#define TENSOR_DECL(X) uint4 X##decl[2]; StructuredBuffer X##data;
-#define TENSOR_DECL_RW(X) uint4 X ## decl[2]; RWStructuredBuffer X ## data;
-
-#define TENSOR_ARG(X) ReadonlyTensor X; X##.Init(X##decl[0], X##data); // readonly
-#define TENSOR_MODEL(X) SharedTensor X; X##.Init(X##decl[0], X##decl[1], X##data); // RO w offset
-#define TENSOR_ARG_RW(X) ReadWriteTensor X; X##.Init(X##decl[0], X##data);
-
-#define TENSOR_ARGS2(X, O) TENSOR_ARG(X); TENSOR_ARG_RW(O);
-#define TENSOR_ARGS3(X, A, O) TENSOR_ARG(X); TENSOR_MODEL(A); TENSOR_ARG_RW(O);
-#define TENSOR_ARGS4(X, A, B, O) TENSOR_ARG(X); TENSOR_MODEL(A); TENSOR_MODEL(B); TENSOR_ARG_RW(O);
-
-// shared model tensors
-#define TENSOR_SHARED_MODEL(X, S) SharedTensor X; X##.Init(X##decl[0], X##decl[1], S##data);
-#define TENSOR_SHARED2_ARGS4(X, A, B, S, O) TENSOR_ARG(X); TENSOR_SHARED_MODEL(A, S); TENSOR_SHARED_MODEL(B, S); TENSOR_ARG_RW(O);
-
-
-// purely informational - declares contract between caller of Dispatch() and kernel
-#define DISPATCH_ARGS(threadGroupsX, threadGroupsY, threadGroupsZ)
-
-
-// @TODO: move into more appropriate file
-#define FLT_MAX 3.402823466e+38F
-#define FLT_EPSILON 1e-6
-
-float fastfma(float a, float b, float c)
-{
- return dot(float2(a,c), float2(b, 1));
-}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Tensor.cginc.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Tensor.cginc.meta
deleted file mode 100644
index c611dd01f4..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/Tensor.cginc.meta
+++ /dev/null
@@ -1,9 +0,0 @@
-fileFormatVersion: 2
-guid: 5761abd87a16940b2a81aaa755787fc9
-timeCreated: 1506540305
-licenseType: Pro
-ShaderImporter:
- defaultTextures: []
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/TexConv.compute b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/TexConv.compute
deleted file mode 100644
index e3174b1ac5..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/TexConv.compute
+++ /dev/null
@@ -1,99 +0,0 @@
-#pragma kernel TexConv2D
-
-#include "Tensor.cginc"
-
-TENSOR_DECL(X)
-TENSOR_DECL(K)
-TENSOR_DECL(B)
-TENSOR_DECL(WBK)
-TENSOR_DECL_RW(O)
-
-uint4 _Pad;
-uint4 _Stride;
-
-struct TextureAsTensor : Tensor
-{
- Texture2D tex;
- SamplerState smp;
-
- Texture2DArray texArray;
- SamplerState smpArray;
-
- void Init(uint4 nhwc, Texture2D tex_, SamplerState sampler_, Texture2DArray texArray_, SamplerState samplerArray_)
- {
- Tensor::Init(nhwc);
- tex = tex_;
- smp = sampler_;
- texArray = texArray_;
- smpArray = samplerArray_;
- }
-
- float4 Get(uint b, uint y, uint x)
- {
- float3 loc = float3((float)x / (float)width, (float)y / (float)height, b);
- if (batch > 1)
- return texArray.SampleLevel(smpArray, loc, 0);
- else
- return tex.SampleLevel(smp, loc.xy, 0);
- }
-};
-
-#define TENSOR_SHARED2_ARGS3(A, B, S, O) TENSOR_SHARED_ARG(A, S); TENSOR_SHARED_ARG(B, S); TENSOR_ARG_RW(O);
-Texture2DArray Xtex2DArray;
-Texture2D Xtex2D;
-SamplerState samplerXtex2D { Filter = MIN_MAG_LINEAR_MIP_POINT; AddressU = Clamp; AddressV = Clamp; };
-SamplerState samplerXtex2DArray { Filter = MIN_MAG_LINEAR_MIP_POINT; AddressU = Clamp; AddressV = Clamp; };
-
-#define MAX_CHANNELS 4
-
-NUMTHREADS((16,4,4), (16,4,2), (16,2,2))
-void TexConv2D(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
-// @TODO: currently it fails to compile, needs to be investigated
-#if 0
- DISPATCH_ARGS(K.kernelCount, O.width, O.height);
- TextureAsTensor X; X.Init(Xdecl[0], Xtex2D, samplerXtex2D, Xtex2DArray, samplerXtex2DArray);
-
- TENSOR_SHARED_ARG(K, WBK);
- TENSOR_SHARED_ARG(B, WBK);
- TENSOR_ARG_RW(O);
-
- // ASSERT(X.channels <= MAX_CHANNELS)
-
- uint k = dispatchThreadID.x;
- uint x = dispatchThreadID.y;
- uint y = dispatchThreadID.z;
-
- if (k >= K.channels) return;
- if (x >= O.width) return;
- if (y >= O.height) return;
-
- for (uint n = 0; n < O.batch; ++n)
- {
- float acc = B.Get(k);
- for (uint dy = 0; dy < K.GetKernelHeight(); ++dy)
- {
- for (uint dx = 0; dx < K.GetKernelWidth(); ++dx)
- {
- uint oy = y * _Stride.y + dy;
- uint ox = x * _Stride.x + dx;
-
- // @TODO: investigate
- // WARNING: had to move both y check into the loop (as opposed to checking y in parent loop) - due to potential bug in Metal compiler
- if (oy < _Pad.y) continue;
- if (oy - _Pad.w >= X.height) continue;
- if (ox < _Pad.x) continue;
- if (ox - _Pad.z >= X.width) continue;
-
- float4 in4channels = X.Get(n, oy - _Pad.y, ox - _Pad.x);
- for (uint c = 0; c < X.channels && c < MAX_CHANNELS; ++c)
- {
- acc += in4channels[c] * K.Get(dy, dx, c, k);
- }
- }
- }
-
- O.Set(n, y, x, k, acc);
- }
-#endif
-}
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/TexConv.compute.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/TexConv.compute.meta
deleted file mode 100644
index 38baaf9613..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Resources/TexConv.compute.meta
+++ /dev/null
@@ -1,9 +0,0 @@
-fileFormatVersion: 2
-guid: 85d38d76f835143f797bca1481285596
-timeCreated: 1507637303
-licenseType: Pro
-ComputeShaderImporter:
- currentAPIMask: 196608
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/LICENSE.md b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/LICENSE.md
deleted file mode 100644
index 389755fbf1..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/LICENSE.md
+++ /dev/null
@@ -1,6 +0,0 @@
-Barracuda cross-platform Neural Net engine copyright © 2018 Unity Technologies ApS
-
-Licensed under the Unity Companion License for Unity-dependent projects--see [Unity Companion License](http://www.unity3d.com/legal/licenses/Unity_Companion_License).
-
-Unless expressly provided otherwise, the Software under this license is made available strictly on an “AS IS” BASIS WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. Please review the license for details on these and other terms and conditions.
-
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/LICENSE.md.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/LICENSE.md.meta
deleted file mode 100644
index a68e6e466d..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/LICENSE.md.meta
+++ /dev/null
@@ -1,7 +0,0 @@
-fileFormatVersion: 2
-guid: dcc5ce8caa7664f8090ef0103a208c6e
-TextScriptImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/ReleaseNotes.md b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/ReleaseNotes.md
deleted file mode 100644
index a195acffb2..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/ReleaseNotes.md
+++ /dev/null
@@ -1,152 +0,0 @@
-# Release notes
-
-## 0.2.4
-- Switched to 2018.4.3f1 as primary Unity version for testing.
-- Fixed ScaleBias scheduling issue with large amounts of data (reproduced with MobileNet @ 16 batch)
-- Fixed buffer overrun in ThreadGroup SharedMemory when TRANSPOSE_X and/or SHIFTED_X paths are enabled. This should fix GPU worker issues on Windows.
-- Added string cache to minimise string concat generated GC pressure.
-- Added small fixes for temp memory allocations, saves ~200B per layer.
-- Refactored inner loop workings, to avoid GC allocations for delegates.
-- Fixed input handling for layers, now inputs are not regenerated with every execution. Static model tensors are to stay forever until worker is disposed.
-- Bumped Burst version to 1.1.1.
-
-## 0.2.3
-- Rewritten Dense, Conv and some other ops on GPU. Speedup of 33% in most models with batch=1 and over 100% for batch=16.
-- Optimizations: reimplemented InstanceNormalization using pyramid approach for calculating mean and variance.
-
-## 0.2.2
-- Added support for --print-supported-ops flag for model converters, now it will print approximate list of supported operations. List of supported ops depends on converter.
-- Added Keras converter as part of distribution.
-- Now compute shaders are loaded only if GPU worker is requested.
-- Fixed bug in MaxPool and AvgPool padding. Issue discovered by Yolo faces network.
-- Fixed bug in Transpose convolution support for C# backend.
-- Fixed TF model conversion with two LSTM cells.
-- Fixed case when strided slice end overflows to zero and thus producing negative range.
-
-## 0.2.1
-- TF importer: fixed ResizeNearestNeighbor aka Upsample2D scaling factor detection.
-- TF importer: optimized node sorting. Should be faster than 0.2.0.
-- TF importer: made detection of actual output node from LSTM/GRU pattern more bullet proof by skipping Const nodes.
-- TF importer: improved InstanceNormalization handling.
-- TF importer: fixed SquareDifference pattern.
-- TF importer: fixed Conv2DBackpropInput (transpose convolution) import.
-- Fixed Conv2D performance regression on some GPUs.
-- Fixed TextureAsTensorData.Download() to work properly with InterpretDepthAs.Channels.
-- Fixed bug when identity/nop layers would reuse input as an output and later causing premature release of that tensor as part of intermediate data cleanup.
-- Added scale + bias to TenstorToRenderTexture interface, usefull for adjusting network output scale + bias on the fly.
-- Fixed double Dispose issue when worker gets garbage collected.
-
-## 0.2.0
-- Version bumped to 0.2.0 as it brings breaking API changes, for details look below.
-- Significantly reduced temporary memory allocations by introducing internal allocator support. Now memory is re-used between layer execution as much as possible.
-- Improved small workload performance on CSharp backend
-- Added parallel implementation for multiple activation functions on CSharp backend
-- Added `Peek()` function to `IWorker`, it retains object storage in worker's allocator, useful for quick grabbing of output. If you want to preserve content of output tensor between `Execute()` invocations, then use `Fetch()`.
-- Fixed ESRGAN model conversion (ONNX importer).
-- Fixed Tensor <-> Texture copy for textures/tensors that dimensions are not multiple of 8.
-- Added `Summary()` method to `Worker`. Currently returns allocator information.
-- Tabs to spaces! Aiming at higher salary (https://stackoverflow.blog/2017/06/15/developers-use-spaces-make-money-use-tabs/).
-- Renamed worker type enum members: `CSharp` -> `CSharpRef`, `CSharpFast` -> `CSharp`, `Compute` -> `ComputeRef`, `ComputeFast` -> `Compute`.
-- Implemented new optimized `ComputePrecompiled` worker. This worker caches Compute kernels and state beforehand to reduce CPU overhead.
-- Added `ExecuteAsync()` to `IWorker` interface, it returns `IEnumerator`, which enables you to control how many layers to schedule per frame (one iteration == one layer).
-- Added `Log` op support on Compute workers.
-- Optimized activation functions and ScaleBias by accessing tensor as continuous array. Gained ~2.0ms on 4 batch MobileNet (MBP2016).
-- Introduced _Loop version of activations to fight 65535 scheduling limit on D3D11.
-- Added .nn as Barracuda model file extension for use in Unity Editor. Also added simple editor importer. Now you can declare serializable fields as NNModel to bind them to .nn asset. ModelLoader.Load() now accepts NNModel as a source.
-- Compute: Reduce reference GPU implementation.
-- TF importer: Expanded Mean support to mean over channels, implemented Pad (as Border2D), implemented SquaredDifference, added InstanceNormalization and LeakyRelu patterns, StridedSlice implementation.
-- TF importer: sort model nodes by dependencies before processing.
-- Fixed ComputeBuffer leak when using Compute and ComputePrecompiled backends.
-- Made to use Conv2D_L1Cached64_RegisterBlock4x4 more often: improves perf ~2x on Vega 16, and ~30% on Nvidia and Intel.
-
-## 0.1.6
-- Added activation type print in verbose mode
-- Added fast and parallel CPU implementation for Swish, Relu, Add, Sub, Div, Min, Max, Tanh, Exp
-- Removed duplicate profiler blocks for ops
-- Improved scheduling on CPU for small batches of data
-- Fixed compatibility with Unity 2019.2.x
-
-## 0.1.5
-- Added Transpose, MatMul and Indentity layer support for models exported from ONNX.
-- Added BasicLSTM layer support for models exported from TF. Limited set of LSTM networks should work now.
-- Added DepthwiseConv2D layer support. Most of the networks based on the MobileNet should work now.
-- Added OneHot layer support for models exported from TF.
-- Added optimized path for Conv2D, Dense and Transpose layers with single batch executions. Performance gain up to 100%.
-- Fixed FMA performance issue on Metal GFX platforms.
-- Added fast optimized path for Sigmoid and Mul layers on CPU.
-- Fixed issue when worker is executed with different batch sizes.
-- Added ``pip`` requirements file for Python dependencies, check ``Tools/requirements.txt```.
-- Added proof of concept Docker wrappers for running model conversion inside of Docker container. Check ``Tools/docker-tensorflow-to-barracuda.sh`` and ``Tools/docker-onnx-to-barracuda.sh``. Currently it was tested only on Mac host.
-- Refactored model importers for easier integration with ML Agents.
-- Fixed input shape determination for Keras sequential model.
-- Added metadata about input shapes to model. Look for ``Model.GetShapeByName()``.
-- Added API to query constant Tensors embedded into network, look for ``Model.GetTensorByName()``.
-- Added reference implementations for Selu, Abs, Neg, Ceil, Floor, Clip, Rcp, Log layers.
-- Added support for Mean, Square, StridedSlice and Border2D layers.
-- Added support for Swish activation, now it is automatically detected in models.
-- Fixed Tanh NaN issue when large argument is passed.
-- RandomNormal and RandomUniform now supports either embedded shape constant OR previous tensor shape for input.
-- Fixed Keras/TF/ONNX FusedBatchNorm/BatchNorm import and now it takes ``epsilon`` into account.
-- Now Barracuda will fallback to CSharpFast if compute shaders are not supported on the current platform.
-- Improved compute kernel interop on Android.
-- Implemented Pix2Pix model (.pict) importer.
-
-## 0.1.4
-- Implemented fast Conv2DTrans. Useful for GAN type networks.
-- Fixed few ComputeBuffer handling issues.
-- Simplified way to pass texture via ``Tensor`` constructor.
-- Documentation improvements.
-- Added Unity Companion License as part of distribution.
-- Fixed boundary checks for Compute Copy/Concat operations.
-- Improved profiling experience, now each layer will be reported separately in Unity Profiler.
-- Fixed Broadcast layer support in ``ModelAnalyzer``.
-- Exp, Pow and other layers are now also implemented in Compute. Improves RL model inference performance on GPU.
-- Added platform specific BLAS plugin support. Out of the box Barracuda ships with Apple Accelerate framework support for iOS and macOS.
-- Added Burst BLAS plugin, greatly improves performance in Unity Editor where native OS BLAS is not available. It's packaged as separate package and requires to have Burst enabled.
-- Improved memory handling, now less GC allocations should be made per inference execution.
-
-## 0.1.3
-- Improved Barracuda support for Unity Profiler.
-- Cleaned up Barracuda APIs.
-- Added direct ``Texture`` input support. Look for ``TextureAsTensorData``. The following types of texture supported as input: ``Texture2D``, ``Texture2DArray``, ``Texture3D``, ``RenderTexture``.
-- Added ``Tensor`` to ``RenderTexture`` conversion. Look for ``TensorToRenderTexture``.
-- Autoencoder type networks can run completely on GPU now. Data roundtrip via CPU is not necessary anymore.
-- Vertical flip is applied when converting between ``Texture`` and ``Tensor`` to match conventionts. To override this behavior look for ``TextureAsTensorData.Flip`` enum.
-- Removed direct reference to WebCamTexture, now Barracuda compiles for Console targets.
-- Fixed _Conv2DTranspose_ layer support. Now GANs using _Conv2DTranspose_ work properly.
-- Added automated test for pix2pix GAN.
-
-## 0.1.2
-- Barracuda now is also available as preview package. Look for ``com.unity.barracuda`` in https://staging-packages.unity.com registry.
-- Conv2D layers are now *up to 30x faster* with ``CSharpFast`` backend (``ComputeFast`` remains best backend for convolutional networks).
-- Added profiler sample for ``Fetch()``.
-- Fixed compilation issues on Xbox One.
-- TexConv2D support was temporary disabled.
-- Barracuda logging now can be configured via static fields of ``Barracuda.D`` class, it allows both disable specific logging levels or just disable stack trace collection (helps with performance when profiling).
-- Compute Concat implementation now will fall back to C# implementation instead of throwing exception when unsupported configuration is encountered.
-- Fixed several ``ComputeBuffer`` release issues.
-- Added constructor for ``Tensor`` that allows to pass in data array.
-- Improved Flatten handling in TensorFlow models.
-- Added helper func ``ModelLoader.LoadFromStreamingAssets``.
-- Fixed .meta file packaging.
-- Small docs improvements.
-- Fixed unnecessary patching of Activation layers in ``ModelLoader``.
-- Added output trimming at run-time. See for extra parameters Worker factory.
-
-## 0.1.1
-- First internal realease as drop-in package
-- Compatibility with ML Agents models: 3DBall, PushBlock, GridWorld, Soccer.
-
-## 0.1.0
-- First internal build. Due some bugs encountered wasn't published.
-
-#Contributors
-- Renaldas (ReJ) Zioma
-- Mantas Puida
-- Vladimir Oster
-- Aurimas Petrovas
-- Martin Sternevald
-- Valdemar Bučilko
-- Kuba Cupisz
-- Povilas Kanapickas
-- Paulius Puodžiūnas
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/ReleaseNotes.md.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/ReleaseNotes.md.meta
deleted file mode 100644
index 2d0ff280f0..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/ReleaseNotes.md.meta
+++ /dev/null
@@ -1,7 +0,0 @@
-fileFormatVersion: 2
-guid: a129912fffc9d4ab3b5ae110be67a669
-TextScriptImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/package.json b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/package.json
deleted file mode 100644
index 4d09c393a7..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/package.json
+++ /dev/null
@@ -1,8 +0,0 @@
-{
- "name": "com.unity.barracuda",
- "displayName": "Barracuda",
- "version": "0.2.4-preview",
- "unity": "2017.4",
- "description": "Barracuda is lightweight and cross-platform Neural Net inference library. Barracuda supports inference both on GPU and CPU.",
- "dependencies": {}
-}
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/package.json.meta b/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/package.json.meta
deleted file mode 100644
index e4c32c936c..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/package.json.meta
+++ /dev/null
@@ -1,7 +0,0 @@
-fileFormatVersion: 2
-guid: 73ae2d877fd444b04b5b6ef591d3fa0e
-TextScriptImporter:
- externalObjects: {}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/ProtoBuffer/Google.Protobuf.dll b/UnitySDK/Assets/ML-Agents/Plugins/ProtoBuffer/Google.Protobuf.dll
deleted file mode 100755
index 6ea720de8d..0000000000
Binary files a/UnitySDK/Assets/ML-Agents/Plugins/ProtoBuffer/Google.Protobuf.dll and /dev/null differ
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/ProtoBuffer/Google.Protobuf.dll.meta b/UnitySDK/Assets/ML-Agents/Plugins/ProtoBuffer/Google.Protobuf.dll.meta
deleted file mode 100644
index e08504227e..0000000000
--- a/UnitySDK/Assets/ML-Agents/Plugins/ProtoBuffer/Google.Protobuf.dll.meta
+++ /dev/null
@@ -1,30 +0,0 @@
-fileFormatVersion: 2
-guid: 0836ffd04a4924861a2d58aa4b111937
-PluginImporter:
- externalObjects: {}
- serializedVersion: 2
- iconMap: {}
- executionOrder: {}
- isPreloaded: 0
- isOverridable: 0
- platformData:
- - first:
- Any:
- second:
- enabled: 1
- settings: {}
- - first:
- Editor: Editor
- second:
- enabled: 0
- settings:
- DefaultValueInitialized: true
- - first:
- Windows Store Apps: WindowsStoreApps
- second:
- enabled: 0
- settings:
- CPU: AnyCPU
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Academy.cs b/UnitySDK/Assets/ML-Agents/Scripts/Academy.cs
index 973cb94d3a..6d297f7406 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Academy.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Academy.cs
@@ -92,23 +92,23 @@ public EnvironmentConfiguration(
"docs/Learning-Environment-Design-Academy.md")]
public abstract class Academy : MonoBehaviour
{
- private const string k_ApiVersion = "API-11";
+ const string k_ApiVersion = "API-12";
/// Temporary storage for global gravity value
/// Used to restore oringal value when deriving Academy modifies it
- private Vector3 m_OriginalGravity;
+ Vector3 m_OriginalGravity;
/// Temporary storage for global fixedDeltaTime value
/// Used to restore original value when deriving Academy modifies it
- private float m_OriginalFixedDeltaTime;
+ float m_OriginalFixedDeltaTime;
/// Temporary storage for global maximumDeltaTime value
/// Used to restore original value when deriving Academy modifies it
- private float m_OriginalMaximumDeltaTime;
+ float m_OriginalMaximumDeltaTime;
// Fields provided in the Inspector
- [FormerlySerializedAs("maxSteps")]
+ [FormerlySerializedAs("trainingConfiguration")]
[SerializeField]
[Tooltip("The engine-level settings which correspond to rendering " +
"quality and engine speed during Training.")]
@@ -153,7 +153,7 @@ public bool IsCommunicatorOn
/// If true, the Academy will use inference settings. This field is
/// initialized in depending on the presence
- /// or absence of a communicator. Furthermore, it can be modified during
+ /// or absence of a communicator. Furthermore, it can be modified during
/// training via .
bool m_IsInference = true;
@@ -178,19 +178,19 @@ public bool IsCommunicatorOn
/// Pointer to the communicator currently in use by the Academy.
public ICommunicator Communicator;
- private bool m_Initialized;
- private List m_ModelRunners = new List();
+ bool m_Initialized;
+ List m_ModelRunners = new List();
// Flag used to keep track of the first time the Academy is reset.
bool m_FirstAcademyReset;
- // The Academy uses a series of events to communicate with agents
+ // The Academy uses a series of events to communicate with agents
// to facilitate synchronization. More specifically, it ensure
// that all the agents performs their steps in a consistent order (i.e. no
// agent can act based on a decision before another agent has had a chance
// to request a decision).
- // Signals to all the Agents at each environment step so they can use
+ // Signals to all the Agents at each environment step so they can use
// their Policy to decide on their next action.
public event System.Action DecideAction;
@@ -240,7 +240,7 @@ public void LazyInitialization()
}
// Used to read Python-provided environment parameters
- private static int ReadArgs()
+ static int ReadArgs()
{
var args = System.Environment.GetCommandLineArgs();
var inputPort = "";
@@ -258,7 +258,7 @@ private static int ReadArgs()
///
/// Initializes the environment, configures it and initialized the Academy.
///
- private void InitializeEnvironment()
+ void InitializeEnvironment()
{
m_OriginalGravity = Physics.gravity;
m_OriginalFixedDeltaTime = Time.fixedDeltaTime;
@@ -344,7 +344,7 @@ static void OnQuitCommandReceived()
Application.Quit();
}
- private void OnResetCommand(EnvironmentResetParameters newResetParameters)
+ void OnResetCommand(EnvironmentResetParameters newResetParameters)
{
UpdateResetParameters(newResetParameters);
ForcedFullReset();
@@ -355,7 +355,7 @@ void OnRLInputReceived(UnityRLInputParameters inputParams)
m_IsInference = !inputParams.isTraining;
}
- private void UpdateResetParameters(EnvironmentResetParameters newResetParameters)
+ void UpdateResetParameters(EnvironmentResetParameters newResetParameters)
{
if (newResetParameters.resetParameters != null)
{
@@ -568,13 +568,13 @@ void FixedUpdate()
}
///
- /// Creates or retrieves an existing ModelRunner that uses the same
+ /// Creates or retrieves an existing ModelRunner that uses the same
/// NNModel and the InferenceDevice as provided.
///
/// The NNModel the ModelRunner must use
- /// The brainParameters used to create
+ /// The brainParameters used to create
/// the ModelRunner
- /// The inference device (CPU or GPU)
+ /// The inference device (CPU or GPU)
/// the ModelRunner will use
/// The ModelRunner compatible with the input settings
public ModelRunner GetOrCreateModelRunner(
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/ActionMasker.cs b/UnitySDK/Assets/ML-Agents/Scripts/ActionMasker.cs
index 4f0e2766f2..b38afb16e6 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/ActionMasker.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/ActionMasker.cs
@@ -8,11 +8,11 @@ public class ActionMasker
{
/// When using discrete control, is the starting indices of the actions
/// when all the branches are concatenated with each other.
- private int[] m_StartingActionIndices;
+ int[] m_StartingActionIndices;
- private bool[] m_CurrentMask;
+ bool[] m_CurrentMask;
- private readonly BrainParameters m_BrainParameters;
+ readonly BrainParameters m_BrainParameters;
public ActionMasker(BrainParameters brainParameters)
{
@@ -79,7 +79,7 @@ public bool[] GetMask()
///
/// Makes sure that the current mask is usable.
///
- private void AssertMask()
+ void AssertMask()
{
// Action Masks can only be used in Discrete Control.
if (m_BrainParameters.vectorActionSpaceType != SpaceType.Discrete)
@@ -116,7 +116,7 @@ public void ResetMask()
///
/// The index of the branch to check
/// True if all the actions of the branch are masked
- private bool AreAllActionsMasked(int branch)
+ bool AreAllActionsMasked(int branch)
{
if (m_CurrentMask == null)
{
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Agent.cs b/UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
index f8abf6dfd6..5c588daafc 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
@@ -1,9 +1,9 @@
+using System;
using System.Collections.Generic;
using UnityEngine;
using Barracuda;
using MLAgents.Sensor;
-
-
+using UnityEngine.Serialization;
namespace MLAgents
{
@@ -14,51 +14,21 @@ namespace MLAgents
public struct AgentInfo
{
///
- /// Most recent agent vector (i.e. numeric) observation.
- ///
- public List vectorObservation;
-
- ///
- /// The previous agent vector observations, stacked. The length of the
- /// history (i.e. number of vector observations to stack) is specified
- /// in the Brain parameters.
- ///
- public List stackedVectorObservation;
-
- ///
- /// Most recent compressed observations.
- ///
- public List compressedObservations;
-
- ///
- /// Most recent text observation.
+ /// Most recent observations.
///
- public string textObservation;
+ public List observations;
///
/// Keeps track of the last vector action taken by the Brain.
///
public float[] storedVectorActions;
- ///
- /// Keeps track of the last text action taken by the Brain.
- ///
- public string storedTextActions;
-
///
/// For discrete control, specifies the actions that the agent cannot take. Is true if
/// the action is masked.
///
public bool[] actionMasks;
- ///
- /// Used by the Trainer to store information about the agent. This data
- /// structure is not consumed or modified by the agent directly, they are
- /// just the owners of their trainier's memory. Currently, however, the
- /// size of the memory is in the Brain properties.
- ///
- public List memories;
-
///
/// Current agent reward.
///
@@ -79,13 +49,6 @@ public struct AgentInfo
/// to separate between different agents in the environment.
///
public int id;
-
- ///
- /// User-customizable object for sending structured output from Unity to Python in response
- /// to an action in addition to a scalar reward.
- /// TODO(cgoy): All references to protobuf objects should be removed.
- ///
- public CommunicatorObjects.CustomObservationProto customObservation;
}
///
@@ -95,11 +58,7 @@ public struct AgentInfo
public struct AgentAction
{
public float[] vectorActions;
- public string textActions;
- public List memories;
public float value;
- /// TODO(cgoy): All references to protobuf objects should be removed.
- public CommunicatorObjects.CustomActionProto customAction;
}
///
@@ -107,7 +66,7 @@ public struct AgentAction
/// Editor. This excludes the Brain linked to the Agent since it can be
/// modified programmatically.
///
- [System.Serializable]
+ [Serializable]
public class AgentParameters
{
///
@@ -194,12 +153,12 @@ public class AgentParameters
///
[HelpURL("https://github.com/Unity-Technologies/ml-agents/blob/master/" +
"docs/Learning-Environment-Design-Agents.md")]
- [System.Serializable]
+ [Serializable]
[RequireComponent(typeof(BehaviorParameters))]
public abstract class Agent : MonoBehaviour
{
- private IPolicy m_Brain;
- private BehaviorParameters m_PolicyFactory;
+ IPolicy m_Brain;
+ BehaviorParameters m_PolicyFactory;
///
/// Agent parameters specified within the Editor via AgentEditor.
@@ -261,16 +220,33 @@ public AgentInfo Info
int m_Id;
/// Keeps track of the actions that are masked at each step.
- private ActionMasker m_ActionMasker;
+ ActionMasker m_ActionMasker;
///
/// Demonstration recorder.
///
- private DemonstrationRecorder m_Recorder;
+ DemonstrationRecorder m_Recorder;
- public List m_Sensors;
+ ///
+ /// List of sensors used to generate observations.
+ /// Currently generated from attached SensorComponents, and a legacy VectorSensor
+ ///
+ [FormerlySerializedAs("m_Sensors")]
+ public List sensors;
- /// Monobehavior function that is called when the attached GameObject
+ ///
+ /// VectorSensor which is written to by AddVectorObs
+ ///
+ public VectorSensor collectObservationsSensor;
+
+ ///
+ /// Internal buffer used for generating float observations.
+ ///
+ float[] m_VectorSensorBuffer;
+
+ WriteAdapter m_WriteAdapter = new WriteAdapter();
+
+ /// MonoBehaviour function that is called when the attached GameObject
/// becomes enabled or active.
void OnEnable()
{
@@ -288,7 +264,7 @@ void OnEnableHelper(Academy academy)
{
m_Info = new AgentInfo();
m_Action = new AgentAction();
- m_Sensors = new List();
+ sensors = new List();
if (academy == null)
{
@@ -336,7 +312,7 @@ void OnDisable()
/// will categorize the agent when training.
///
/// The model to use for inference.
- /// Define on what device the model
+ /// Define on what device the model
/// will be run.
public void GiveModel(
string behaviorName,
@@ -481,22 +457,7 @@ void ResetData()
}
}
- if (m_Info.textObservation == null)
- m_Info.textObservation = "";
- m_Action.textActions = "";
- m_Info.memories = new List();
- m_Action.memories = new List();
- m_Info.vectorObservation =
- new List(param.vectorObservationSize);
- m_Info.stackedVectorObservation =
- new List(param.vectorObservationSize
- * param.numStackedVectorObservations);
- m_Info.stackedVectorObservation.AddRange(
- new float[param.vectorObservationSize
- * param.numStackedVectorObservations]);
-
- m_Info.compressedObservations = new List();
- m_Info.customObservation = null;
+ m_Info.observations = new List();
}
///
@@ -534,23 +495,51 @@ public virtual float[] Heuristic()
///
public void InitializeSensors()
{
+ // Get all attached sensor components
var attachedSensorComponents = GetComponents();
- m_Sensors.Capacity += attachedSensorComponents.Length;
+ sensors.Capacity += attachedSensorComponents.Length;
foreach (var component in attachedSensorComponents)
{
- m_Sensors.Add(component.CreateSensor());
+ sensors.Add(component.CreateSensor());
}
- // Sort the sensors by name to ensure determinism
- m_Sensors.Sort((x, y) => x.GetName().CompareTo(y.GetName()));
+ // Support legacy CollectObservations
+ var param = m_PolicyFactory.brainParameters;
+ if (param.vectorObservationSize > 0)
+ {
+ collectObservationsSensor = new VectorSensor(param.vectorObservationSize);
+ if (param.numStackedVectorObservations > 1)
+ {
+ var stackingSensor = new StackingSensor(collectObservationsSensor, param.numStackedVectorObservations);
+ sensors.Add(stackingSensor);
+ }
+ else
+ {
+ sensors.Add(collectObservationsSensor);
+ }
+ }
+
+ // Sort the Sensors by name to ensure determinism
+ sensors.Sort((x, y) => x.GetName().CompareTo(y.GetName()));
#if DEBUG
// Make sure the names are actually unique
- for (var i = 0; i < m_Sensors.Count - 1; i++)
+ for (var i = 0; i < sensors.Count - 1; i++)
{
- Debug.Assert(!m_Sensors[i].GetName().Equals(m_Sensors[i + 1].GetName()), "Sensor names must be unique.");
+ Debug.Assert(!sensors[i].GetName().Equals(sensors[i + 1].GetName()), "Sensor names must be unique.");
}
#endif
+ // Create a buffer for writing vector sensor data too
+ int numFloatObservations = 0;
+ for (var i = 0; i < sensors.Count; i++)
+ {
+ if (sensors[i].GetCompressionType() == SensorCompressionType.None)
+ {
+ numFloatObservations += sensors[i].ObservationSize();
+ }
+ }
+
+ m_VectorSensorBuffer = new float[numFloatObservations];
}
///
@@ -563,33 +552,17 @@ void SendInfoToBrain()
return;
}
- m_Info.memories = m_Action.memories;
m_Info.storedVectorActions = m_Action.vectorActions;
- m_Info.storedTextActions = m_Action.textActions;
- m_Info.vectorObservation.Clear();
- m_Info.compressedObservations.Clear();
+ m_Info.observations.Clear();
m_ActionMasker.ResetMask();
+ UpdateSensors();
using (TimerStack.Instance.Scoped("CollectObservations"))
{
CollectObservations();
}
m_Info.actionMasks = m_ActionMasker.GetMask();
- var param = m_PolicyFactory.brainParameters;
- if (m_Info.vectorObservation.Count != param.vectorObservationSize)
- {
- throw new UnityAgentsException(string.Format(
- "Vector Observation size mismatch in continuous " +
- "agent {0}. " +
- "Was Expecting {1} but received {2}. ",
- gameObject.name,
- param.vectorObservationSize,
- m_Info.vectorObservation.Count));
- }
-
- Utilities.ShiftLeft(m_Info.stackedVectorObservation, param.vectorObservationSize);
- Utilities.ReplaceRange(m_Info.stackedVectorObservation, m_Info.vectorObservation,
- m_Info.stackedVectorObservation.Count - m_Info.vectorObservation.Count);
+ // var param = m_PolicyFactory.brainParameters; // look, no brain params!
m_Info.reward = m_Reward;
m_Info.done = m_Done;
@@ -600,9 +573,9 @@ void SendInfoToBrain()
if (m_Recorder != null && m_Recorder.record && Application.isEditor)
{
- // This is a bit of a hack - if we're in inference mode, compressed observations won't be generated
+ // This is a bit of a hack - if we're in inference mode, observations won't be generated
// But we need these to be generated for the recorder. So generate them here.
- if (m_Info.compressedObservations.Count == 0)
+ if (m_Info.observations.Count == 0)
{
GenerateSensorData();
}
@@ -610,34 +583,59 @@ void SendInfoToBrain()
m_Recorder.WriteExperience(m_Info);
}
- m_Info.textObservation = "";
+ }
+
+ void UpdateSensors()
+ {
+ for (var i = 0; i < sensors.Count; i++)
+ {
+ sensors[i].Update();
+ }
}
///
/// Generate data for each sensor and store it on the Agent's AgentInfo.
/// NOTE: At the moment, this is only called during training or when using a DemonstrationRecorder;
- /// during inference the sensors are used to write directly to the Tensor data. This will likely change in the
+ /// during inference the Sensors are used to write directly to the Tensor data. This will likely change in the
/// future to be controlled by the type of brain being used.
///
public void GenerateSensorData()
{
- // Generate data for all sensors
- // TODO add bool argument indicating when to compress? For now, we always will compress.
- for (var i = 0; i < m_Sensors.Count; i++)
+ int floatsWritten = 0;
+ // Generate data for all Sensors
+ for (var i = 0; i < sensors.Count; i++)
{
- var sensor = m_Sensors[i];
- var compressedObs = new CompressedObservation
+ var sensor = sensors[i];
+ if (sensor.GetCompressionType() == SensorCompressionType.None)
+ {
+ // only handles 1D
+ // TODO handle in communicator code instead
+ m_WriteAdapter.SetTarget(m_VectorSensorBuffer, floatsWritten);
+ var numFloats = sensor.Write(m_WriteAdapter);
+ var floatObs = new Observation
+ {
+ FloatData = new ArraySegment(m_VectorSensorBuffer, floatsWritten, numFloats),
+ Shape = sensor.GetFloatObservationShape(),
+ CompressionType = sensor.GetCompressionType()
+ };
+ m_Info.observations.Add(floatObs);
+ floatsWritten += numFloats;
+ }
+ else
{
- Data = sensor.GetCompressedObservation(),
- Shape = sensor.GetFloatObservationShape(),
- CompressionType = sensor.GetCompressionType()
- };
- m_Info.compressedObservations.Add(compressedObs);
+ var compressedObs = new Observation
+ {
+ CompressedData = sensor.GetCompressedObservation(),
+ Shape = sensor.GetFloatObservationShape(),
+ CompressionType = sensor.GetCompressionType()
+ };
+ m_Info.observations.Add(compressedObs);
+ }
}
}
///
- /// Collects the (vector, visual, text) observations of the agent.
+ /// Collects the (vector, visual) observations of the agent.
/// The agent observation describes the current environment from the
/// perspective of the agent.
///
@@ -646,7 +644,7 @@ public void GenerateSensorData()
/// the Agent acheive its goal. For example, for a fighting Agent, its
/// observation could include distances to friends or enemies, or the
/// current level of ammunition at its disposal.
- /// Recall that an Agent may attach vector, visual or textual observations.
+ /// Recall that an Agent may attach vector or visual observations.
/// Vector observations are added by calling the provided helper methods:
/// -
/// -
@@ -667,8 +665,6 @@ public void GenerateSensorData()
/// needs to match the vectorObservationSize attribute of the linked Brain.
/// Visual observations are implicitly added from the cameras attached to
/// the Agent.
- /// Lastly, textual observations are added using
- /// .
///
public virtual void CollectObservations()
{
@@ -731,7 +727,7 @@ protected void SetActionMask(int branch, IEnumerable actionIndices)
/// Observation.
protected void AddVectorObs(float observation)
{
- m_Info.vectorObservation.Add(observation);
+ collectObservationsSensor.AddObservation(observation);
}
///
@@ -741,7 +737,7 @@ protected void AddVectorObs(float observation)
/// Observation.
protected void AddVectorObs(int observation)
{
- m_Info.vectorObservation.Add(observation);
+ collectObservationsSensor.AddObservation(observation);
}
///
@@ -751,9 +747,7 @@ protected void AddVectorObs(int observation)
/// Observation.
protected void AddVectorObs(Vector3 observation)
{
- m_Info.vectorObservation.Add(observation.x);
- m_Info.vectorObservation.Add(observation.y);
- m_Info.vectorObservation.Add(observation.z);
+ collectObservationsSensor.AddObservation(observation);
}
///
@@ -763,8 +757,7 @@ protected void AddVectorObs(Vector3 observation)
/// Observation.
protected void AddVectorObs(Vector2 observation)
{
- m_Info.vectorObservation.Add(observation.x);
- m_Info.vectorObservation.Add(observation.y);
+ collectObservationsSensor.AddObservation(observation);
}
///
@@ -774,7 +767,7 @@ protected void AddVectorObs(Vector2 observation)
/// Observation.
protected void AddVectorObs(IEnumerable observation)
{
- m_Info.vectorObservation.AddRange(observation);
+ collectObservationsSensor.AddObservation(observation);
}
///
@@ -784,10 +777,7 @@ protected void AddVectorObs(IEnumerable observation)
/// Observation.
protected void AddVectorObs(Quaternion observation)
{
- m_Info.vectorObservation.Add(observation.x);
- m_Info.vectorObservation.Add(observation.y);
- m_Info.vectorObservation.Add(observation.z);
- m_Info.vectorObservation.Add(observation.w);
+ collectObservationsSensor.AddObservation(observation);
}
///
@@ -797,36 +787,12 @@ protected void AddVectorObs(Quaternion observation)
///
protected void AddVectorObs(bool observation)
{
- m_Info.vectorObservation.Add(observation ? 1f : 0f);
+ collectObservationsSensor.AddObservation(observation);
}
protected void AddVectorObs(int observation, int range)
{
- var oneHotVector = new float[range];
- oneHotVector[observation] = 1;
- m_Info.vectorObservation.AddRange(oneHotVector);
- }
-
- ///
- /// Sets the text observation.
- ///
- /// The text observation.
- public void SetTextObs(string textObservation)
- {
- m_Info.textObservation = textObservation;
- }
-
- ///
- /// Specifies the agent behavior at every step based on the provided
- /// action.
- ///
- ///
- /// Vector action. Note that for discrete actions, the provided array
- /// will be of length 1.
- ///
- /// Text action.
- public virtual void AgentAction(float[] vectorAction, string textAction)
- {
+ collectObservationsSensor.AddOneHotObservation(observation, range);
}
///
@@ -837,15 +803,8 @@ public virtual void AgentAction(float[] vectorAction, string textAction)
/// Vector action. Note that for discrete actions, the provided array
/// will be of length 1.
///
- /// Text action.
- ///
- /// A custom action, defined by the user as custom protobuf message. Useful if the action is hard to encode
- /// as either a flat vector or a single string.
- ///
- public virtual void AgentAction(float[] vectorAction, string textAction, CommunicatorObjects.CustomActionProto customAction)
+ public virtual void AgentAction(float[] vectorAction)
{
- // We fall back to not using the custom action if the subclassed Agent doesn't override this method.
- AgentAction(vectorAction, textAction);
}
///
@@ -902,25 +861,6 @@ public void UpdateVectorAction(float[] vectorActions)
m_Action.vectorActions = vectorActions;
}
- ///
- /// Updates the memories action.
- ///
- /// Memories.
- public void UpdateMemoriesAction(List memories)
- {
- m_Action.memories = memories;
- }
-
- public void AppendMemoriesAction(List memories)
- {
- m_Action.memories.AddRange(memories);
- }
-
- public List GetMemoriesAction()
- {
- return m_Action.memories;
- }
-
///
/// Updates the value of the agent.
///
@@ -1029,7 +969,7 @@ void AgentStep()
if ((m_RequestAction) && (m_Brain != null))
{
m_RequestAction = false;
- AgentAction(m_Action.vectorActions, m_Action.textActions, m_Action.customAction);
+ AgentAction(m_Action.vectorActions);
}
if ((m_StepCount >= agentParameters.maxStep)
@@ -1065,14 +1005,5 @@ void DecideAction()
{
m_Brain?.DecideAction();
}
-
- ///
- /// Sets the custom observation for the agent for this episode.
- ///
- /// New value of the agent's custom observation.
- public void SetCustomObservation(CommunicatorObjects.CustomObservationProto customObservation)
- {
- m_Info.customObservation = customObservation;
- }
}
}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationRecorder.cs b/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationRecorder.cs
index 01c0d77d9d..3ccbde5326 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationRecorder.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationRecorder.cs
@@ -1,5 +1,6 @@
-using UnityEngine;
+using System.IO.Abstractions;
using System.Text.RegularExpressions;
+using UnityEngine;
namespace MLAgents
{
@@ -11,12 +12,12 @@ public class DemonstrationRecorder : MonoBehaviour
{
public bool record;
public string demonstrationName;
- private Agent m_RecordingAgent;
- private string m_FilePath;
- private DemonstrationStore m_DemoStore;
+ Agent m_RecordingAgent;
+ string m_FilePath;
+ DemonstrationStore m_DemoStore;
public const int MaxNameLength = 16;
- private void Start()
+ void Start()
{
if (Application.isEditor && record)
{
@@ -24,7 +25,7 @@ private void Start()
}
}
- private void Update()
+ void Update()
{
if (Application.isEditor && record && m_DemoStore == null)
{
@@ -35,15 +36,16 @@ private void Update()
///
/// Creates demonstration store for use in recording.
///
- private void InitializeDemoStore()
+ public void InitializeDemoStore(IFileSystem fileSystem = null)
{
m_RecordingAgent = GetComponent();
- m_DemoStore = new DemonstrationStore();
+ m_DemoStore = new DemonstrationStore(fileSystem);
+ var behaviorParams = GetComponent();
demonstrationName = SanitizeName(demonstrationName, MaxNameLength);
m_DemoStore.Initialize(
demonstrationName,
- GetComponent().brainParameters,
- GetComponent().behaviorName);
+ behaviorParams.brainParameters,
+ behaviorParams.behaviorName);
Monitor.Log("Recording Demonstration of Agent: ", m_RecordingAgent.name);
}
@@ -71,14 +73,23 @@ public void WriteExperience(AgentInfo info)
m_DemoStore.Record(info);
}
+ public void Close()
+ {
+ if (m_DemoStore != null)
+ {
+ m_DemoStore.Close();
+ m_DemoStore = null;
+ }
+ }
+
///
/// Closes Demonstration store.
///
- private void OnApplicationQuit()
+ void OnApplicationQuit()
{
- if (Application.isEditor && record && m_DemoStore != null)
+ if (Application.isEditor && record)
{
- m_DemoStore.Close();
+ Close();
}
}
}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs b/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs
index 9a9664768b..2ae5360a25 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs
@@ -1,6 +1,7 @@
using System.IO;
using System.IO.Abstractions;
using Google.Protobuf;
+using UnityEngine;
namespace MLAgents
{
@@ -10,23 +11,25 @@ namespace MLAgents
public class DemonstrationStore
{
public const int MetaDataBytes = 32; // Number of bytes allocated to metadata in demo file.
- private readonly IFileSystem m_FileSystem;
- private const string k_DemoDirecory = "Assets/Demonstrations/";
- private const string k_ExtensionType = ".demo";
+ readonly IFileSystem m_FileSystem;
+ const string k_DemoDirecory = "Assets/Demonstrations/";
+ const string k_ExtensionType = ".demo";
- private string m_FilePath;
- private DemonstrationMetaData m_MetaData;
- private Stream m_Writer;
- private float m_CumulativeReward;
+ string m_FilePath;
+ DemonstrationMetaData m_MetaData;
+ Stream m_Writer;
+ float m_CumulativeReward;
public DemonstrationStore(IFileSystem fileSystem)
{
- m_FileSystem = fileSystem;
- }
-
- public DemonstrationStore()
- {
- m_FileSystem = new FileSystem();
+ if (fileSystem != null)
+ {
+ m_FileSystem = fileSystem;
+ }
+ else
+ {
+ m_FileSystem = new FileSystem();
+ }
}
///
@@ -44,7 +47,7 @@ public void Initialize(
/// Checks for the existence of the Demonstrations directory
/// and creates it if it does not exist.
///
- private void CreateDirectory()
+ void CreateDirectory()
{
if (!m_FileSystem.Directory.Exists(k_DemoDirecory))
{
@@ -55,7 +58,7 @@ private void CreateDirectory()
///
/// Creates demonstration file.
///
- private void CreateDemonstrationFile(string demonstrationName)
+ void CreateDemonstrationFile(string demonstrationName)
{
// Creates demonstration file.
var literalName = demonstrationName;
@@ -69,7 +72,7 @@ private void CreateDemonstrationFile(string demonstrationName)
}
m_Writer = m_FileSystem.File.Create(m_FilePath);
- m_MetaData = new DemonstrationMetaData {demonstrationName = demonstrationName};
+ m_MetaData = new DemonstrationMetaData { demonstrationName = demonstrationName };
var metaProto = m_MetaData.ToProto();
metaProto.WriteDelimitedTo(m_Writer);
}
@@ -77,7 +80,7 @@ private void CreateDemonstrationFile(string demonstrationName)
///
/// Writes brain parameters to file.
///
- private void WriteBrainParameters(string brainName, BrainParameters brainParameters)
+ void WriteBrainParameters(string brainName, BrainParameters brainParameters)
{
// Writes BrainParameters to file.
m_Writer.Seek(MetaDataBytes + 1, 0);
@@ -99,7 +102,7 @@ public void Record(AgentInfo info)
}
// Write AgentInfo to file.
- var agentProto = info.ToProto();
+ var agentProto = info.ToInfoActionPairProto();
agentProto.WriteDelimitedTo(m_Writer);
}
@@ -117,7 +120,7 @@ public void Close()
///
/// Performs necessary episode-completion steps.
///
- private void EndEpisode()
+ void EndEpisode()
{
m_MetaData.numberEpisodes += 1;
}
@@ -125,7 +128,7 @@ private void EndEpisode()
///
/// Writes meta-data.
///
- private void WriteMetadata()
+ void WriteMetadata()
{
var metaProto = m_MetaData.ToProto();
var metaProtoBytes = metaProto.ToByteArray();
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentAction.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentAction.cs
index 33a893aa9d..84dd912776 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentAction.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentAction.cs
@@ -25,17 +25,14 @@ static AgentActionReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2Fj",
- "dGlvbi5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMaNm1sYWdlbnRzL2Vu",
- "dnMvY29tbXVuaWNhdG9yX29iamVjdHMvY3VzdG9tX2FjdGlvbi5wcm90byKh",
- "AQoQQWdlbnRBY3Rpb25Qcm90bxIWCg52ZWN0b3JfYWN0aW9ucxgBIAMoAhIU",
- "Cgx0ZXh0X2FjdGlvbnMYAiABKAkSEAoIbWVtb3JpZXMYAyADKAISDQoFdmFs",
- "dWUYBCABKAISPgoNY3VzdG9tX2FjdGlvbhgFIAEoCzInLmNvbW11bmljYXRv",
- "cl9vYmplY3RzLkN1c3RvbUFjdGlvblByb3RvQh+qAhxNTEFnZW50cy5Db21t",
- "dW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
+ "dGlvbi5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiSwoQQWdlbnRBY3Rp",
+ "b25Qcm90bxIWCg52ZWN0b3JfYWN0aW9ucxgBIAMoAhINCgV2YWx1ZRgEIAEo",
+ "AkoECAIQA0oECAMQBEoECAUQBkIfqgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9y",
+ "T2JqZWN0c2IGcHJvdG8z"));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
- new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.CustomActionReflection.Descriptor, },
+ new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
- new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentActionProto), global::MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "TextActions", "Memories", "Value", "CustomAction" }, null, null, null)
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentActionProto), global::MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "Value" }, null, null, null)
}));
}
#endregion
@@ -68,10 +65,7 @@ public AgentActionProto() {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public AgentActionProto(AgentActionProto other) : this() {
vectorActions_ = other.vectorActions_.Clone();
- textActions_ = other.textActions_;
- memories_ = other.memories_.Clone();
value_ = other.value_;
- CustomAction = other.customAction_ != null ? other.CustomAction.Clone() : null;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}
@@ -90,27 +84,6 @@ public AgentActionProto Clone() {
get { return vectorActions_; }
}
- /// Field number for the "text_actions" field.
- public const int TextActionsFieldNumber = 2;
- private string textActions_ = "";
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public string TextActions {
- get { return textActions_; }
- set {
- textActions_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
- }
- }
-
- /// Field number for the "memories" field.
- public const int MemoriesFieldNumber = 3;
- private static readonly pb::FieldCodec _repeated_memories_codec
- = pb::FieldCodec.ForFloat(26);
- private readonly pbc::RepeatedField memories_ = new pbc::RepeatedField();
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public pbc::RepeatedField Memories {
- get { return memories_; }
- }
-
/// Field number for the "value" field.
public const int ValueFieldNumber = 4;
private float value_;
@@ -122,17 +95,6 @@ public float Value {
}
}
- /// Field number for the "custom_action" field.
- public const int CustomActionFieldNumber = 5;
- private global::MLAgents.CommunicatorObjects.CustomActionProto customAction_;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public global::MLAgents.CommunicatorObjects.CustomActionProto CustomAction {
- get { return customAction_; }
- set {
- customAction_ = value;
- }
- }
-
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as AgentActionProto);
@@ -147,10 +109,7 @@ public bool Equals(AgentActionProto other) {
return true;
}
if(!vectorActions_.Equals(other.vectorActions_)) return false;
- if (TextActions != other.TextActions) return false;
- if(!memories_.Equals(other.memories_)) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Value, other.Value)) return false;
- if (!object.Equals(CustomAction, other.CustomAction)) return false;
return Equals(_unknownFields, other._unknownFields);
}
@@ -158,10 +117,7 @@ public bool Equals(AgentActionProto other) {
public override int GetHashCode() {
int hash = 1;
hash ^= vectorActions_.GetHashCode();
- if (TextActions.Length != 0) hash ^= TextActions.GetHashCode();
- hash ^= memories_.GetHashCode();
if (Value != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Value);
- if (customAction_ != null) hash ^= CustomAction.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
@@ -176,19 +132,10 @@ public override string ToString() {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
vectorActions_.WriteTo(output, _repeated_vectorActions_codec);
- if (TextActions.Length != 0) {
- output.WriteRawTag(18);
- output.WriteString(TextActions);
- }
- memories_.WriteTo(output, _repeated_memories_codec);
if (Value != 0F) {
output.WriteRawTag(37);
output.WriteFloat(Value);
}
- if (customAction_ != null) {
- output.WriteRawTag(42);
- output.WriteMessage(CustomAction);
- }
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
@@ -198,16 +145,9 @@ public void WriteTo(pb::CodedOutputStream output) {
public int CalculateSize() {
int size = 0;
size += vectorActions_.CalculateSize(_repeated_vectorActions_codec);
- if (TextActions.Length != 0) {
- size += 1 + pb::CodedOutputStream.ComputeStringSize(TextActions);
- }
- size += memories_.CalculateSize(_repeated_memories_codec);
if (Value != 0F) {
size += 1 + 4;
}
- if (customAction_ != null) {
- size += 1 + pb::CodedOutputStream.ComputeMessageSize(CustomAction);
- }
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
@@ -220,19 +160,9 @@ public void MergeFrom(AgentActionProto other) {
return;
}
vectorActions_.Add(other.vectorActions_);
- if (other.TextActions.Length != 0) {
- TextActions = other.TextActions;
- }
- memories_.Add(other.memories_);
if (other.Value != 0F) {
Value = other.Value;
}
- if (other.customAction_ != null) {
- if (customAction_ == null) {
- customAction_ = new global::MLAgents.CommunicatorObjects.CustomActionProto();
- }
- CustomAction.MergeFrom(other.CustomAction);
- }
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}
@@ -249,26 +179,10 @@ public void MergeFrom(pb::CodedInputStream input) {
vectorActions_.AddEntriesFrom(input, _repeated_vectorActions_codec);
break;
}
- case 18: {
- TextActions = input.ReadString();
- break;
- }
- case 26:
- case 29: {
- memories_.AddEntriesFrom(input, _repeated_memories_codec);
- break;
- }
case 37: {
Value = input.ReadFloat();
break;
}
- case 42: {
- if (customAction_ == null) {
- customAction_ = new global::MLAgents.CommunicatorObjects.CustomActionProto();
- }
- input.ReadMessage(customAction_);
- break;
- }
}
}
}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentInfo.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentInfo.cs
index dfe9158fd9..41f75e95df 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentInfo.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentInfo.cs
@@ -25,24 +25,18 @@ static AgentInfoReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjNtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
- "Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGj9tbGFnZW50cy9lbnZz",
- "L2NvbW11bmljYXRvcl9vYmplY3RzL2NvbXByZXNzZWRfb2JzZXJ2YXRpb24u",
- "cHJvdG8aO21sYWdlbnRzL2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvY3Vz",
- "dG9tX29ic2VydmF0aW9uLnByb3RvIpgDCg5BZ2VudEluZm9Qcm90bxIiChpz",
- "dGFja2VkX3ZlY3Rvcl9vYnNlcnZhdGlvbhgBIAMoAhIYChB0ZXh0X29ic2Vy",
- "dmF0aW9uGAMgASgJEh0KFXN0b3JlZF92ZWN0b3JfYWN0aW9ucxgEIAMoAhIb",
- "ChNzdG9yZWRfdGV4dF9hY3Rpb25zGAUgASgJEhAKCG1lbW9yaWVzGAYgAygC",
- "Eg4KBnJld2FyZBgHIAEoAhIMCgRkb25lGAggASgIEhgKEG1heF9zdGVwX3Jl",
- "YWNoZWQYCSABKAgSCgoCaWQYCiABKAUSEwoLYWN0aW9uX21hc2sYCyADKAgS",
- "SAoSY3VzdG9tX29ic2VydmF0aW9uGAwgASgLMiwuY29tbXVuaWNhdG9yX29i",
- "amVjdHMuQ3VzdG9tT2JzZXJ2YXRpb25Qcm90bxJRChdjb21wcmVzc2VkX29i",
- "c2VydmF0aW9ucxgNIAMoCzIwLmNvbW11bmljYXRvcl9vYmplY3RzLkNvbXBy",
- "ZXNzZWRPYnNlcnZhdGlvblByb3RvSgQIAhADQh+qAhxNTEFnZW50cy5Db21t",
- "dW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
+ "Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50cy9lbnZz",
+ "L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvItEBCg5B",
+ "Z2VudEluZm9Qcm90bxIOCgZyZXdhcmQYByABKAISDAoEZG9uZRgIIAEoCBIY",
+ "ChBtYXhfc3RlcF9yZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlv",
+ "bl9tYXNrGAsgAygIEjwKDG9ic2VydmF0aW9ucxgNIAMoCzImLmNvbW11bmlj",
+ "YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG9KBAgBEAJKBAgCEANKBAgD",
+ "EARKBAgEEAVKBAgFEAZKBAgGEAdKBAgMEA1CH6oCHE1MQWdlbnRzLkNvbW11",
+ "bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
- new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.CompressedObservationReflection.Descriptor, global::MLAgents.CommunicatorObjects.CustomObservationReflection.Descriptor, },
+ new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.ObservationReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
- new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoProto), global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "StackedVectorObservation", "TextObservation", "StoredVectorActions", "StoredTextActions", "Memories", "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "CustomObservation", "CompressedObservations" }, null, null, null)
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoProto), global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations" }, null, null, null)
}));
}
#endregion
@@ -74,18 +68,12 @@ public AgentInfoProto() {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public AgentInfoProto(AgentInfoProto other) : this() {
- stackedVectorObservation_ = other.stackedVectorObservation_.Clone();
- textObservation_ = other.textObservation_;
- storedVectorActions_ = other.storedVectorActions_.Clone();
- storedTextActions_ = other.storedTextActions_;
- memories_ = other.memories_.Clone();
reward_ = other.reward_;
done_ = other.done_;
maxStepReached_ = other.maxStepReached_;
id_ = other.id_;
actionMask_ = other.actionMask_.Clone();
- CustomObservation = other.customObservation_ != null ? other.CustomObservation.Clone() : null;
- compressedObservations_ = other.compressedObservations_.Clone();
+ observations_ = other.observations_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}
@@ -94,58 +82,6 @@ public AgentInfoProto Clone() {
return new AgentInfoProto(this);
}
- /// Field number for the "stacked_vector_observation" field.
- public const int StackedVectorObservationFieldNumber = 1;
- private static readonly pb::FieldCodec _repeated_stackedVectorObservation_codec
- = pb::FieldCodec.ForFloat(10);
- private readonly pbc::RepeatedField stackedVectorObservation_ = new pbc::RepeatedField();
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public pbc::RepeatedField StackedVectorObservation {
- get { return stackedVectorObservation_; }
- }
-
- /// Field number for the "text_observation" field.
- public const int TextObservationFieldNumber = 3;
- private string textObservation_ = "";
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public string TextObservation {
- get { return textObservation_; }
- set {
- textObservation_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
- }
- }
-
- /// Field number for the "stored_vector_actions" field.
- public const int StoredVectorActionsFieldNumber = 4;
- private static readonly pb::FieldCodec _repeated_storedVectorActions_codec
- = pb::FieldCodec.ForFloat(34);
- private readonly pbc::RepeatedField storedVectorActions_ = new pbc::RepeatedField();
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public pbc::RepeatedField StoredVectorActions {
- get { return storedVectorActions_; }
- }
-
- /// Field number for the "stored_text_actions" field.
- public const int StoredTextActionsFieldNumber = 5;
- private string storedTextActions_ = "";
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public string StoredTextActions {
- get { return storedTextActions_; }
- set {
- storedTextActions_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
- }
- }
-
- /// Field number for the "memories" field.
- public const int MemoriesFieldNumber = 6;
- private static readonly pb::FieldCodec _repeated_memories_codec
- = pb::FieldCodec.ForFloat(50);
- private readonly pbc::RepeatedField memories_ = new pbc::RepeatedField();
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public pbc::RepeatedField Memories {
- get { return memories_; }
- }
-
/// Field number for the "reward" field.
public const int RewardFieldNumber = 7;
private float reward_;
@@ -200,25 +136,14 @@ public int Id {
get { return actionMask_; }
}
- /// Field number for the "custom_observation" field.
- public const int CustomObservationFieldNumber = 12;
- private global::MLAgents.CommunicatorObjects.CustomObservationProto customObservation_;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public global::MLAgents.CommunicatorObjects.CustomObservationProto CustomObservation {
- get { return customObservation_; }
- set {
- customObservation_ = value;
- }
- }
-
- /// Field number for the "compressed_observations" field.
- public const int CompressedObservationsFieldNumber = 13;
- private static readonly pb::FieldCodec _repeated_compressedObservations_codec
- = pb::FieldCodec.ForMessage(106, global::MLAgents.CommunicatorObjects.CompressedObservationProto.Parser);
- private readonly pbc::RepeatedField compressedObservations_ = new pbc::RepeatedField();
+ /// Field number for the "observations" field.
+ public const int ObservationsFieldNumber = 13;
+ private static readonly pb::FieldCodec _repeated_observations_codec
+ = pb::FieldCodec.ForMessage(106, global::MLAgents.CommunicatorObjects.ObservationProto.Parser);
+ private readonly pbc::RepeatedField observations_ = new pbc::RepeatedField();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public pbc::RepeatedField CompressedObservations {
- get { return compressedObservations_; }
+ public pbc::RepeatedField Observations {
+ get { return observations_; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
@@ -234,36 +159,24 @@ public bool Equals(AgentInfoProto other) {
if (ReferenceEquals(other, this)) {
return true;
}
- if(!stackedVectorObservation_.Equals(other.stackedVectorObservation_)) return false;
- if (TextObservation != other.TextObservation) return false;
- if(!storedVectorActions_.Equals(other.storedVectorActions_)) return false;
- if (StoredTextActions != other.StoredTextActions) return false;
- if(!memories_.Equals(other.memories_)) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Reward, other.Reward)) return false;
if (Done != other.Done) return false;
if (MaxStepReached != other.MaxStepReached) return false;
if (Id != other.Id) return false;
if(!actionMask_.Equals(other.actionMask_)) return false;
- if (!object.Equals(CustomObservation, other.CustomObservation)) return false;
- if(!compressedObservations_.Equals(other.compressedObservations_)) return false;
+ if(!observations_.Equals(other.observations_)) return false;
return Equals(_unknownFields, other._unknownFields);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
- hash ^= stackedVectorObservation_.GetHashCode();
- if (TextObservation.Length != 0) hash ^= TextObservation.GetHashCode();
- hash ^= storedVectorActions_.GetHashCode();
- if (StoredTextActions.Length != 0) hash ^= StoredTextActions.GetHashCode();
- hash ^= memories_.GetHashCode();
if (Reward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Reward);
if (Done != false) hash ^= Done.GetHashCode();
if (MaxStepReached != false) hash ^= MaxStepReached.GetHashCode();
if (Id != 0) hash ^= Id.GetHashCode();
hash ^= actionMask_.GetHashCode();
- if (customObservation_ != null) hash ^= CustomObservation.GetHashCode();
- hash ^= compressedObservations_.GetHashCode();
+ hash ^= observations_.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
@@ -277,17 +190,6 @@ public override string ToString() {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
- stackedVectorObservation_.WriteTo(output, _repeated_stackedVectorObservation_codec);
- if (TextObservation.Length != 0) {
- output.WriteRawTag(26);
- output.WriteString(TextObservation);
- }
- storedVectorActions_.WriteTo(output, _repeated_storedVectorActions_codec);
- if (StoredTextActions.Length != 0) {
- output.WriteRawTag(42);
- output.WriteString(StoredTextActions);
- }
- memories_.WriteTo(output, _repeated_memories_codec);
if (Reward != 0F) {
output.WriteRawTag(61);
output.WriteFloat(Reward);
@@ -305,11 +207,7 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteInt32(Id);
}
actionMask_.WriteTo(output, _repeated_actionMask_codec);
- if (customObservation_ != null) {
- output.WriteRawTag(98);
- output.WriteMessage(CustomObservation);
- }
- compressedObservations_.WriteTo(output, _repeated_compressedObservations_codec);
+ observations_.WriteTo(output, _repeated_observations_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
@@ -318,15 +216,6 @@ public void WriteTo(pb::CodedOutputStream output) {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
- size += stackedVectorObservation_.CalculateSize(_repeated_stackedVectorObservation_codec);
- if (TextObservation.Length != 0) {
- size += 1 + pb::CodedOutputStream.ComputeStringSize(TextObservation);
- }
- size += storedVectorActions_.CalculateSize(_repeated_storedVectorActions_codec);
- if (StoredTextActions.Length != 0) {
- size += 1 + pb::CodedOutputStream.ComputeStringSize(StoredTextActions);
- }
- size += memories_.CalculateSize(_repeated_memories_codec);
if (Reward != 0F) {
size += 1 + 4;
}
@@ -340,10 +229,7 @@ public int CalculateSize() {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id);
}
size += actionMask_.CalculateSize(_repeated_actionMask_codec);
- if (customObservation_ != null) {
- size += 1 + pb::CodedOutputStream.ComputeMessageSize(CustomObservation);
- }
- size += compressedObservations_.CalculateSize(_repeated_compressedObservations_codec);
+ size += observations_.CalculateSize(_repeated_observations_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
@@ -355,15 +241,6 @@ public void MergeFrom(AgentInfoProto other) {
if (other == null) {
return;
}
- stackedVectorObservation_.Add(other.stackedVectorObservation_);
- if (other.TextObservation.Length != 0) {
- TextObservation = other.TextObservation;
- }
- storedVectorActions_.Add(other.storedVectorActions_);
- if (other.StoredTextActions.Length != 0) {
- StoredTextActions = other.StoredTextActions;
- }
- memories_.Add(other.memories_);
if (other.Reward != 0F) {
Reward = other.Reward;
}
@@ -377,13 +254,7 @@ public void MergeFrom(AgentInfoProto other) {
Id = other.Id;
}
actionMask_.Add(other.actionMask_);
- if (other.customObservation_ != null) {
- if (customObservation_ == null) {
- customObservation_ = new global::MLAgents.CommunicatorObjects.CustomObservationProto();
- }
- CustomObservation.MergeFrom(other.CustomObservation);
- }
- compressedObservations_.Add(other.compressedObservations_);
+ observations_.Add(other.observations_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}
@@ -395,29 +266,6 @@ public void MergeFrom(pb::CodedInputStream input) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
- case 10:
- case 13: {
- stackedVectorObservation_.AddEntriesFrom(input, _repeated_stackedVectorObservation_codec);
- break;
- }
- case 26: {
- TextObservation = input.ReadString();
- break;
- }
- case 34:
- case 37: {
- storedVectorActions_.AddEntriesFrom(input, _repeated_storedVectorActions_codec);
- break;
- }
- case 42: {
- StoredTextActions = input.ReadString();
- break;
- }
- case 50:
- case 53: {
- memories_.AddEntriesFrom(input, _repeated_memories_codec);
- break;
- }
case 61: {
Reward = input.ReadFloat();
break;
@@ -439,15 +287,8 @@ public void MergeFrom(pb::CodedInputStream input) {
actionMask_.AddEntriesFrom(input, _repeated_actionMask_codec);
break;
}
- case 98: {
- if (customObservation_ == null) {
- customObservation_ = new global::MLAgents.CommunicatorObjects.CustomObservationProto();
- }
- input.ReadMessage(customObservation_);
- break;
- }
case 106: {
- compressedObservations_.AddEntriesFrom(input, _repeated_compressedObservations_codec);
+ observations_.AddEntriesFrom(input, _repeated_observations_codec);
break;
}
}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentInfoActionPair.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentInfoActionPair.cs
new file mode 100644
index 0000000000..d9b1d4d79c
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentInfoActionPair.cs
@@ -0,0 +1,219 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: mlagents/envs/communicator_objects/agent_info_action_pair.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from mlagents/envs/communicator_objects/agent_info_action_pair.proto
+ public static partial class AgentInfoActionPairReflection {
+
+ #region Descriptor
+ /// File descriptor for mlagents/envs/communicator_objects/agent_info_action_pair.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static AgentInfoActionPairReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "Cj9tbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
+ "Zm9fYWN0aW9uX3BhaXIucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjNt",
+ "bGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2luZm8u",
+ "cHJvdG8aNW1sYWdlbnRzL2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvYWdl",
+ "bnRfYWN0aW9uLnByb3RvIpEBChhBZ2VudEluZm9BY3Rpb25QYWlyUHJvdG8S",
+ "OAoKYWdlbnRfaW5mbxgBIAEoCzIkLmNvbW11bmljYXRvcl9vYmplY3RzLkFn",
+ "ZW50SW5mb1Byb3RvEjsKC2FjdGlvbl9pbmZvGAIgASgLMiYuY29tbXVuaWNh",
+ "dG9yX29iamVjdHMuQWdlbnRBY3Rpb25Qcm90b0IfqgIcTUxBZ2VudHMuQ29t",
+ "bXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.AgentInfoReflection.Descriptor, global::MLAgents.CommunicatorObjects.AgentActionReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoActionPairProto), global::MLAgents.CommunicatorObjects.AgentInfoActionPairProto.Parser, new[]{ "AgentInfo", "ActionInfo" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class AgentInfoActionPairProto : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AgentInfoActionPairProto());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.AgentInfoActionPairReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public AgentInfoActionPairProto() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public AgentInfoActionPairProto(AgentInfoActionPairProto other) : this() {
+ AgentInfo = other.agentInfo_ != null ? other.AgentInfo.Clone() : null;
+ ActionInfo = other.actionInfo_ != null ? other.ActionInfo.Clone() : null;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public AgentInfoActionPairProto Clone() {
+ return new AgentInfoActionPairProto(this);
+ }
+
+ /// Field number for the "agent_info" field.
+ public const int AgentInfoFieldNumber = 1;
+ private global::MLAgents.CommunicatorObjects.AgentInfoProto agentInfo_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::MLAgents.CommunicatorObjects.AgentInfoProto AgentInfo {
+ get { return agentInfo_; }
+ set {
+ agentInfo_ = value;
+ }
+ }
+
+ /// Field number for the "action_info" field.
+ public const int ActionInfoFieldNumber = 2;
+ private global::MLAgents.CommunicatorObjects.AgentActionProto actionInfo_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::MLAgents.CommunicatorObjects.AgentActionProto ActionInfo {
+ get { return actionInfo_; }
+ set {
+ actionInfo_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as AgentInfoActionPairProto);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(AgentInfoActionPairProto other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (!object.Equals(AgentInfo, other.AgentInfo)) return false;
+ if (!object.Equals(ActionInfo, other.ActionInfo)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (agentInfo_ != null) hash ^= AgentInfo.GetHashCode();
+ if (actionInfo_ != null) hash ^= ActionInfo.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (agentInfo_ != null) {
+ output.WriteRawTag(10);
+ output.WriteMessage(AgentInfo);
+ }
+ if (actionInfo_ != null) {
+ output.WriteRawTag(18);
+ output.WriteMessage(ActionInfo);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (agentInfo_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(AgentInfo);
+ }
+ if (actionInfo_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(ActionInfo);
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(AgentInfoActionPairProto other) {
+ if (other == null) {
+ return;
+ }
+ if (other.agentInfo_ != null) {
+ if (agentInfo_ == null) {
+ agentInfo_ = new global::MLAgents.CommunicatorObjects.AgentInfoProto();
+ }
+ AgentInfo.MergeFrom(other.AgentInfo);
+ }
+ if (other.actionInfo_ != null) {
+ if (actionInfo_ == null) {
+ actionInfo_ = new global::MLAgents.CommunicatorObjects.AgentActionProto();
+ }
+ ActionInfo.MergeFrom(other.ActionInfo);
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ if (agentInfo_ == null) {
+ agentInfo_ = new global::MLAgents.CommunicatorObjects.AgentInfoProto();
+ }
+ input.ReadMessage(agentInfo_);
+ break;
+ }
+ case 18: {
+ if (actionInfo_ == null) {
+ actionInfo_ = new global::MLAgents.CommunicatorObjects.AgentActionProto();
+ }
+ input.ReadMessage(actionInfo_);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentInfoActionPair.cs.meta
similarity index 83%
rename from UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.cs.meta
rename to UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentInfoActionPair.cs.meta
index b90d4acb89..7474dcae69 100644
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/OSX/MacBLAS.cs.meta
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/AgentInfoActionPair.cs.meta
@@ -1,5 +1,5 @@
fileFormatVersion: 2
-guid: 680f04373f71f48a89408105d3f58a08
+guid: 29577366657494c678558b0643abcb30
MonoImporter:
externalObjects: {}
serializedVersion: 2
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/BrainParameters.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/BrainParameters.cs
index 4d92bcb92c..5948d577e1 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/BrainParameters.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/BrainParameters.cs
@@ -27,18 +27,16 @@ static BrainParametersReflection() {
"CjltbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2JyYWluX3Bh",
"cmFtZXRlcnMucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjNtbGFnZW50",
"cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3NwYWNlX3R5cGUucHJvdG8i",
- "lwIKFEJyYWluUGFyYW1ldGVyc1Byb3RvEh8KF3ZlY3Rvcl9vYnNlcnZhdGlv",
- "bl9zaXplGAEgASgFEicKH251bV9zdGFja2VkX3ZlY3Rvcl9vYnNlcnZhdGlv",
- "bnMYAiABKAUSGgoSdmVjdG9yX2FjdGlvbl9zaXplGAMgAygFEiIKGnZlY3Rv",
- "cl9hY3Rpb25fZGVzY3JpcHRpb25zGAUgAygJEkYKGHZlY3Rvcl9hY3Rpb25f",
- "c3BhY2VfdHlwZRgGIAEoDjIkLmNvbW11bmljYXRvcl9vYmplY3RzLlNwYWNl",
- "VHlwZVByb3RvEhIKCmJyYWluX25hbWUYByABKAkSEwoLaXNfdHJhaW5pbmcY",
- "CCABKAhKBAgEEAVCH6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNi",
- "BnByb3RvMw=="));
+ "2QEKFEJyYWluUGFyYW1ldGVyc1Byb3RvEhoKEnZlY3Rvcl9hY3Rpb25fc2l6",
+ "ZRgDIAMoBRIiChp2ZWN0b3JfYWN0aW9uX2Rlc2NyaXB0aW9ucxgFIAMoCRJG",
+ "Chh2ZWN0b3JfYWN0aW9uX3NwYWNlX3R5cGUYBiABKA4yJC5jb21tdW5pY2F0",
+ "b3Jfb2JqZWN0cy5TcGFjZVR5cGVQcm90bxISCgpicmFpbl9uYW1lGAcgASgJ",
+ "EhMKC2lzX3RyYWluaW5nGAggASgISgQIARACSgQIAhADSgQIBBAFQh+qAhxN",
+ "TEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.SpaceTypeReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
- new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.BrainParametersProto), global::MLAgents.CommunicatorObjects.BrainParametersProto.Parser, new[]{ "VectorObservationSize", "NumStackedVectorObservations", "VectorActionSize", "VectorActionDescriptions", "VectorActionSpaceType", "BrainName", "IsTraining" }, null, null, null)
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.BrainParametersProto), global::MLAgents.CommunicatorObjects.BrainParametersProto.Parser, new[]{ "VectorActionSize", "VectorActionDescriptions", "VectorActionSpaceType", "BrainName", "IsTraining" }, null, null, null)
}));
}
#endregion
@@ -70,8 +68,6 @@ public BrainParametersProto() {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public BrainParametersProto(BrainParametersProto other) : this() {
- vectorObservationSize_ = other.vectorObservationSize_;
- numStackedVectorObservations_ = other.numStackedVectorObservations_;
vectorActionSize_ = other.vectorActionSize_.Clone();
vectorActionDescriptions_ = other.vectorActionDescriptions_.Clone();
vectorActionSpaceType_ = other.vectorActionSpaceType_;
@@ -85,28 +81,6 @@ public BrainParametersProto Clone() {
return new BrainParametersProto(this);
}
- /// Field number for the "vector_observation_size" field.
- public const int VectorObservationSizeFieldNumber = 1;
- private int vectorObservationSize_;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public int VectorObservationSize {
- get { return vectorObservationSize_; }
- set {
- vectorObservationSize_ = value;
- }
- }
-
- /// Field number for the "num_stacked_vector_observations" field.
- public const int NumStackedVectorObservationsFieldNumber = 2;
- private int numStackedVectorObservations_;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public int NumStackedVectorObservations {
- get { return numStackedVectorObservations_; }
- set {
- numStackedVectorObservations_ = value;
- }
- }
-
/// Field number for the "vector_action_size" field.
public const int VectorActionSizeFieldNumber = 3;
private static readonly pb::FieldCodec _repeated_vectorActionSize_codec
@@ -173,8 +147,6 @@ public bool Equals(BrainParametersProto other) {
if (ReferenceEquals(other, this)) {
return true;
}
- if (VectorObservationSize != other.VectorObservationSize) return false;
- if (NumStackedVectorObservations != other.NumStackedVectorObservations) return false;
if(!vectorActionSize_.Equals(other.vectorActionSize_)) return false;
if(!vectorActionDescriptions_.Equals(other.vectorActionDescriptions_)) return false;
if (VectorActionSpaceType != other.VectorActionSpaceType) return false;
@@ -186,8 +158,6 @@ public bool Equals(BrainParametersProto other) {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
- if (VectorObservationSize != 0) hash ^= VectorObservationSize.GetHashCode();
- if (NumStackedVectorObservations != 0) hash ^= NumStackedVectorObservations.GetHashCode();
hash ^= vectorActionSize_.GetHashCode();
hash ^= vectorActionDescriptions_.GetHashCode();
if (VectorActionSpaceType != 0) hash ^= VectorActionSpaceType.GetHashCode();
@@ -206,14 +176,6 @@ public override string ToString() {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
- if (VectorObservationSize != 0) {
- output.WriteRawTag(8);
- output.WriteInt32(VectorObservationSize);
- }
- if (NumStackedVectorObservations != 0) {
- output.WriteRawTag(16);
- output.WriteInt32(NumStackedVectorObservations);
- }
vectorActionSize_.WriteTo(output, _repeated_vectorActionSize_codec);
vectorActionDescriptions_.WriteTo(output, _repeated_vectorActionDescriptions_codec);
if (VectorActionSpaceType != 0) {
@@ -236,12 +198,6 @@ public void WriteTo(pb::CodedOutputStream output) {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
- if (VectorObservationSize != 0) {
- size += 1 + pb::CodedOutputStream.ComputeInt32Size(VectorObservationSize);
- }
- if (NumStackedVectorObservations != 0) {
- size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumStackedVectorObservations);
- }
size += vectorActionSize_.CalculateSize(_repeated_vectorActionSize_codec);
size += vectorActionDescriptions_.CalculateSize(_repeated_vectorActionDescriptions_codec);
if (VectorActionSpaceType != 0) {
@@ -264,12 +220,6 @@ public void MergeFrom(BrainParametersProto other) {
if (other == null) {
return;
}
- if (other.VectorObservationSize != 0) {
- VectorObservationSize = other.VectorObservationSize;
- }
- if (other.NumStackedVectorObservations != 0) {
- NumStackedVectorObservations = other.NumStackedVectorObservations;
- }
vectorActionSize_.Add(other.vectorActionSize_);
vectorActionDescriptions_.Add(other.vectorActionDescriptions_);
if (other.VectorActionSpaceType != 0) {
@@ -292,14 +242,6 @@ public void MergeFrom(pb::CodedInputStream input) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
- case 8: {
- VectorObservationSize = input.ReadInt32();
- break;
- }
- case 16: {
- NumStackedVectorObservations = input.ReadInt32();
- break;
- }
case 26:
case 24: {
vectorActionSize_.AddEntriesFrom(input, _repeated_vectorActionSize_codec);
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CompressedObservation.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CompressedObservation.cs
deleted file mode 100644
index 1b7a0f9296..0000000000
--- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CompressedObservation.cs
+++ /dev/null
@@ -1,234 +0,0 @@
-//
-// Generated by the protocol buffer compiler. DO NOT EDIT!
-// source: mlagents/envs/communicator_objects/compressed_observation.proto
-//
-#pragma warning disable 1591, 0612, 3021
-#region Designer generated code
-
-using pb = global::Google.Protobuf;
-using pbc = global::Google.Protobuf.Collections;
-using pbr = global::Google.Protobuf.Reflection;
-using scg = global::System.Collections.Generic;
-namespace MLAgents.CommunicatorObjects {
-
- /// Holder for reflection information generated from mlagents/envs/communicator_objects/compressed_observation.proto
- public static partial class CompressedObservationReflection {
-
- #region Descriptor
- /// File descriptor for mlagents/envs/communicator_objects/compressed_observation.proto
- public static pbr::FileDescriptor Descriptor {
- get { return descriptor; }
- }
- private static pbr::FileDescriptor descriptor;
-
- static CompressedObservationReflection() {
- byte[] descriptorData = global::System.Convert.FromBase64String(
- string.Concat(
- "Cj9tbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NvbXByZXNz",
- "ZWRfb2JzZXJ2YXRpb24ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzIn8K",
- "GkNvbXByZXNzZWRPYnNlcnZhdGlvblByb3RvEg0KBXNoYXBlGAEgAygFEkQK",
- "EGNvbXByZXNzaW9uX3R5cGUYAiABKA4yKi5jb21tdW5pY2F0b3Jfb2JqZWN0",
- "cy5Db21wcmVzc2lvblR5cGVQcm90bxIMCgRkYXRhGAMgASgMKikKFENvbXBy",
- "ZXNzaW9uVHlwZVByb3RvEggKBE5PTkUQABIHCgNQTkcQAUIfqgIcTUxBZ2Vu",
- "dHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
- descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
- new pbr::FileDescriptor[] { },
- new pbr::GeneratedClrTypeInfo(new[] {typeof(global::MLAgents.CommunicatorObjects.CompressionTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
- new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.CompressedObservationProto), global::MLAgents.CommunicatorObjects.CompressedObservationProto.Parser, new[]{ "Shape", "CompressionType", "Data" }, null, null, null)
- }));
- }
- #endregion
-
- }
- #region Enums
- public enum CompressionTypeProto {
- [pbr::OriginalName("NONE")] None = 0,
- [pbr::OriginalName("PNG")] Png = 1,
- }
-
- #endregion
-
- #region Messages
- public sealed partial class CompressedObservationProto : pb::IMessage {
- private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CompressedObservationProto());
- private pb::UnknownFieldSet _unknownFields;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pb::MessageParser Parser { get { return _parser; } }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pbr::MessageDescriptor Descriptor {
- get { return global::MLAgents.CommunicatorObjects.CompressedObservationReflection.Descriptor.MessageTypes[0]; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- pbr::MessageDescriptor pb::IMessage.Descriptor {
- get { return Descriptor; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public CompressedObservationProto() {
- OnConstruction();
- }
-
- partial void OnConstruction();
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public CompressedObservationProto(CompressedObservationProto other) : this() {
- shape_ = other.shape_.Clone();
- compressionType_ = other.compressionType_;
- data_ = other.data_;
- _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public CompressedObservationProto Clone() {
- return new CompressedObservationProto(this);
- }
-
- /// Field number for the "shape" field.
- public const int ShapeFieldNumber = 1;
- private static readonly pb::FieldCodec _repeated_shape_codec
- = pb::FieldCodec.ForInt32(10);
- private readonly pbc::RepeatedField shape_ = new pbc::RepeatedField();
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public pbc::RepeatedField Shape {
- get { return shape_; }
- }
-
- /// Field number for the "compression_type" field.
- public const int CompressionTypeFieldNumber = 2;
- private global::MLAgents.CommunicatorObjects.CompressionTypeProto compressionType_ = 0;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public global::MLAgents.CommunicatorObjects.CompressionTypeProto CompressionType {
- get { return compressionType_; }
- set {
- compressionType_ = value;
- }
- }
-
- /// Field number for the "data" field.
- public const int DataFieldNumber = 3;
- private pb::ByteString data_ = pb::ByteString.Empty;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public pb::ByteString Data {
- get { return data_; }
- set {
- data_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override bool Equals(object other) {
- return Equals(other as CompressedObservationProto);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public bool Equals(CompressedObservationProto other) {
- if (ReferenceEquals(other, null)) {
- return false;
- }
- if (ReferenceEquals(other, this)) {
- return true;
- }
- if(!shape_.Equals(other.shape_)) return false;
- if (CompressionType != other.CompressionType) return false;
- if (Data != other.Data) return false;
- return Equals(_unknownFields, other._unknownFields);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override int GetHashCode() {
- int hash = 1;
- hash ^= shape_.GetHashCode();
- if (CompressionType != 0) hash ^= CompressionType.GetHashCode();
- if (Data.Length != 0) hash ^= Data.GetHashCode();
- if (_unknownFields != null) {
- hash ^= _unknownFields.GetHashCode();
- }
- return hash;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override string ToString() {
- return pb::JsonFormatter.ToDiagnosticString(this);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void WriteTo(pb::CodedOutputStream output) {
- shape_.WriteTo(output, _repeated_shape_codec);
- if (CompressionType != 0) {
- output.WriteRawTag(16);
- output.WriteEnum((int) CompressionType);
- }
- if (Data.Length != 0) {
- output.WriteRawTag(26);
- output.WriteBytes(Data);
- }
- if (_unknownFields != null) {
- _unknownFields.WriteTo(output);
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public int CalculateSize() {
- int size = 0;
- size += shape_.CalculateSize(_repeated_shape_codec);
- if (CompressionType != 0) {
- size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) CompressionType);
- }
- if (Data.Length != 0) {
- size += 1 + pb::CodedOutputStream.ComputeBytesSize(Data);
- }
- if (_unknownFields != null) {
- size += _unknownFields.CalculateSize();
- }
- return size;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(CompressedObservationProto other) {
- if (other == null) {
- return;
- }
- shape_.Add(other.shape_);
- if (other.CompressionType != 0) {
- CompressionType = other.CompressionType;
- }
- if (other.Data.Length != 0) {
- Data = other.Data;
- }
- _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(pb::CodedInputStream input) {
- uint tag;
- while ((tag = input.ReadTag()) != 0) {
- switch(tag) {
- default:
- _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
- break;
- case 10:
- case 8: {
- shape_.AddEntriesFrom(input, _repeated_shape_codec);
- break;
- }
- case 16: {
- compressionType_ = (global::MLAgents.CommunicatorObjects.CompressionTypeProto) input.ReadEnum();
- break;
- }
- case 26: {
- Data = input.ReadBytes();
- break;
- }
- }
- }
- }
-
- }
-
- #endregion
-
-}
-
-#endregion Designer generated code
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CompressedObservation.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CompressedObservation.cs.meta
deleted file mode 100644
index 8bb01e7651..0000000000
--- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CompressedObservation.cs.meta
+++ /dev/null
@@ -1,11 +0,0 @@
-fileFormatVersion: 2
-guid: 55ac40ee8d5b74b9e80d3def9d4ef6e0
-MonoImporter:
- externalObjects: {}
- serializedVersion: 2
- defaultReferences: []
- executionOrder: 0
- icon: {instanceID: 0}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomAction.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomAction.cs
deleted file mode 100644
index fe98b8d171..0000000000
--- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomAction.cs
+++ /dev/null
@@ -1,146 +0,0 @@
-//
-// Generated by the protocol buffer compiler. DO NOT EDIT!
-// source: mlagents/envs/communicator_objects/custom_action.proto
-//
-#pragma warning disable 1591, 0612, 3021
-#region Designer generated code
-
-using pb = global::Google.Protobuf;
-using pbc = global::Google.Protobuf.Collections;
-using pbr = global::Google.Protobuf.Reflection;
-using scg = global::System.Collections.Generic;
-namespace MLAgents.CommunicatorObjects {
-
- /// Holder for reflection information generated from mlagents/envs/communicator_objects/custom_action.proto
- public static partial class CustomActionReflection {
-
- #region Descriptor
- /// File descriptor for mlagents/envs/communicator_objects/custom_action.proto
- public static pbr::FileDescriptor Descriptor {
- get { return descriptor; }
- }
- private static pbr::FileDescriptor descriptor;
-
- static CustomActionReflection() {
- byte[] descriptorData = global::System.Convert.FromBase64String(
- string.Concat(
- "CjZtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2N1c3RvbV9h",
- "Y3Rpb24ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzIhMKEUN1c3RvbUFj",
- "dGlvblByb3RvQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZw",
- "cm90bzM="));
- descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
- new pbr::FileDescriptor[] { },
- new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
- new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.CustomActionProto), global::MLAgents.CommunicatorObjects.CustomActionProto.Parser, null, null, null, null)
- }));
- }
- #endregion
-
- }
- #region Messages
- public sealed partial class CustomActionProto : pb::IMessage {
- private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CustomActionProto());
- private pb::UnknownFieldSet _unknownFields;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pb::MessageParser Parser { get { return _parser; } }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pbr::MessageDescriptor Descriptor {
- get { return global::MLAgents.CommunicatorObjects.CustomActionReflection.Descriptor.MessageTypes[0]; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- pbr::MessageDescriptor pb::IMessage.Descriptor {
- get { return Descriptor; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public CustomActionProto() {
- OnConstruction();
- }
-
- partial void OnConstruction();
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public CustomActionProto(CustomActionProto other) : this() {
- _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public CustomActionProto Clone() {
- return new CustomActionProto(this);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override bool Equals(object other) {
- return Equals(other as CustomActionProto);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public bool Equals(CustomActionProto other) {
- if (ReferenceEquals(other, null)) {
- return false;
- }
- if (ReferenceEquals(other, this)) {
- return true;
- }
- return Equals(_unknownFields, other._unknownFields);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override int GetHashCode() {
- int hash = 1;
- if (_unknownFields != null) {
- hash ^= _unknownFields.GetHashCode();
- }
- return hash;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override string ToString() {
- return pb::JsonFormatter.ToDiagnosticString(this);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void WriteTo(pb::CodedOutputStream output) {
- if (_unknownFields != null) {
- _unknownFields.WriteTo(output);
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public int CalculateSize() {
- int size = 0;
- if (_unknownFields != null) {
- size += _unknownFields.CalculateSize();
- }
- return size;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(CustomActionProto other) {
- if (other == null) {
- return;
- }
- _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(pb::CodedInputStream input) {
- uint tag;
- while ((tag = input.ReadTag()) != 0) {
- switch(tag) {
- default:
- _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
- break;
- }
- }
- }
-
- }
-
- #endregion
-
-}
-
-#endregion Designer generated code
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomAction.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomAction.cs.meta
deleted file mode 100644
index 3c1bc85d1d..0000000000
--- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomAction.cs.meta
+++ /dev/null
@@ -1,11 +0,0 @@
-fileFormatVersion: 2
-guid: cc39771cc6e944eaaafb44e2da960a65
-MonoImporter:
- externalObjects: {}
- serializedVersion: 2
- defaultReferences: []
- executionOrder: 0
- icon: {instanceID: 0}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomObservation.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomObservation.cs
deleted file mode 100644
index 05770841b6..0000000000
--- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomObservation.cs
+++ /dev/null
@@ -1,146 +0,0 @@
-//
-// Generated by the protocol buffer compiler. DO NOT EDIT!
-// source: mlagents/envs/communicator_objects/custom_observation.proto
-//
-#pragma warning disable 1591, 0612, 3021
-#region Designer generated code
-
-using pb = global::Google.Protobuf;
-using pbc = global::Google.Protobuf.Collections;
-using pbr = global::Google.Protobuf.Reflection;
-using scg = global::System.Collections.Generic;
-namespace MLAgents.CommunicatorObjects {
-
- /// Holder for reflection information generated from mlagents/envs/communicator_objects/custom_observation.proto
- public static partial class CustomObservationReflection {
-
- #region Descriptor
- /// File descriptor for mlagents/envs/communicator_objects/custom_observation.proto
- public static pbr::FileDescriptor Descriptor {
- get { return descriptor; }
- }
- private static pbr::FileDescriptor descriptor;
-
- static CustomObservationReflection() {
- byte[] descriptorData = global::System.Convert.FromBase64String(
- string.Concat(
- "CjttbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2N1c3RvbV9v",
- "YnNlcnZhdGlvbi5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiGAoWQ3Vz",
- "dG9tT2JzZXJ2YXRpb25Qcm90b0IfqgIcTUxBZ2VudHMuQ29tbXVuaWNhdG9y",
- "T2JqZWN0c2IGcHJvdG8z"));
- descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
- new pbr::FileDescriptor[] { },
- new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
- new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.CustomObservationProto), global::MLAgents.CommunicatorObjects.CustomObservationProto.Parser, null, null, null, null)
- }));
- }
- #endregion
-
- }
- #region Messages
- public sealed partial class CustomObservationProto : pb::IMessage {
- private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CustomObservationProto());
- private pb::UnknownFieldSet _unknownFields;
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pb::MessageParser Parser { get { return _parser; } }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public static pbr::MessageDescriptor Descriptor {
- get { return global::MLAgents.CommunicatorObjects.CustomObservationReflection.Descriptor.MessageTypes[0]; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- pbr::MessageDescriptor pb::IMessage.Descriptor {
- get { return Descriptor; }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public CustomObservationProto() {
- OnConstruction();
- }
-
- partial void OnConstruction();
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public CustomObservationProto(CustomObservationProto other) : this() {
- _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public CustomObservationProto Clone() {
- return new CustomObservationProto(this);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override bool Equals(object other) {
- return Equals(other as CustomObservationProto);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public bool Equals(CustomObservationProto other) {
- if (ReferenceEquals(other, null)) {
- return false;
- }
- if (ReferenceEquals(other, this)) {
- return true;
- }
- return Equals(_unknownFields, other._unknownFields);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override int GetHashCode() {
- int hash = 1;
- if (_unknownFields != null) {
- hash ^= _unknownFields.GetHashCode();
- }
- return hash;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public override string ToString() {
- return pb::JsonFormatter.ToDiagnosticString(this);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void WriteTo(pb::CodedOutputStream output) {
- if (_unknownFields != null) {
- _unknownFields.WriteTo(output);
- }
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public int CalculateSize() {
- int size = 0;
- if (_unknownFields != null) {
- size += _unknownFields.CalculateSize();
- }
- return size;
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(CustomObservationProto other) {
- if (other == null) {
- return;
- }
- _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
- }
-
- [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
- public void MergeFrom(pb::CodedInputStream input) {
- uint tag;
- while ((tag = input.ReadTag()) != 0) {
- switch(tag) {
- default:
- _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
- break;
- }
- }
- }
-
- }
-
- #endregion
-
-}
-
-#endregion Designer generated code
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomObservation.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomObservation.cs.meta
deleted file mode 100644
index d0dc127a0a..0000000000
--- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/CustomObservation.cs.meta
+++ /dev/null
@@ -1,11 +0,0 @@
-fileFormatVersion: 2
-guid: 186aa820efd71454db6e4cb7b883dce5
-MonoImporter:
- externalObjects: {}
- serializedVersion: 2
- defaultReferences: []
- executionOrder: 0
- icon: {instanceID: 0}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/Observation.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/Observation.cs
new file mode 100644
index 0000000000..97ad351400
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/Observation.cs
@@ -0,0 +1,433 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: mlagents/envs/communicator_objects/observation.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace MLAgents.CommunicatorObjects {
+
+ /// Holder for reflection information generated from mlagents/envs/communicator_objects/observation.proto
+ public static partial class ObservationReflection {
+
+ #region Descriptor
+ /// File descriptor for mlagents/envs/communicator_objects/observation.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static ObservationReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjRtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0",
+ "aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyL5AQoQT2JzZXJ2YXRp",
+ "b25Qcm90bxINCgVzaGFwZRgBIAMoBRJEChBjb21wcmVzc2lvbl90eXBlGAIg",
+ "ASgOMiouY29tbXVuaWNhdG9yX29iamVjdHMuQ29tcHJlc3Npb25UeXBlUHJv",
+ "dG8SGQoPY29tcHJlc3NlZF9kYXRhGAMgASgMSAASRgoKZmxvYXRfZGF0YRgE",
+ "IAEoCzIwLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8u",
+ "RmxvYXREYXRhSAAaGQoJRmxvYXREYXRhEgwKBGRhdGEYASADKAJCEgoQb2Jz",
+ "ZXJ2YXRpb25fZGF0YSopChRDb21wcmVzc2lvblR5cGVQcm90bxIICgROT05F",
+ "EAASBwoDUE5HEAFCH6oCHE1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNi",
+ "BnByb3RvMw=="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(new[] {typeof(global::MLAgents.CommunicatorObjects.CompressionTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.ObservationProto), global::MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
+ }));
+ }
+ #endregion
+
+ }
+ #region Enums
+ public enum CompressionTypeProto {
+ [pbr::OriginalName("NONE")] None = 0,
+ [pbr::OriginalName("PNG")] Png = 1,
+ }
+
+ #endregion
+
+ #region Messages
+ public sealed partial class ObservationProto : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ObservationProto());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.ObservationReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ObservationProto() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ObservationProto(ObservationProto other) : this() {
+ shape_ = other.shape_.Clone();
+ compressionType_ = other.compressionType_;
+ switch (other.ObservationDataCase) {
+ case ObservationDataOneofCase.CompressedData:
+ CompressedData = other.CompressedData;
+ break;
+ case ObservationDataOneofCase.FloatData:
+ FloatData = other.FloatData.Clone();
+ break;
+ }
+
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ObservationProto Clone() {
+ return new ObservationProto(this);
+ }
+
+ /// Field number for the "shape" field.
+ public const int ShapeFieldNumber = 1;
+ private static readonly pb::FieldCodec _repeated_shape_codec
+ = pb::FieldCodec.ForInt32(10);
+ private readonly pbc::RepeatedField shape_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Shape {
+ get { return shape_; }
+ }
+
+ /// Field number for the "compression_type" field.
+ public const int CompressionTypeFieldNumber = 2;
+ private global::MLAgents.CommunicatorObjects.CompressionTypeProto compressionType_ = 0;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::MLAgents.CommunicatorObjects.CompressionTypeProto CompressionType {
+ get { return compressionType_; }
+ set {
+ compressionType_ = value;
+ }
+ }
+
+ /// Field number for the "compressed_data" field.
+ public const int CompressedDataFieldNumber = 3;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pb::ByteString CompressedData {
+ get { return observationDataCase_ == ObservationDataOneofCase.CompressedData ? (pb::ByteString) observationData_ : pb::ByteString.Empty; }
+ set {
+ observationData_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ observationDataCase_ = ObservationDataOneofCase.CompressedData;
+ }
+ }
+
+ /// Field number for the "float_data" field.
+ public const int FloatDataFieldNumber = 4;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData FloatData {
+ get { return observationDataCase_ == ObservationDataOneofCase.FloatData ? (global::MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData) observationData_ : null; }
+ set {
+ observationData_ = value;
+ observationDataCase_ = value == null ? ObservationDataOneofCase.None : ObservationDataOneofCase.FloatData;
+ }
+ }
+
+ private object observationData_;
+ /// Enum of possible cases for the "observation_data" oneof.
+ public enum ObservationDataOneofCase {
+ None = 0,
+ CompressedData = 3,
+ FloatData = 4,
+ }
+ private ObservationDataOneofCase observationDataCase_ = ObservationDataOneofCase.None;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public ObservationDataOneofCase ObservationDataCase {
+ get { return observationDataCase_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void ClearObservationData() {
+ observationDataCase_ = ObservationDataOneofCase.None;
+ observationData_ = null;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as ObservationProto);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(ObservationProto other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if(!shape_.Equals(other.shape_)) return false;
+ if (CompressionType != other.CompressionType) return false;
+ if (CompressedData != other.CompressedData) return false;
+ if (!object.Equals(FloatData, other.FloatData)) return false;
+ if (ObservationDataCase != other.ObservationDataCase) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ hash ^= shape_.GetHashCode();
+ if (CompressionType != 0) hash ^= CompressionType.GetHashCode();
+ if (observationDataCase_ == ObservationDataOneofCase.CompressedData) hash ^= CompressedData.GetHashCode();
+ if (observationDataCase_ == ObservationDataOneofCase.FloatData) hash ^= FloatData.GetHashCode();
+ hash ^= (int) observationDataCase_;
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ shape_.WriteTo(output, _repeated_shape_codec);
+ if (CompressionType != 0) {
+ output.WriteRawTag(16);
+ output.WriteEnum((int) CompressionType);
+ }
+ if (observationDataCase_ == ObservationDataOneofCase.CompressedData) {
+ output.WriteRawTag(26);
+ output.WriteBytes(CompressedData);
+ }
+ if (observationDataCase_ == ObservationDataOneofCase.FloatData) {
+ output.WriteRawTag(34);
+ output.WriteMessage(FloatData);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ size += shape_.CalculateSize(_repeated_shape_codec);
+ if (CompressionType != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) CompressionType);
+ }
+ if (observationDataCase_ == ObservationDataOneofCase.CompressedData) {
+ size += 1 + pb::CodedOutputStream.ComputeBytesSize(CompressedData);
+ }
+ if (observationDataCase_ == ObservationDataOneofCase.FloatData) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(FloatData);
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(ObservationProto other) {
+ if (other == null) {
+ return;
+ }
+ shape_.Add(other.shape_);
+ if (other.CompressionType != 0) {
+ CompressionType = other.CompressionType;
+ }
+ switch (other.ObservationDataCase) {
+ case ObservationDataOneofCase.CompressedData:
+ CompressedData = other.CompressedData;
+ break;
+ case ObservationDataOneofCase.FloatData:
+ if (FloatData == null) {
+ FloatData = new global::MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData();
+ }
+ FloatData.MergeFrom(other.FloatData);
+ break;
+ }
+
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10:
+ case 8: {
+ shape_.AddEntriesFrom(input, _repeated_shape_codec);
+ break;
+ }
+ case 16: {
+ compressionType_ = (global::MLAgents.CommunicatorObjects.CompressionTypeProto) input.ReadEnum();
+ break;
+ }
+ case 26: {
+ CompressedData = input.ReadBytes();
+ break;
+ }
+ case 34: {
+ global::MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData subBuilder = new global::MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData();
+ if (observationDataCase_ == ObservationDataOneofCase.FloatData) {
+ subBuilder.MergeFrom(FloatData);
+ }
+ input.ReadMessage(subBuilder);
+ FloatData = subBuilder;
+ break;
+ }
+ }
+ }
+ }
+
+ #region Nested types
+ /// Container for nested types declared in the ObservationProto message type.
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static partial class Types {
+ public sealed partial class FloatData : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FloatData());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::MLAgents.CommunicatorObjects.ObservationProto.Descriptor.NestedTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public FloatData() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public FloatData(FloatData other) : this() {
+ data_ = other.data_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public FloatData Clone() {
+ return new FloatData(this);
+ }
+
+ /// Field number for the "data" field.
+ public const int DataFieldNumber = 1;
+ private static readonly pb::FieldCodec _repeated_data_codec
+ = pb::FieldCodec.ForFloat(10);
+ private readonly pbc::RepeatedField data_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Data {
+ get { return data_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as FloatData);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(FloatData other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if(!data_.Equals(other.data_)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ hash ^= data_.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ data_.WriteTo(output, _repeated_data_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ size += data_.CalculateSize(_repeated_data_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(FloatData other) {
+ if (other == null) {
+ return;
+ }
+ data_.Add(other.data_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10:
+ case 13: {
+ data_.AddEntriesFrom(input, _repeated_data_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ }
+ #endregion
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/Observation.cs.meta
similarity index 83%
rename from UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.cs.meta
rename to UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/Observation.cs.meta
index 9304817b25..971fead69c 100644
--- a/UnitySDK/Assets/ML-Agents/Plugins/Barracuda.Core/Barracuda/Plugins/iOS/iOSBLAS.cs.meta
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/CommunicatorObjects/Observation.cs.meta
@@ -1,5 +1,5 @@
fileFormatVersion: 2
-guid: 75424b0c6afc14ea7a1debef68240d9e
+guid: 9fbba5f80821d4f02b4239a8e16eebfa
MonoImporter:
externalObjects: {}
serializedVersion: 2
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
index cc6cf9534c..046da04657 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
@@ -11,39 +11,51 @@ namespace MLAgents
{
public static class GrpcExtensions
{
+
+ ///
+ /// Converts a AgentInfo to a protobuf generated AgentInfoActionPairProto
+ ///
+ /// The protobuf version of the AgentInfoActionPairProto.
+ public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai)
+ {
+ var agentInfoProto = ai.ToAgentInfoProto();
+
+ var agentActionProto = new AgentActionProto
+ {
+ VectorActions = { ai.storedVectorActions }
+ };
+
+ return new AgentInfoActionPairProto
+ {
+ AgentInfo = agentInfoProto,
+ ActionInfo = agentActionProto
+ };
+ }
+
///
/// Converts a AgentInfo to a protobuf generated AgentInfoProto
///
/// The protobuf version of the AgentInfo.
- public static AgentInfoProto ToProto(this AgentInfo ai)
+ public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai)
{
var agentInfoProto = new AgentInfoProto
{
- StackedVectorObservation = { ai.stackedVectorObservation },
- StoredVectorActions = { ai.storedVectorActions },
- StoredTextActions = ai.storedTextActions,
- TextObservation = ai.textObservation,
Reward = ai.reward,
MaxStepReached = ai.maxStepReached,
Done = ai.done,
Id = ai.id,
- CustomObservation = ai.customObservation
};
- if (ai.memories != null)
- {
- agentInfoProto.Memories.Add(ai.memories);
- }
if (ai.actionMasks != null)
{
agentInfoProto.ActionMask.AddRange(ai.actionMasks);
}
- if (ai.compressedObservations != null)
+ if (ai.observations != null)
{
- foreach (var obs in ai.compressedObservations)
+ foreach (var obs in ai.observations)
{
- agentInfoProto.CompressedObservations.Add(obs.ToProto());
+ agentInfoProto.Observations.Add(obs.ToProto());
}
}
@@ -61,8 +73,6 @@ public static BrainParametersProto ToProto(this BrainParameters bp, string name,
{
var brainParametersProto = new BrainParametersProto
{
- VectorObservationSize = bp.vectorObservationSize,
- NumStackedVectorObservations = bp.numStackedVectorObservations,
VectorActionSize = { bp.vectorActionSize },
VectorActionSpaceType =
(SpaceTypeProto)bp.vectorActionSpaceType,
@@ -117,8 +127,6 @@ public static BrainParameters ToBrainParameters(this BrainParametersProto bpp)
{
var bp = new BrainParameters
{
- vectorObservationSize = bpp.VectorObservationSize,
- numStackedVectorObservations = bpp.NumStackedVectorObservations,
vectorActionSize = bpp.VectorActionSize.ToArray(),
vectorActionDescriptions = bpp.VectorActionDescriptions.ToArray(),
vectorActionSpaceType = (SpaceType)bpp.VectorActionSpaceType
@@ -163,10 +171,7 @@ public static AgentAction ToAgentAction(this AgentActionProto aap)
return new AgentAction
{
vectorActions = aap.VectorActions.ToArray(),
- textActions = aap.TextActions,
- memories = aap.Memories.ToList(),
value = aap.Value,
- customAction = aap.CustomAction
};
}
@@ -180,13 +185,38 @@ public static List ToAgentActionList(this UnityRLInputProto.Types.L
return agentActions;
}
- public static CompressedObservationProto ToProto(this CompressedObservation obs)
+ public static ObservationProto ToProto(this Observation obs)
{
- var obsProto = new CompressedObservationProto
+ ObservationProto obsProto = null;
+
+ if (obs.CompressedData != null)
{
- Data = ByteString.CopyFrom(obs.Data),
- CompressionType = (CompressionTypeProto) obs.CompressionType,
- };
+ // Make sure that uncompressed data is empty
+ if (obs.FloatData.Count != 0)
+ {
+ Debug.LogWarning("Observation has both compressed and uncompressed data set. Using compressed.");
+ }
+
+ obsProto = new ObservationProto
+ {
+ CompressedData = ByteString.CopyFrom(obs.CompressedData),
+ CompressionType = (CompressionTypeProto)obs.CompressionType,
+ };
+ }
+ else
+ {
+ var floatDataProto = new ObservationProto.Types.FloatData
+ {
+ Data = { obs.FloatData },
+ };
+
+ obsProto = new ObservationProto
+ {
+ FloatData = floatDataProto,
+ CompressionType = (CompressionTypeProto)obs.CompressionType,
+ };
+ }
+
obsProto.Shape.AddRange(obs.Shape);
return obsProto;
}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs
index 0beddc8356..be89c885e9 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs
@@ -23,7 +23,7 @@ public class RpcCommunicator : ICommunicator
bool m_IsOpen;
/// The default number of agents in the scene
- private const int k_NumAgents = 32;
+ const int k_NumAgents = 32;
/// Keeps track of the agents of each brain on the current step
Dictionary> m_CurrentAgents =
@@ -37,8 +37,8 @@ public class RpcCommunicator : ICommunicator
new Dictionary>();
// Brains that we have sent over the communicator with agents.
- HashSet m_sentBrainKeys = new HashSet();
- Dictionary m_unsentBrainKeys = new Dictionary();
+ HashSet m_SentBrainKeys = new HashSet();
+ Dictionary m_UnsentBrainKeys = new Dictionary();
# if UNITY_EDITOR || UNITY_STANDALONE_WIN || UNITY_STANDALONE_OSX || UNITY_STANDALONE_LINUX
@@ -138,7 +138,7 @@ void UpdateEnvironmentWithInput(UnityRLInputProto rlInput)
SendCommandEvent(rlInput.Command, rlInput.EnvironmentParameters);
}
- private UnityInputProto Initialize(UnityOutputProto unityOutput,
+ UnityInputProto Initialize(UnityOutputProto unityOutput,
out UnityInputProto unityInput)
{
# if UNITY_EDITOR || UNITY_STANDALONE_WIN || UNITY_STANDALONE_OSX || UNITY_STANDALONE_LINUX
@@ -197,7 +197,8 @@ public void Dispose()
#endregion
#region Sending Events
- private void SendCommandEvent(CommandProto command, EnvironmentParametersProto environmentParametersProto)
+
+ void SendCommandEvent(CommandProto command, EnvironmentParametersProto environmentParametersProto)
{
switch (command)
{
@@ -218,7 +219,7 @@ private void SendCommandEvent(CommandProto command, EnvironmentParametersProto e
}
}
- private void SendRLInputReceivedEvent(bool isTraining)
+ void SendRLInputReceivedEvent(bool isTraining)
{
RLInputReceived?.Invoke(new UnityRLInputParameters { isTraining = isTraining });
}
@@ -243,7 +244,7 @@ public void DecideBatch()
{
// Update the sensor data on the AgentInfo
agent.GenerateSensorData();
- var agentInfoProto = agent.Info.ToProto();
+ var agentInfoProto = agent.Info.ToAgentInfoProto();
m_CurrentUnityRlOutput.AgentInfos[brainKey].Value.Add(agentInfoProto);
}
@@ -260,8 +261,8 @@ public void DecideBatch()
///
/// Sends the observations of one Agent.
///
- /// Batch Key.
- /// Agent info.
+ /// Batch Key.
+ /// Agent info.
public void PutObservations(string brainKey, Agent agent)
{
m_CurrentAgents[brainKey].Add(agent);
@@ -337,7 +338,7 @@ public Dictionary GetActions(string key)
///
/// The next UnityInput.
/// The UnityOutput to be sent.
- private UnityInputProto Exchange(UnityOutputProto unityOutput)
+ UnityInputProto Exchange(UnityOutputProto unityOutput)
{
# if UNITY_EDITOR || UNITY_STANDALONE_WIN || UNITY_STANDALONE_OSX || UNITY_STANDALONE_LINUX
if (!m_IsOpen)
@@ -377,7 +378,7 @@ private UnityInputProto Exchange(UnityOutputProto unityOutput)
/// The UnityMessage corresponding.
/// The UnityOutput to be wrapped.
/// The status of the message.
- private static UnityMessageProto WrapMessage(UnityOutputProto content, int status)
+ static UnityMessageProto WrapMessage(UnityOutputProto content, int status)
{
return new UnityMessageProto
{
@@ -386,21 +387,21 @@ private static UnityMessageProto WrapMessage(UnityOutputProto content, int statu
};
}
- private void CacheBrainParameters(string brainKey, BrainParameters brainParameters)
+ void CacheBrainParameters(string brainKey, BrainParameters brainParameters)
{
- if (m_sentBrainKeys.Contains(brainKey))
+ if (m_SentBrainKeys.Contains(brainKey))
{
return;
}
// TODO We should check that if m_unsentBrainKeys has brainKey, it equals brainParameters
- m_unsentBrainKeys[brainKey] = brainParameters;
+ m_UnsentBrainKeys[brainKey] = brainParameters;
}
- private UnityRLInitializationOutputProto GetTempUnityRlInitializationOutput()
+ UnityRLInitializationOutputProto GetTempUnityRlInitializationOutput()
{
UnityRLInitializationOutputProto output = null;
- foreach (var brainKey in m_unsentBrainKeys.Keys)
+ foreach (var brainKey in m_UnsentBrainKeys.Keys)
{
if (m_CurrentUnityRlOutput.AgentInfos.ContainsKey(brainKey))
{
@@ -409,7 +410,7 @@ private UnityRLInitializationOutputProto GetTempUnityRlInitializationOutput()
output = new UnityRLInitializationOutputProto();
}
- var brainParameters = m_unsentBrainKeys[brainKey];
+ var brainParameters = m_UnsentBrainKeys[brainKey];
output.BrainParameters.Add(brainParameters.ToProto(brainKey, true));
}
}
@@ -417,7 +418,7 @@ private UnityRLInitializationOutputProto GetTempUnityRlInitializationOutput()
return output;
}
- private void UpdateSentBrainParameters(UnityRLInitializationOutputProto output)
+ void UpdateSentBrainParameters(UnityRLInitializationOutputProto output)
{
if (output == null)
{
@@ -426,8 +427,8 @@ private void UpdateSentBrainParameters(UnityRLInitializationOutputProto output)
foreach (var brainProto in output.BrainParameters)
{
- m_sentBrainKeys.Add(brainProto.BrainName);
- m_unsentBrainKeys.Remove(brainProto.BrainName);
+ m_SentBrainKeys.Add(brainProto.BrainName);
+ m_UnsentBrainKeys.Remove(brainProto.BrainName);
}
}
@@ -439,7 +440,7 @@ private void UpdateSentBrainParameters(UnityRLInitializationOutputProto output)
/// When the editor exits, the communicator must be closed
///
/// State.
- private void HandleOnPlayModeChanged(PlayModeStateChange state)
+ void HandleOnPlayModeChanged(PlayModeStateChange state)
{
// This method is run whenever the playmode state is changed.
if (state == PlayModeStateChange.ExitingPlayMode)
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs b/UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs
index 936a29394b..ef85028440 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/ICommunicator.cs
@@ -108,7 +108,7 @@ Since the messages are sent back and forth with exchange and simultaneously when
UnityOutput and UnityInput can be extended to provide functionalities beyond RL
UnityRLOutput and UnityRLInput can be extended to provide new RL functionalities
*/
- public interface ICommunicator : IBatchedDecisionMaker
+ public interface ICommunicator
{
///
/// Quit was received by the communicator.
@@ -141,6 +141,20 @@ public interface ICommunicator : IBatchedDecisionMaker
/// The Parameters for the Brain being registered
void SubscribeBrain(string name, BrainParameters brainParameters);
+ ///
+ /// Sends the observations of one Agent.
+ ///
+ /// Batch Key.
+ /// Agent info.
+ void PutObservations(string brainKey, Agent agent);
+
+ ///
+ /// Signals the ICommunicator that the Agents are now ready to receive their action
+ /// and that if the communicator has not yet received an action for one of the Agents
+ /// it needs to get one at this point.
+ ///
+ void DecideBatch();
+
///
/// Gets the AgentActions based on the batching key.
///
@@ -148,10 +162,4 @@ public interface ICommunicator : IBatchedDecisionMaker
///
Dictionary GetActions(string key);
}
-
- public interface IBatchedDecisionMaker : IDisposable
- {
- void PutObservations(string key, Agent agent);
- void DecideBatch();
- }
}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs
index ecdb434e20..c3eb38bd37 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs
@@ -36,9 +36,9 @@ public void Apply(TensorProxy tensorProxy, IEnumerable agents)
///
public class DiscreteActionOutputApplier : TensorApplier.IApplier
{
- private readonly int[] m_ActionSize;
- private readonly Multinomial m_Multinomial;
- private readonly ITensorAllocator m_Allocator;
+ readonly int[] m_ActionSize;
+ readonly Multinomial m_Multinomial;
+ readonly ITensorAllocator m_Allocator;
public DiscreteActionOutputApplier(int[] actionSize, int seed, ITensorAllocator allocator)
{
@@ -60,7 +60,7 @@ public void Apply(TensorProxy tensorProxy, IEnumerable agents)
var actionProbs = new TensorProxy()
{
valueType = TensorProxy.TensorType.FloatingPoint,
- shape = new long[] {batchSize, nBranchAction},
+ shape = new long[] { batchSize, nBranchAction },
data = m_Allocator.Alloc(new TensorShape(batchSize, nBranchAction))
};
@@ -78,7 +78,7 @@ public void Apply(TensorProxy tensorProxy, IEnumerable agents)
var outputTensor = new TensorProxy()
{
valueType = TensorProxy.TensorType.FloatingPoint,
- shape = new long[] {batchSize, 1},
+ shape = new long[] { batchSize, 1 },
data = m_Allocator.Alloc(new TensorShape(batchSize, 1))
};
@@ -169,68 +169,83 @@ public static void Eval(TensorProxy src, TensorProxy dst, Multinomial multinomia
}
}
- public class BarracudaMemoryOutputApplier : TensorApplier.IApplier
+ ///
+ /// The Applier for the Memory output tensor. Tensor is assumed to contain the new
+ /// memory data of the agents in the batch.
+ ///
+ public class MemoryOutputApplier : TensorApplier.IApplier
{
- private readonly int m_MemoriesCount;
- private readonly int m_MemoryIndex;
+ Dictionary> m_Memories;
- public BarracudaMemoryOutputApplier(int memoriesCount, int memoryIndex)
+ public MemoryOutputApplier(
+ Dictionary> memories)
{
- m_MemoriesCount = memoriesCount;
- m_MemoryIndex = memoryIndex;
+ m_Memories = memories;
}
-
public void Apply(TensorProxy tensorProxy, IEnumerable agents)
{
var agentIndex = 0;
var memorySize = (int)tensorProxy.shape[tensorProxy.shape.Length - 1];
-
foreach (var agent in agents)
{
- var memory = agent.GetMemoriesAction();
-
- if (memory == null || memory.Count < memorySize * m_MemoriesCount)
+ List memory = null;
+ if (!m_Memories.TryGetValue(agent.Info.id, out memory)
+ || memory.Count < memorySize)
{
memory = new List();
- memory.AddRange(Enumerable.Repeat(0f, memorySize * m_MemoriesCount));
- }
-
- for (var j = 0; j < memorySize; j++)
- {
- memory[memorySize * m_MemoryIndex + j] = tensorProxy.data[agentIndex, j];
+ memory.AddRange(Enumerable.Repeat(0f, memorySize));
}
- agent.UpdateMemoriesAction(memory);
-
+ m_Memories[agent.Info.id] = memory;
agentIndex++;
}
}
}
- ///
- /// The Applier for the Memory output tensor. Tensor is assumed to contain the new
- /// memory data of the agents in the batch.
- ///
- public class MemoryOutputApplier : TensorApplier.IApplier
+ public class BarracudaMemoryOutputApplier : TensorApplier.IApplier
{
+ readonly int m_MemoriesCount;
+ readonly int m_MemoryIndex;
+
+ Dictionary> m_Memories;
+
+ public BarracudaMemoryOutputApplier(
+ int memoriesCount,
+ int memoryIndex,
+ Dictionary> memories)
+ {
+ m_MemoriesCount = memoriesCount;
+ m_MemoryIndex = memoryIndex;
+ m_Memories = memories;
+ }
+
public void Apply(TensorProxy tensorProxy, IEnumerable agents)
{
var agentIndex = 0;
- var memorySize = tensorProxy.shape[tensorProxy.shape.Length - 1];
+ var memorySize = (int)tensorProxy.shape[tensorProxy.shape.Length - 1];
+
foreach (var agent in agents)
{
- var memory = new List();
+ List memory = null;
+ if (!m_Memories.TryGetValue(agent.Info.id, out memory)
+ || memory.Count < memorySize * m_MemoriesCount)
+ {
+ memory = new List();
+ memory.AddRange(Enumerable.Repeat(0f, memorySize * m_MemoriesCount));
+ }
+
for (var j = 0; j < memorySize; j++)
{
- memory.Add(tensorProxy.data[agentIndex, j]);
+ memory[memorySize * m_MemoryIndex + j] = tensorProxy.data[agentIndex, j];
}
- agent.UpdateMemoriesAction(memory);
+ m_Memories[agent.Info.id] = memory;
agentIndex++;
}
}
}
+
///
/// The Applier for the Value Estimate output tensor. Tensor is assumed to contain the
/// value estimates of the agents in the batch.
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/BarracudaModelParamLoader.cs b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/BarracudaModelParamLoader.cs
index ccd8a698df..f4f10e36c2 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/BarracudaModelParamLoader.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/BarracudaModelParamLoader.cs
@@ -2,6 +2,8 @@
using System.Collections.Generic;
using System.Linq;
using Barracuda;
+using MLAgents.Sensor;
+using UnityEngine;
namespace MLAgents.InferenceBrain
{
@@ -11,13 +13,14 @@ namespace MLAgents.InferenceBrain
///
public class BarracudaModelParamLoader
{
- private enum ModelActionType
+ enum ModelActionType
{
Unknown,
Discrete,
Continuous
}
- private const long k_ApiVersion = 2;
+
+ const long k_ApiVersion = 2;
///
/// Generates the Tensor inputs that are expected to be present in the Model.
@@ -60,6 +63,26 @@ public static IReadOnlyList GetInputTensors(Model model)
return tensors;
}
+ public static int GetNumVisualInputs(Model model)
+ {
+ var count = 0;
+ if (model == null)
+ return count;
+
+ foreach (var input in model.inputs)
+ {
+ if (input.shape.Length == 4)
+ {
+ if (input.name.StartsWith(TensorNames.VisualObservationPlaceholderPrefix))
+ {
+ count++;
+ }
+ }
+ }
+
+ return count;
+ }
+
///
/// Generates the Tensor outputs that are expected to be present in the Model.
///
@@ -102,8 +125,9 @@ public static string[] GetOutputNames(Model model)
///
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
///
+ /// Attached sensor components
/// The list the error messages of the checks that failed
- public static IEnumerable CheckModel(Model model, BrainParameters brainParameters)
+ public static IEnumerable CheckModel(Model model, BrainParameters brainParameters, SensorComponent[] sensorComponents)
{
List failedModelChecks = new List();
if (model == null)
@@ -143,13 +167,13 @@ public static IEnumerable CheckModel(Model model, BrainParameters brainP
})
);
failedModelChecks.AddRange(
- CheckInputTensorPresence(model, brainParameters, memorySize, isContinuous)
+ CheckInputTensorPresence(model, brainParameters, memorySize, isContinuous, sensorComponents)
);
failedModelChecks.AddRange(
CheckOutputTensorPresence(model, memorySize))
;
failedModelChecks.AddRange(
- CheckInputTensorShape(model, brainParameters)
+ CheckInputTensorShape(model, brainParameters, sensorComponents)
);
failedModelChecks.AddRange(
CheckOutputTensorShape(model, brainParameters, isContinuous, actionSize)
@@ -165,7 +189,7 @@ public static IEnumerable CheckModel(Model model, BrainParameters brainP
/// The integer value in the model indicating the type of control
///
/// The equivalent ModelActionType
- private static ModelActionType GetActionType(int isContinuousInt)
+ static ModelActionType GetActionType(int isContinuousInt)
{
ModelActionType isContinuous;
switch (isContinuousInt)
@@ -189,7 +213,7 @@ private static ModelActionType GetActionType(int isContinuousInt)
///
/// Mapping from node names to int values
/// The list the error messages of the checks that failed
- private static IEnumerable CheckIntScalarPresenceHelper(
+ static IEnumerable CheckIntScalarPresenceHelper(
Dictionary requiredScalarFields)
{
var failedModelChecks = new List();
@@ -219,14 +243,17 @@ private static IEnumerable CheckIntScalarPresenceHelper(
///
/// Whether the model is expecting continuous or discrete control.
///
+ /// Array of attached sensor components
///
/// A IEnumerable of string corresponding to the failed input presence checks.
///
- private static IEnumerable CheckInputTensorPresence(
+ static IEnumerable CheckInputTensorPresence(
Model model,
BrainParameters brainParameters,
int memory,
- ModelActionType isContinuous)
+ ModelActionType isContinuous,
+ SensorComponent[] sensorComponents
+ )
{
var failedModelChecks = new List();
var tensorsNames = GetInputTensors(model).Select(x => x.name).ToList();
@@ -240,7 +267,36 @@ private static IEnumerable CheckInputTensorPresence(
"You must set the Vector Observation Space Size to 0.");
}
- // TODO reenable checks there are enough Visual Observation Placeholder in the model.
+ // If there are not enough Visual Observation Input compared to what the
+ // sensors expect.
+ var visObsIndex = 0;
+ for (var sensorIndex = 0; sensorIndex < sensorComponents.Length; sensorIndex++)
+ {
+ var sensor = sensorComponents[sensorIndex];
+ if (!sensor.IsVisual())
+ {
+ continue;
+ }
+ if (!tensorsNames.Contains(
+ TensorNames.VisualObservationPlaceholderPrefix + visObsIndex))
+ {
+ failedModelChecks.Add(
+ "The model does not contain a Visual Observation Placeholder Input " +
+ $"for sensor component {visObsIndex} ({sensor.GetType().Name}).");
+ }
+
+ visObsIndex++;
+ }
+
+ var expectedVisualObs = GetNumVisualInputs(model);
+ // Check if there's not enough visual sensors (too many would be handled above)
+ if (expectedVisualObs > visObsIndex)
+ {
+ failedModelChecks.Add(
+ $"The model expects {expectedVisualObs} visual inputs," +
+ $" but only found {visObsIndex} visual sensors."
+ );
+ }
// If the model has a non-negative memory size but requires a recurrent input
if (memory > 0)
@@ -276,7 +332,7 @@ private static IEnumerable CheckInputTensorPresence(
///
/// A IEnumerable of string corresponding to the failed output presence checks.
///
- private static IEnumerable CheckOutputTensorPresence(Model model, int memory)
+ static IEnumerable CheckOutputTensorPresence(Model model, int memory)
{
var failedModelChecks = new List();
// If there is no Action Output.
@@ -300,6 +356,34 @@ private static IEnumerable CheckOutputTensorPresence(Model model, int me
return failedModelChecks;
}
+ ///
+ /// Checks that the shape of the visual observation input placeholder is the same as the corresponding sensor.
+ ///
+ /// The tensor that is expected by the model
+ /// The sensor that produces the visual observation.
+ ///
+ /// If the Check failed, returns a string containing information about why the
+ /// check failed. If the check passed, returns null.
+ ///
+ static string CheckVisualObsShape(
+ TensorProxy tensorProxy, SensorComponent sensorComponent)
+ {
+ var shape = sensorComponent.GetObservationShape();
+ var heightBp = shape[0];
+ var widthBp = shape[1];
+ var pixelBp = shape[2];
+ var heightT = tensorProxy.shape[1];
+ var widthT = tensorProxy.shape[2];
+ var pixelT = tensorProxy.shape[3];
+ if ((widthBp != widthT) || (heightBp != heightT) || (pixelBp != pixelT))
+ {
+ return $"The visual Observation of the model does not match. " +
+ $"Received TensorProxy of shape [?x{widthBp}x{heightBp}x{pixelBp}] but " +
+ $"was expecting [?x{widthT}x{heightT}x{pixelT}].";
+ }
+ return null;
+ }
+
///
/// Generates failed checks that correspond to inputs shapes incompatibilities between
/// the model and the BrainParameters.
@@ -310,28 +394,40 @@ private static IEnumerable CheckOutputTensorPresence(Model model, int me
///
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
///
+ /// Attached sensors
/// The list the error messages of the checks that failed
- private static IEnumerable CheckInputTensorShape(
- Model model, BrainParameters brainParameters)
+ static IEnumerable CheckInputTensorShape(
+ Model model, BrainParameters brainParameters, SensorComponent[] sensorComponents)
{
var failedModelChecks = new List();
var tensorTester =
- new Dictionary>()
+ new Dictionary>()
{
{TensorNames.VectorObservationPlacholder, CheckVectorObsShape},
{TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape},
- {TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor) => null)},
- {TensorNames.ActionMaskPlaceholder, ((bp, tensor) => null)},
- {TensorNames.SequenceLengthPlaceholder, ((bp, tensor) => null)},
- {TensorNames.RecurrentInPlaceholder, ((bp, tensor) => null)},
+ {TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs) => null)},
+ {TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs) => null)},
+ {TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs) => null)},
+ {TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs) => null)},
};
foreach (var mem in model.memories)
{
- tensorTester[mem.input] = ((bp, tensor) => null);
+ tensorTester[mem.input] = ((bp, tensor, scs) => null);
}
- // TODO reenable checks on visual observation shapes.
+ var visObsIndex = 0;
+ for (var sensorIndex = 0; sensorIndex < sensorComponents.Length; sensorIndex++)
+ {
+ var sensorComponent = sensorComponents[sensorIndex];
+ if (!sensorComponent.IsVisual())
+ {
+ continue;
+ }
+ tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] =
+ (bp, tensor, scs) => CheckVisualObsShape(tensor, sensorComponent);
+ visObsIndex++;
+ }
// If the model expects an input but it is not in this list
foreach (var tensor in GetInputTensors(model))
@@ -347,7 +443,7 @@ private static IEnumerable CheckInputTensorShape(
else
{
var tester = tensorTester[tensor.name];
- var error = tester.Invoke(brainParameters, tensor);
+ var error = tester.Invoke(brainParameters, tensor, sensorComponents);
if (error != null)
{
failedModelChecks.Add(error);
@@ -365,20 +461,50 @@ private static IEnumerable CheckInputTensorShape(
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
///
/// The tensor that is expected by the model
+ /// Array of attached sensor components
///
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
///
- private static string CheckVectorObsShape(
- BrainParameters brainParameters, TensorProxy tensorProxy)
+ static string CheckVectorObsShape(
+ BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents)
{
var vecObsSizeBp = brainParameters.vectorObservationSize;
var numStackedVector = brainParameters.numStackedVectorObservations;
var totalVecObsSizeT = tensorProxy.shape[tensorProxy.shape.Length - 1];
- if (vecObsSizeBp * numStackedVector != totalVecObsSizeT)
+
+ var totalVectorSensorSize = 0;
+ foreach (var sensorComp in sensorComponents)
{
- return "Vector Observation Size of the model does not match. Received " +
- $"{vecObsSizeBp} x {numStackedVector} but was expecting {totalVecObsSizeT}.";
+ if (sensorComp.IsVector())
+ {
+ totalVectorSensorSize += sensorComp.GetObservationShape()[0];
+ }
+ }
+
+ if (vecObsSizeBp * numStackedVector + totalVectorSensorSize != totalVecObsSizeT)
+ {
+ var sensorSizes = "";
+ foreach (var sensorComp in sensorComponents)
+ {
+ if (sensorComp.IsVector())
+ {
+ var vecSize = sensorComp.GetObservationShape()[0];
+ if (sensorSizes.Length == 0)
+ {
+ sensorSizes = $"[{vecSize}";
+ }
+ else
+ {
+ sensorSizes += $", {vecSize}";
+ }
+ }
+ }
+
+ sensorSizes += "]";
+ return $"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " +
+ $"but received {vecObsSizeBp} x {numStackedVector} vector observations and " +
+ $"SensorComponent sizes: {sensorSizes}.";
}
return null;
}
@@ -391,10 +517,11 @@ private static string CheckVectorObsShape(
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
///
/// The tensor that is expected by the model
+ /// Array of attached sensor components
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
- private static string CheckPreviousActionShape(
- BrainParameters brainParameters, TensorProxy tensorProxy)
+ static string CheckPreviousActionShape(
+ BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents)
{
var numberActionsBp = brainParameters.vectorActionSize.Length;
var numberActionsT = tensorProxy.shape[tensorProxy.shape.Length - 1];
@@ -426,7 +553,7 @@ private static string CheckPreviousActionShape(
/// A IEnumerable of string corresponding to the incompatible shapes between model
/// and BrainParameters.
///
- private static IEnumerable CheckOutputTensorShape(
+ static IEnumerable CheckOutputTensorShape(
Model model,
BrainParameters brainParameters,
ModelActionType isContinuous,
@@ -494,7 +621,7 @@ private static IEnumerable CheckOutputTensorShape(
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
///
- private static string CheckDiscreteActionOutputShape(
+ static string CheckDiscreteActionOutputShape(
BrainParameters brainParameters, TensorShape shape, int modelActionSize)
{
var bpActionSize = brainParameters.vectorActionSize.Sum();
@@ -519,7 +646,7 @@ private static string CheckDiscreteActionOutputShape(
///
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
- private static string CheckContinuousActionOutputShape(
+ static string CheckContinuousActionOutputShape(
BrainParameters brainParameters, TensorShape shape, int modelActionSize)
{
var bpActionSize = brainParameters.vectorActionSize[0];
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs
index c35461356b..250e37f4bb 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs
@@ -1,8 +1,9 @@
using System.Collections.Generic;
using System;
-using System.Linq;
using Barracuda;
using MLAgents.InferenceBrain.Utils;
+using MLAgents.Sensor;
+using UnityEngine;
namespace MLAgents.InferenceBrain
{
@@ -13,7 +14,7 @@ namespace MLAgents.InferenceBrain
///
public class BiDimensionalOutputGenerator : TensorGenerator.IGenerator
{
- private readonly ITensorAllocator m_Allocator;
+ readonly ITensorAllocator m_Allocator;
public BiDimensionalOutputGenerator(ITensorAllocator allocator)
{
@@ -32,7 +33,7 @@ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable
///
public class BatchSizeGenerator : TensorGenerator.IGenerator
{
- private readonly ITensorAllocator m_Allocator;
+ readonly ITensorAllocator m_Allocator;
public BatchSizeGenerator(ITensorAllocator allocator)
{
@@ -55,7 +56,7 @@ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable
///
public class SequenceLengthGenerator : TensorGenerator.IGenerator
{
- private readonly ITensorAllocator m_Allocator;
+ readonly ITensorAllocator m_Allocator;
public SequenceLengthGenerator(ITensorAllocator allocator)
{
@@ -79,26 +80,42 @@ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable
///
public class VectorObservationGenerator : TensorGenerator.IGenerator
{
- private readonly ITensorAllocator m_Allocator;
+ readonly ITensorAllocator m_Allocator;
+ List m_SensorIndices = new List();
+ WriteAdapter m_WriteAdapter = new WriteAdapter();
+
public VectorObservationGenerator(ITensorAllocator allocator)
{
m_Allocator = allocator;
}
- public void Generate(
- TensorProxy tensorProxy, int batchSize, IEnumerable agents)
+ public void AddSensorIndex(int sensorIndex)
+ {
+ m_SensorIndices.Add(sensorIndex);
+ }
+
+ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
var vecObsSizeT = tensorProxy.shape[tensorProxy.shape.Length - 1];
var agentIndex = 0;
foreach (var agent in agents)
{
- var info = agent.Info;
- var vectorObs = info.stackedVectorObservation;
- for (var j = 0; j < vecObsSizeT; j++)
+ var tensorOffset = 0;
+ // Write each sensor consecutively to the tensor
+ foreach (var sensorIndex in m_SensorIndices)
{
- tensorProxy.data[agentIndex, j] = vectorObs[j];
+ m_WriteAdapter.SetTarget(tensorProxy, agentIndex, tensorOffset);
+ var sensor = agent.sensors[sensorIndex];
+ var numWritten = sensor.Write(m_WriteAdapter);
+ tensorOffset += numWritten;
}
+ Debug.AssertFormat(
+ tensorOffset == vecObsSizeT,
+ "mismatch between vector observation size ({0}) and number of observations written ({1})",
+ vecObsSizeT, tensorOffset
+ );
+
agentIndex++;
}
}
@@ -113,10 +130,14 @@ public void Generate(
public class RecurrentInputGenerator : TensorGenerator.IGenerator
{
private readonly ITensorAllocator m_Allocator;
+ Dictionary> m_Memories;
- public RecurrentInputGenerator(ITensorAllocator allocator)
+ public RecurrentInputGenerator(
+ ITensorAllocator allocator,
+ Dictionary> memories)
{
m_Allocator = allocator;
+ m_Memories = memories;
}
public void Generate(
@@ -129,9 +150,18 @@ public void Generate(
foreach (var agent in agents)
{
var info = agent.Info;
- var memory = info.memories;
- if (memory == null)
+ List memory;
+
+ if (agent.Info.done)
+ {
+ m_Memories.Remove(agent.Info.id);
+ }
+ if (!m_Memories.TryGetValue(agent.Info.id, out memory))
{
+ for (var j = 0; j < memorySize; j++)
+ {
+ tensorProxy.data[agentIndex, j] = 0;
+ }
agentIndex++;
continue;
}
@@ -150,18 +180,24 @@ public void Generate(
public class BarracudaRecurrentInputGenerator : TensorGenerator.IGenerator
{
- private int m_MemoriesCount;
- private readonly int m_MemoryIndex;
- private readonly ITensorAllocator m_Allocator;
+ int m_MemoriesCount;
+ readonly int m_MemoryIndex;
+ readonly ITensorAllocator m_Allocator;
- public BarracudaRecurrentInputGenerator(int memoryIndex, ITensorAllocator allocator)
+ Dictionary> m_Memories;
+
+ public BarracudaRecurrentInputGenerator(
+ int memoryIndex,
+ ITensorAllocator allocator,
+ Dictionary> memories)
{
m_MemoryIndex = memoryIndex;
m_Allocator = allocator;
+ m_Memories = memories;
+
}
- public void Generate(
- TensorProxy tensorProxy, int batchSize, IEnumerable agents)
+ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
@@ -169,13 +205,19 @@ public void Generate(
var agentIndex = 0;
foreach (var agent in agents)
{
- var agentInfo = agent.Info;
- var memory = agentInfo.memories;
-
var offset = memorySize * m_MemoryIndex;
-
- if (memory == null)
+ List memory;
+ if (agent.Info.done)
+ {
+ m_Memories.Remove(agent.Info.id);
+ }
+ if (!m_Memories.TryGetValue(agent.Info.id, out memory))
{
+
+ for (var j = 0; j < memorySize; j++)
+ {
+ tensorProxy.data[agentIndex, j] = 0;
+ }
agentIndex++;
continue;
}
@@ -185,6 +227,7 @@ public void Generate(
{
break;
}
+
tensorProxy.data[agentIndex, j] = memory[j + offset];
}
agentIndex++;
@@ -200,15 +243,14 @@ public void Generate(
///
public class PreviousActionInputGenerator : TensorGenerator.IGenerator
{
- private readonly ITensorAllocator m_Allocator;
+ readonly ITensorAllocator m_Allocator;
public PreviousActionInputGenerator(ITensorAllocator allocator)
{
m_Allocator = allocator;
}
- public void Generate(
- TensorProxy tensorProxy, int batchSize, IEnumerable agents)
+ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
@@ -236,15 +278,14 @@ public void Generate(
///
public class ActionMaskInputGenerator : TensorGenerator.IGenerator
{
- private readonly ITensorAllocator m_Allocator;
+ readonly ITensorAllocator m_Allocator;
public ActionMaskInputGenerator(ITensorAllocator allocator)
{
m_Allocator = allocator;
}
- public void Generate(
- TensorProxy tensorProxy, int batchSize, IEnumerable agents)
+ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
@@ -271,8 +312,8 @@ public void Generate(
///
public class RandomNormalInputGenerator : TensorGenerator.IGenerator
{
- private readonly RandomNormal m_RandomNormal;
- private readonly ITensorAllocator m_Allocator;
+ readonly RandomNormal m_RandomNormal;
+ readonly ITensorAllocator m_Allocator;
public RandomNormalInputGenerator(int seed, ITensorAllocator allocator)
{
@@ -280,8 +321,7 @@ public RandomNormalInputGenerator(int seed, ITensorAllocator allocator)
m_Allocator = allocator;
}
- public void Generate(
- TensorProxy tensorProxy, int batchSize, IEnumerable agents)
+ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
TensorUtils.FillTensorWithRandomNormal(tensorProxy, m_RandomNormal);
@@ -296,27 +336,25 @@ public void Generate(
///
public class VisualObservationInputGenerator : TensorGenerator.IGenerator
{
- private readonly int m_Index;
- private readonly bool m_GrayScale;
- private readonly ITensorAllocator m_Allocator;
+ readonly int m_SensorIndex;
+ readonly ITensorAllocator m_Allocator;
+ WriteAdapter m_WriteAdapter = new WriteAdapter();
public VisualObservationInputGenerator(
- int index, ITensorAllocator allocator)
+ int sensorIndex, ITensorAllocator allocator)
{
- m_Index = index;
+ m_SensorIndex = sensorIndex;
m_Allocator = allocator;
}
- public void Generate(
- TensorProxy tensorProxy, int batchSize, IEnumerable agents)
+ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
var agentIndex = 0;
foreach (var agent in agents)
{
- // TODO direct access to sensors list here - should we do it differently?
- // TODO m_Index here is the visual observation index. Will work for now but not if we add more sensor types.
- agent.m_Sensors[m_Index].WriteToTensor(tensorProxy, agentIndex);
+ m_WriteAdapter.SetTarget(tensorProxy, agentIndex, 0);
+ agent.sensors[m_SensorIndex].Write(m_WriteAdapter);
agentIndex++;
}
}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelRunner.cs b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelRunner.cs
index 85301e9313..0ed6c353dd 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelRunner.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelRunner.cs
@@ -4,22 +4,23 @@
namespace MLAgents.InferenceBrain
{
- public class ModelRunner : IBatchedDecisionMaker
+ public class ModelRunner
{
- private List m_Agents = new List();
- private ITensorAllocator m_TensorAllocator;
- private TensorGenerator m_TensorGenerator;
- private TensorApplier m_TensorApplier;
-
- private NNModel m_Model;
- private InferenceDevice m_InferenceDevice;
- private IWorker m_Engine;
- private bool m_Verbose = false;
- private string[] m_OutputNames;
- private IReadOnlyList m_InferenceInputs;
- private IReadOnlyList m_InferenceOutputs;
-
- private bool m_visualObservationsInitialized = false;
+ List m_Agents = new List();
+ ITensorAllocator m_TensorAllocator;
+ TensorGenerator m_TensorGenerator;
+ TensorApplier m_TensorApplier;
+
+ NNModel m_Model;
+ InferenceDevice m_InferenceDevice;
+ IWorker m_Engine;
+ bool m_Verbose = false;
+ string[] m_OutputNames;
+ IReadOnlyList m_InferenceInputs;
+ IReadOnlyList m_InferenceOutputs;
+ Dictionary> m_Memories = new Dictionary>();
+
+ bool m_VisualObservationsInitialized;
///
/// Initializes the Brain with the Model that it will use when selecting actions for
@@ -66,11 +67,13 @@ public ModelRunner(
m_InferenceInputs = BarracudaModelParamLoader.GetInputTensors(barracudaModel);
m_OutputNames = BarracudaModelParamLoader.GetOutputNames(barracudaModel);
- m_TensorGenerator = new TensorGenerator(brainParameters, seed, m_TensorAllocator, barracudaModel);
- m_TensorApplier = new TensorApplier(brainParameters, seed, m_TensorAllocator, barracudaModel);
+ m_TensorGenerator = new TensorGenerator(
+ seed, m_TensorAllocator, m_Memories, barracudaModel);
+ m_TensorApplier = new TensorApplier(
+ brainParameters, seed, m_TensorAllocator, m_Memories, barracudaModel);
}
- private static Dictionary PrepareBarracudaInputs(IEnumerable infInputs)
+ static Dictionary PrepareBarracudaInputs(IEnumerable infInputs)
{
var inputs = new Dictionary();
foreach (var inp in infInputs)
@@ -88,7 +91,7 @@ public void Dispose()
m_TensorAllocator?.Reset(false);
}
- private List FetchBarracudaOutputs(string[] names)
+ List FetchBarracudaOutputs(string[] names)
{
var outputs = new List();
foreach (var n in names)
@@ -100,7 +103,7 @@ private List FetchBarracudaOutputs(string[] names)
return outputs;
}
- public void PutObservations(string key, Agent agent)
+ public void PutObservations(Agent agent)
{
m_Agents.Add(agent);
}
@@ -112,13 +115,13 @@ public void DecideBatch()
return;
}
- if (!m_visualObservationsInitialized)
+ if (!m_VisualObservationsInitialized)
{
// Just grab the first agent in the collection (any will suffice, really).
// We check for an empty Collection above, so this will always return successfully.
var firstAgent = m_Agents[0];
- m_TensorGenerator.InitializeVisualObservations(firstAgent, m_TensorAllocator);
- m_visualObservationsInitialized = true;
+ m_TensorGenerator.InitializeObservations(firstAgent, m_TensorAllocator);
+ m_VisualObservationsInitialized = true;
}
Profiler.BeginSample("LearningBrain.DecideAction");
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs
index 610abb6cf9..878e5ad701 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs
@@ -33,7 +33,7 @@ public interface IApplier
void Apply(TensorProxy tensorProxy, IEnumerable agents);
}
- private readonly Dictionary m_Dict = new Dictionary();
+ readonly Dictionary m_Dict = new Dictionary();
///
/// Returns a new TensorAppliers object.
@@ -42,9 +42,14 @@ public interface IApplier
/// used
/// The seed the Appliers will be initialized with.
/// Tensor allocator
+ /// Dictionary of AgentInfo.id to memory used to pass to the inference model.
///
public TensorApplier(
- BrainParameters bp, int seed, ITensorAllocator allocator, object barracudaModel = null)
+ BrainParameters bp,
+ int seed,
+ ITensorAllocator allocator,
+ Dictionary> memories,
+ object barracudaModel = null)
{
m_Dict[TensorNames.ValueEstimateOutput] = new ValueEstimateApplier();
if (bp.vectorActionSpaceType == SpaceType.Continuous)
@@ -56,16 +61,16 @@ public TensorApplier(
m_Dict[TensorNames.ActionOutput] =
new DiscreteActionOutputApplier(bp.vectorActionSize, seed, allocator);
}
- m_Dict[TensorNames.RecurrentOutput] = new MemoryOutputApplier();
+ m_Dict[TensorNames.RecurrentOutput] = new MemoryOutputApplier(memories);
if (barracudaModel != null)
{
var model = (Model)barracudaModel;
- for (var i = 0; i < model?.memories.Length; i++)
+ for (var i = 0; i < model?.memories.Count; i++)
{
m_Dict[model.memories[i].output] =
- new BarracudaMemoryOutputApplier(model.memories.Length, i);
+ new BarracudaMemoryOutputApplier(model.memories.Count, i, memories);
}
}
}
@@ -78,7 +83,7 @@ public TensorApplier(
/// One of the tensor does not have an
/// associated applier.
public void ApplyTensors(
- IEnumerable tensors, IEnumerable agents)
+ IEnumerable tensors, IEnumerable agents)
{
foreach (var tensor in tensors)
{
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs
index 452f15f92b..db9f0d9e92 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs
@@ -31,36 +31,36 @@ void Generate(
TensorProxy tensorProxy, int batchSize, IEnumerable agents);
}
- private readonly Dictionary m_Dict = new Dictionary();
+ readonly Dictionary m_Dict = new Dictionary();
///
/// Returns a new TensorGenerators object.
///
- /// The BrainParameters used to determine what Generators will be
- /// used
/// The seed the Generators will be initialized with.
/// Tensor allocator
+ /// Dictionary of AgentInfo.id to memory for use in the inference model.
///
public TensorGenerator(
- BrainParameters bp, int seed, ITensorAllocator allocator, object barracudaModel = null)
+ int seed,
+ ITensorAllocator allocator,
+ Dictionary> memories,
+ object barracudaModel = null)
{
// Generator for Inputs
m_Dict[TensorNames.BatchSizePlaceholder] =
new BatchSizeGenerator(allocator);
m_Dict[TensorNames.SequenceLengthPlaceholder] =
new SequenceLengthGenerator(allocator);
- m_Dict[TensorNames.VectorObservationPlacholder] =
- new VectorObservationGenerator(allocator);
m_Dict[TensorNames.RecurrentInPlaceholder] =
- new RecurrentInputGenerator(allocator);
+ new RecurrentInputGenerator(allocator, memories);
if (barracudaModel != null)
{
var model = (Model)barracudaModel;
- for (var i = 0; i < model?.memories.Length; i++)
+ for (var i = 0; i < model.memories.Count; i++)
{
m_Dict[model.memories[i].input] =
- new BarracudaRecurrentInputGenerator(i, allocator);
+ new BarracudaRecurrentInputGenerator(i, allocator, memories);
}
}
@@ -78,13 +78,39 @@ public TensorGenerator(
m_Dict[TensorNames.ValueEstimateOutput] = new BiDimensionalOutputGenerator(allocator);
}
- public void InitializeVisualObservations(Agent agent, ITensorAllocator allocator)
+ public void InitializeObservations(Agent agent, ITensorAllocator allocator)
{
- for (var visIndex = 0; visIndex < agent.m_Sensors.Count; visIndex++)
+ // Loop through the sensors on a representative agent.
+ // For vector observations, add the index to the (single) VectorObservationGenerator
+ // For visual observations, make a VisualObservationInputGenerator
+ var visIndex = 0;
+ VectorObservationGenerator vecObsGen = null;
+ for (var sensorIndex = 0; sensorIndex < agent.sensors.Count; sensorIndex++)
{
- // TODO handle non-visual sensors too - need to index better
- m_Dict[TensorNames.VisualObservationPlaceholderPrefix + visIndex] =
- new VisualObservationInputGenerator(visIndex, allocator);
+ var sensor = agent.sensors[sensorIndex];
+ var shape = sensor.GetFloatObservationShape();
+ // TODO generalize - we currently only have vector or visual, but can't handle "2D" observations
+ var isVectorSensor = (shape.Length == 1);
+ if (isVectorSensor)
+ {
+ if (vecObsGen == null)
+ {
+ vecObsGen = new VectorObservationGenerator(allocator);
+ }
+
+ vecObsGen.AddSensorIndex(sensorIndex);
+ }
+ else
+ {
+ m_Dict[TensorNames.VisualObservationPlaceholderPrefix + visIndex] =
+ new VisualObservationInputGenerator(sensorIndex, allocator);
+ visIndex++;
+ }
+ }
+
+ if (vecObsGen != null)
+ {
+ m_Dict[TensorNames.VectorObservationPlacholder] = vecObsGen;
}
}
@@ -100,9 +126,7 @@ public void InitializeVisualObservations(Agent agent, ITensorAllocator allocator
/// One of the tensor does not have an
/// associated generator.
public void GenerateTensors(
- IEnumerable tensors,
- int currentBatchSize,
- IEnumerable agents)
+ IEnumerable tensors, int currentBatchSize, IEnumerable agents)
{
foreach (var tensor in tensors)
{
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorProxy.cs b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorProxy.cs
index 7c1d3dacf5..6cb9fd985c 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorProxy.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorProxy.cs
@@ -21,7 +21,7 @@ public enum TensorType
FloatingPoint
};
- private static readonly Dictionary k_TypeMap =
+ static readonly Dictionary k_TypeMap =
new Dictionary()
{
{TensorType.FloatingPoint, typeof(float)},
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/Multinomial.cs b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/Multinomial.cs
index 24910f6f22..a543056f81 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/Multinomial.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/Multinomial.cs
@@ -10,7 +10,7 @@ namespace MLAgents.InferenceBrain.Utils
///
public class Multinomial
{
- private readonly System.Random m_Random;
+ readonly System.Random m_Random;
///
/// Constructor.
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/RandomNormal.cs b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/RandomNormal.cs
index 5f2793d0e5..78a5b4dff4 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/RandomNormal.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/RandomNormal.cs
@@ -10,9 +10,9 @@ namespace MLAgents.InferenceBrain.Utils
///
public class RandomNormal
{
- private readonly double m_Mean;
- private readonly double m_Stddev;
- private readonly Random m_Random;
+ readonly double m_Mean;
+ readonly double m_Stddev;
+ readonly Random m_Random;
public RandomNormal(int seed, float mean = 0.0f, float stddev = 1.0f)
{
@@ -22,8 +22,8 @@ public RandomNormal(int seed, float mean = 0.0f, float stddev = 1.0f)
}
// Each iteration produces two numbers. Hold one here for next call
- private bool m_HasSpare;
- private double m_SpareUnscaled;
+ bool m_HasSpare;
+ double m_SpareUnscaled;
///
/// Return the next random double number
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Policy/BarracudaPolicy.cs b/UnitySDK/Assets/ML-Agents/Scripts/Policy/BarracudaPolicy.cs
index 65d66f4239..30f5a54140 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Policy/BarracudaPolicy.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Policy/BarracudaPolicy.cs
@@ -1,6 +1,7 @@
using UnityEngine;
using Barracuda;
using System.Collections.Generic;
+using MLAgents.InferenceBrain;
namespace MLAgents
{
@@ -18,10 +19,10 @@ public enum InferenceDevice
public class BarracudaPolicy : IPolicy
{
- protected IBatchedDecisionMaker m_BatchedDecisionMaker;
+ protected ModelRunner m_ModelRunner;
///
- /// Sensor shapes for the associated Agents. All Agents must have the same shapes for their sensors.
+ /// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors.
///
List m_SensorShapes;
@@ -31,10 +32,10 @@ public BarracudaPolicy(
NNModel model,
InferenceDevice inferenceDevice)
{
- var aca = GameObject.FindObjectOfType();
+ var aca = Object.FindObjectOfType();
aca.LazyInitialization();
var modelRunner = aca.GetOrCreateModelRunner(model, brainParameters, inferenceDevice);
- m_BatchedDecisionMaker = modelRunner;
+ m_ModelRunner = modelRunner;
}
///
@@ -43,40 +44,40 @@ public void RequestDecision(Agent agent)
#if DEBUG
ValidateAgentSensorShapes(agent);
#endif
- m_BatchedDecisionMaker?.PutObservations(null, agent);
+ m_ModelRunner?.PutObservations(agent);
}
///
public void DecideAction()
{
- m_BatchedDecisionMaker?.DecideBatch();
+ m_ModelRunner?.DecideBatch();
}
///
- /// Check that the Agent sensors are the same shape as the the other Agents using the same Brain.
+ /// Check that the Agent Sensors are the same shape as the the other Agents using the same Brain.
/// If this is the first Agent being checked, its Sensor sizes will be saved.
///
/// The Agent to check
- private void ValidateAgentSensorShapes(Agent agent)
+ void ValidateAgentSensorShapes(Agent agent)
{
if (m_SensorShapes == null)
{
- m_SensorShapes = new List(agent.m_Sensors.Count);
+ m_SensorShapes = new List(agent.sensors.Count);
// First agent, save the sensor sizes
- foreach (var sensor in agent.m_Sensors)
+ foreach (var sensor in agent.sensors)
{
m_SensorShapes.Add(sensor.GetFloatObservationShape());
}
}
else
{
- // Check for compatibility with the other Agents' sensors
+ // Check for compatibility with the other Agents' Sensors
// TODO make sure this only checks once per agent
- Debug.Assert(m_SensorShapes.Count == agent.m_Sensors.Count, $"Number of sensors must match. {m_SensorShapes.Count} != {agent.m_Sensors.Count}");
+ Debug.Assert(m_SensorShapes.Count == agent.sensors.Count, $"Number of Sensors must match. {m_SensorShapes.Count} != {agent.sensors.Count}");
for (var i = 0; i < m_SensorShapes.Count; i++)
{
var cachedShape = m_SensorShapes[i];
- var sensorShape = agent.m_Sensors[i].GetFloatObservationShape();
+ var sensorShape = agent.sensors[i].GetFloatObservationShape();
Debug.Assert(cachedShape.Length == sensorShape.Length, "Sensor dimensions must match.");
for (var j = 0; j < cachedShape.Length; j++)
{
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Policy/BehaviorParameters.cs b/UnitySDK/Assets/ML-Agents/Scripts/Policy/BehaviorParameters.cs
index d29ad4b887..be794cc735 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Policy/BehaviorParameters.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Policy/BehaviorParameters.cs
@@ -11,21 +11,35 @@ namespace MLAgents
public class BehaviorParameters : MonoBehaviour
{
+ [Serializable]
+ private enum BehaviorType
+ {
+ Default,
+ HeuristicOnly,
+ InferenceOnly
+ }
+
[HideInInspector]
[SerializeField]
- private BrainParameters m_BrainParameters = new BrainParameters();
- [HideInInspector] [SerializeField] private NNModel m_Model;
- [HideInInspector] [SerializeField] private InferenceDevice m_InferenceDevice;
- [HideInInspector] [SerializeField] private bool m_UseHeuristic;
- [HideInInspector] [SerializeField] private string m_BehaviorName = "My Behavior";
-
+ BrainParameters m_BrainParameters = new BrainParameters();
+ [HideInInspector]
+ [SerializeField]
+ NNModel m_Model;
+ [HideInInspector]
+ [SerializeField]
+ InferenceDevice m_InferenceDevice;
+ [HideInInspector]
+ [SerializeField]
+ BehaviorType m_BehaviorType;
[HideInInspector]
+ [SerializeField]
+ string m_BehaviorName = "My Behavior";
+
public BrainParameters brainParameters
{
get { return m_BrainParameters; }
}
- [HideInInspector]
public string behaviorName
{
get { return m_BehaviorName; }
@@ -33,21 +47,27 @@ public string behaviorName
public IPolicy GeneratePolicy(Func heuristic)
{
- if (m_UseHeuristic)
- {
- return new HeuristicPolicy(heuristic);
- }
- if (FindObjectOfType().IsCommunicatorOn)
- {
- return new RemotePolicy(m_BrainParameters, m_BehaviorName);
- }
- if (m_Model != null)
- {
- return new BarracudaPolicy(m_BrainParameters, m_Model, m_InferenceDevice);
- }
- else
+ switch (m_BehaviorType)
{
- return new HeuristicPolicy(heuristic);
+ case BehaviorType.HeuristicOnly:
+ return new HeuristicPolicy(heuristic);
+ case BehaviorType.InferenceOnly:
+ return new BarracudaPolicy(m_BrainParameters, m_Model, m_InferenceDevice);
+ case BehaviorType.Default:
+ if (FindObjectOfType().IsCommunicatorOn)
+ {
+ return new RemotePolicy(m_BrainParameters, m_BehaviorName);
+ }
+ if (m_Model != null)
+ {
+ return new BarracudaPolicy(m_BrainParameters, m_Model, m_InferenceDevice);
+ }
+ else
+ {
+ return new HeuristicPolicy(heuristic);
+ }
+ default:
+ return new HeuristicPolicy(heuristic);
}
}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Policy/HeuristicPolicy.cs b/UnitySDK/Assets/ML-Agents/Scripts/Policy/HeuristicPolicy.cs
index 03f98e9cd4..4c014c389c 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Policy/HeuristicPolicy.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Policy/HeuristicPolicy.cs
@@ -1,6 +1,4 @@
using UnityEngine;
-using Barracuda;
-using MLAgents.InferenceBrain;
using System;
namespace MLAgents
@@ -13,8 +11,8 @@ namespace MLAgents
///
public class HeuristicPolicy : IPolicy
{
- private Func m_Heuristic;
- private Agent m_Agent;
+ Func m_Heuristic;
+ Agent m_Agent;
///
public HeuristicPolicy(Func heuristic)
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Policy/IPolicy.cs b/UnitySDK/Assets/ML-Agents/Scripts/Policy/IPolicy.cs
index 35c5bc13c0..f33d73493e 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Policy/IPolicy.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Policy/IPolicy.cs
@@ -1,5 +1,4 @@
using System;
-using System.Collections.Generic;
using UnityEngine;
namespace MLAgents
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Policy/RemotePolicy.cs b/UnitySDK/Assets/ML-Agents/Scripts/Policy/RemotePolicy.cs
index 6ad30fde75..d446cde218 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Policy/RemotePolicy.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Policy/RemotePolicy.cs
@@ -9,12 +9,11 @@ namespace MLAgents
///
public class RemotePolicy : IPolicy
{
-
- private string m_BehaviorName;
- protected IBatchedDecisionMaker m_BatchedDecisionMaker;
+ string m_BehaviorName;
+ protected ICommunicator m_Communicator;
///
- /// Sensor shapes for the associated Agents. All Agents must have the same shapes for their sensors.
+ /// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors.
///
List m_SensorShapes;
@@ -24,9 +23,9 @@ public RemotePolicy(
string behaviorName)
{
m_BehaviorName = behaviorName;
- var aca = GameObject.FindObjectOfType();
+ var aca = Object.FindObjectOfType();
aca.LazyInitialization();
- m_BatchedDecisionMaker = aca.Communicator;
+ m_Communicator = aca.Communicator;
aca.Communicator.SubscribeBrain(m_BehaviorName, brainParameters);
}
@@ -36,40 +35,40 @@ public void RequestDecision(Agent agent)
#if DEBUG
ValidateAgentSensorShapes(agent);
#endif
- m_BatchedDecisionMaker?.PutObservations(m_BehaviorName, agent);
+ m_Communicator?.PutObservations(m_BehaviorName, agent);
}
///
public void DecideAction()
{
- m_BatchedDecisionMaker?.DecideBatch();
+ m_Communicator?.DecideBatch();
}
///
- /// Check that the Agent sensors are the same shape as the the other Agents using the same Brain.
+ /// Check that the Agent Sensors are the same shape as the the other Agents using the same Brain.
/// If this is the first Agent being checked, its Sensor sizes will be saved.
///
/// The Agent to check
- private void ValidateAgentSensorShapes(Agent agent)
+ void ValidateAgentSensorShapes(Agent agent)
{
if (m_SensorShapes == null)
{
- m_SensorShapes = new List(agent.m_Sensors.Count);
+ m_SensorShapes = new List(agent.sensors.Count);
// First agent, save the sensor sizes
- foreach (var sensor in agent.m_Sensors)
+ foreach (var sensor in agent.sensors)
{
m_SensorShapes.Add(sensor.GetFloatObservationShape());
}
}
else
{
- // Check for compatibility with the other Agents' sensors
+ // Check for compatibility with the other Agents' Sensors
// TODO make sure this only checks once per agent
- Debug.Assert(m_SensorShapes.Count == agent.m_Sensors.Count, $"Number of sensors must match. {m_SensorShapes.Count} != {agent.m_Sensors.Count}");
+ Debug.Assert(m_SensorShapes.Count == agent.sensors.Count, $"Number of Sensors must match. {m_SensorShapes.Count} != {agent.sensors.Count}");
for (var i = 0; i < m_SensorShapes.Count; i++)
{
var cachedShape = m_SensorShapes[i];
- var sensorShape = agent.m_Sensors[i].GetFloatObservationShape();
+ var sensorShape = agent.sensors[i].GetFloatObservationShape();
Debug.Assert(cachedShape.Length == sensorShape.Length, "Sensor dimensions must match.");
for (var j = 0; j < cachedShape.Length; j++)
{
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/ResetParameters.cs b/UnitySDK/Assets/ML-Agents/Scripts/ResetParameters.cs
index 19ecb2c456..7f79fc7e87 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/ResetParameters.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/ResetParameters.cs
@@ -22,7 +22,7 @@ public ResetParameters(IDictionary dict) : base(dict)
UpdateResetParameters();
}
- private void UpdateResetParameters()
+ void UpdateResetParameters()
{
m_ResetParameters.Clear();
foreach (var pair in this)
@@ -32,7 +32,8 @@ private void UpdateResetParameters()
}
[FormerlySerializedAs("resetParameters")]
- [SerializeField] private List m_ResetParameters = new List();
+ [SerializeField]
+ List m_ResetParameters = new List();
public void OnBeforeSerialize()
{
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/CameraSensor.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/CameraSensor.cs
index 574f226a81..4d3f337440 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/CameraSensor.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/CameraSensor.cs
@@ -1,17 +1,16 @@
using System;
-using MLAgents.InferenceBrain;
using UnityEngine;
namespace MLAgents.Sensor
{
public class CameraSensor : ISensor
{
- private Camera m_Camera;
- private int m_Width;
- private int m_Height;
- private bool m_Grayscale;
- private string m_Name;
- private int[] m_Shape;
+ Camera m_Camera;
+ int m_Width;
+ int m_Height;
+ bool m_Grayscale;
+ string m_Name;
+ int[] m_Shape;
public CameraSensor(Camera camera, int width, int height, bool grayscale, string name)
{
@@ -20,7 +19,7 @@ public CameraSensor(Camera camera, int width, int height, bool grayscale, string
m_Height = height;
m_Grayscale = grayscale;
m_Name = name;
- m_Shape = new[] { width, height, grayscale ? 1 : 3 };
+ m_Shape = new[] { height, width, grayscale ? 1 : 3 };
}
public string GetName()
@@ -45,16 +44,19 @@ public byte[] GetCompressedObservation()
}
}
- public void WriteToTensor(TensorProxy tensorProxy, int agentIndex)
+ public int Write(WriteAdapter adapter)
{
using (TimerStack.Instance.Scoped("CameraSensor.WriteToTensor"))
{
var texture = ObservationToTexture(m_Camera, m_Width, m_Height);
- Utilities.TextureToTensorProxy(texture, tensorProxy, m_Grayscale, agentIndex);
+ var numWritten = Utilities.TextureToTensorProxy(texture, adapter, m_Grayscale);
UnityEngine.Object.Destroy(texture);
+ return numWritten;
}
}
+ public void Update() { }
+
public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.PNG;
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/CameraSensorComponent.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/CameraSensorComponent.cs
index 5afdb8d159..09d8f2c230 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/CameraSensorComponent.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/CameraSensorComponent.cs
@@ -9,11 +9,16 @@ public class CameraSensorComponent : SensorComponent
public string sensorName = "CameraSensor";
public int width = 84;
public int height = 84;
- public bool grayscale = false;
+ public bool grayscale;
public override ISensor CreateSensor()
{
return new CameraSensor(camera, width, height, grayscale, sensorName);
}
+
+ public override int[] GetObservationShape()
+ {
+ return new[] { height, width, grayscale ? 1 : 3 };
+ }
}
}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/ISensor.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/ISensor.cs
index 49f8068955..2f31772cd9 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/ISensor.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/ISensor.cs
@@ -1,11 +1,9 @@
-using MLAgents.InferenceBrain;
-
namespace MLAgents.Sensor
{
public enum SensorCompressionType
{
None,
- PNG,
+ PNG
}
///
@@ -22,12 +20,14 @@ public interface ISensor {
int[] GetFloatObservationShape();
///
- /// Write the observation data directly to the TensorProxy.
+ /// Write the observation data directly to the WriteAdapter.
/// This is considered an advanced interface; for a simpler approach, use SensorBase and override WriteFloats instead.
+ /// Note that this (and GetCompressedObservation) may be called multiple times per agent step, so should not
+ /// mutate any internal state.
///
- ///
- ///
- void WriteToTensor(TensorProxy tensorProxy, int agentIndex);
+ ///
+ /// The number of elements written
+ int Write(WriteAdapter adapater);
///
/// Return a compressed representation of the observation. For small observations, this should generally not be
@@ -37,6 +37,11 @@ public interface ISensor {
///
byte[] GetCompressedObservation();
+ ///
+ /// Update any internal state of the sensor. This is called once per each agent step.
+ ///
+ void Update();
+
///
/// Return the compression type being used. If no compression is used, return SensorCompressionType.None
///
@@ -51,4 +56,24 @@ public interface ISensor {
string GetName();
}
+ public static class SensorExtensions
+ {
+ ///
+ /// Get the total number of elements in the ISensor's observation (i.e. the product of the shape elements).
+ ///
+ ///
+ ///
+ public static int ObservationSize(this ISensor sensor)
+ {
+ var shape = sensor.GetFloatObservationShape();
+ int count = 1;
+ for (var i = 0; i < shape.Length; i++)
+ {
+ count *= shape[i];
+ }
+
+ return count;
+ }
+ }
+
}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/CompressedObservation.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/Observation.cs
similarity index 55%
rename from UnitySDK/Assets/ML-Agents/Scripts/Sensor/CompressedObservation.cs
rename to UnitySDK/Assets/ML-Agents/Scripts/Sensor/Observation.cs
index 5dfe6f85bf..3a5f88e120 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/CompressedObservation.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/Observation.cs
@@ -3,12 +3,17 @@
namespace MLAgents.Sensor
{
- public struct CompressedObservation
+ public struct Observation
{
///
- /// The compressed data.
+ /// The compressed sensor data. Assumed to be non-null if CompressionType != CompressionType.None
///
- public byte[] Data;
+ public byte[] CompressedData;
+
+ ///
+ /// Uncompressed sensor data. Assumed to be non-empty if CompressionType == CompressionType.None
+ ///
+ public ArraySegment FloatData;
///
/// The format of the compressed data
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/CompressedObservation.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/Observation.cs.meta
similarity index 100%
rename from UnitySDK/Assets/ML-Agents/Scripts/Sensor/CompressedObservation.cs.meta
rename to UnitySDK/Assets/ML-Agents/Scripts/Sensor/Observation.cs.meta
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensor.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensor.cs
new file mode 100644
index 0000000000..ee41a47a58
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensor.cs
@@ -0,0 +1,318 @@
+using System;
+using System.Collections.Generic;
+using UnityEngine;
+
+namespace MLAgents.Sensor
+{
+ public class RayPerceptionSensor : ISensor
+ {
+ public enum CastType
+ {
+ Cast2D,
+ Cast3D,
+ }
+
+ float[] m_Observations;
+ int[] m_Shape;
+ string m_Name;
+
+ float m_RayDistance;
+ List m_DetectableObjects;
+ float[] m_Angles;
+
+ float m_StartOffset;
+ float m_EndOffset;
+ float m_CastRadius;
+ CastType m_CastType;
+ Transform m_Transform;
+
+ ///
+ /// Debug information for the raycast hits. This is used by the RayPerceptionSensorComponent.
+ ///
+ public class DebugDisplayInfo
+ {
+ public struct RayInfo
+ {
+ public Vector3 localStart;
+ public Vector3 localEnd;
+ public Vector3 worldStart;
+ public Vector3 worldEnd;
+ public bool castHit;
+ public float hitFraction;
+ }
+
+ public void Reset()
+ {
+ m_Frame = Time.frameCount;
+ }
+
+ ///
+ /// "Age" of the results in number of frames. This is used to adjust the alpha when drawing.
+ ///
+ public int age
+ {
+ get { return Time.frameCount - m_Frame; }
+ }
+
+ public RayInfo[] rayInfos;
+
+ int m_Frame;
+ }
+
+ DebugDisplayInfo m_DebugDisplayInfo;
+
+ public DebugDisplayInfo debugDisplayInfo
+ {
+ get { return m_DebugDisplayInfo; }
+ }
+
+ public RayPerceptionSensor(string name, float rayDistance, List detectableObjects, float[] angles,
+ Transform transform, float startOffset, float endOffset, float castRadius, CastType castType)
+ {
+ var numObservations = (detectableObjects.Count + 2) * angles.Length;
+ m_Shape = new[] { numObservations };
+ m_Name = name;
+
+ m_Observations = new float[numObservations];
+
+ m_RayDistance = rayDistance;
+ m_DetectableObjects = detectableObjects;
+ // TODO - preprocess angles, save ray directions instead?
+ m_Angles = angles;
+ m_Transform = transform;
+ m_StartOffset = startOffset;
+ m_EndOffset = endOffset;
+ m_CastRadius = castRadius;
+ m_CastType = castType;
+
+ if (Application.isEditor)
+ {
+ m_DebugDisplayInfo = new DebugDisplayInfo();
+ }
+ }
+
+ public int Write(WriteAdapter adapter)
+ {
+ using (TimerStack.Instance.Scoped("RayPerceptionSensor.Perceive"))
+ {
+ PerceiveStatic(
+ m_RayDistance, m_Angles, m_DetectableObjects, m_StartOffset, m_EndOffset,
+ m_CastRadius, m_Transform, m_CastType, m_Observations, false, m_DebugDisplayInfo
+ );
+ adapter.AddRange(m_Observations);
+ }
+ return m_Observations.Length;
+ }
+
+ public void Update()
+ {
+ }
+
+ public int[] GetFloatObservationShape()
+ {
+ return m_Shape;
+ }
+
+ public string GetName()
+ {
+ return m_Name;
+ }
+
+ public virtual byte[] GetCompressedObservation()
+ {
+ return null;
+ }
+
+ public virtual SensorCompressionType GetCompressionType()
+ {
+ return SensorCompressionType.None;
+ }
+
+ ///
+ /// Evaluates a perception vector to be used as part of an observation of an agent.
+ /// Each element in the rayAngles array determines a sublist of data to the observation.
+ /// The sublist contains the observation data for a single cast. The list is composed of the following:
+ /// 1. A one-hot encoding for detectable objects. For example, if detectableObjects.Length = n, the
+ /// first n elements of the sublist will be a one-hot encoding of the detectableObject that was hit, or
+ /// all zeroes otherwise.
+ /// 2. The 'length' element of the sublist will be 1 if the ray missed everything, or 0 if it hit
+ /// something (detectable or not).
+ /// 3. The 'length+1' element of the sublist will contain the normalised distance to the object hit, or 1 if
+ /// nothing was hit.
+ ///
+ /// The legacyHitFractionBehavior changes the behavior to be backwards compatible but has some
+ /// counter-intuitive behavior:
+ /// * if the cast hits a object that's not in the detectableObjects list, all results are 0
+ /// * if the cast doesn't hit, the hit fraction field is 0
+ ///
+ ///
+ /// List of angles (in degrees) used to define the rays. 90 degrees is considered
+ /// "forward" relative to the game object
+ /// List of tags which correspond to object types agent can see
+ /// Starting height offset of ray from center of agent.
+ /// Ending height offset of ray from center of agent.
+ /// Radius of the sphere to use for spherecasting. If 0 or less, rays are used
+ /// instead - this may be faster, especially for complex environments.
+ /// Transform of the GameObject
+ /// Whether to perform the casts in 2D or 3D.
+ /// Output array of floats. Must be (num rays) * (num tags + 2) in size.
+ /// Whether to use the legacy behavior for hit fractions.
+ /// Optional debug information output, only used by RayPerceptionSensor.
+ ///
+ public static void PerceiveStatic(float rayLength,
+ IReadOnlyList rayAngles, IReadOnlyList detectableObjects,
+ float startOffset, float endOffset, float castRadius,
+ Transform transform, CastType castType, float[] perceptionBuffer,
+ bool legacyHitFractionBehavior = false,
+ DebugDisplayInfo debugInfo = null)
+ {
+ Array.Clear(perceptionBuffer, 0, perceptionBuffer.Length);
+ if (debugInfo != null)
+ {
+ debugInfo.Reset();
+ if (debugInfo.rayInfos == null || debugInfo.rayInfos.Length != rayAngles.Count)
+ {
+ debugInfo.rayInfos = new DebugDisplayInfo.RayInfo[rayAngles.Count];
+ }
+ }
+
+ // For each ray sublist stores categorical information on detected object
+ // along with object distance.
+ int bufferOffset = 0;
+ for (var rayIndex = 0; rayIndex 0f)
+ {
+ castHit = Physics.SphereCast(startPositionWorld, castRadius, rayDirection, out rayHit, rayLength);
+ }
+ else
+ {
+ castHit = Physics.Raycast(startPositionWorld, rayDirection, out rayHit, rayLength);
+ }
+
+ hitFraction = castHit ? rayHit.distance / rayLength : 1.0f;
+ hitObject = castHit ? rayHit.collider.gameObject : null;
+ }
+ else
+ {
+ RaycastHit2D rayHit;
+ if (castRadius > 0f)
+ {
+ rayHit = Physics2D.CircleCast(startPositionWorld, castRadius, rayDirection, rayLength);
+ }
+ else
+ {
+ rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, rayLength);
+ }
+
+ castHit = rayHit;
+ hitFraction = castHit ? rayHit.fraction : 1.0f;
+ hitObject = castHit ? rayHit.collider.gameObject : null;
+ }
+
+ if (debugInfo != null)
+ {
+ debugInfo.rayInfos[rayIndex].localStart = startPositionLocal;
+ debugInfo.rayInfos[rayIndex].localEnd = endPositionLocal;
+ debugInfo.rayInfos[rayIndex].worldStart = startPositionWorld;
+ debugInfo.rayInfos[rayIndex].worldEnd = endPositionWorld;
+ debugInfo.rayInfos[rayIndex].castHit = castHit;
+ debugInfo.rayInfos[rayIndex].hitFraction = hitFraction;
+ }
+ else if (Application.isEditor)
+ {
+ // Legacy drawing
+ Debug.DrawRay(startPositionWorld,rayDirection, Color.black, 0.01f, true);
+ }
+
+ if (castHit)
+ {
+ for (var i = 0; i < detectableObjects.Count; i++)
+ {
+ if (hitObject.CompareTag(detectableObjects[i]))
+ {
+ perceptionBuffer[bufferOffset + i] = 1;
+ perceptionBuffer[bufferOffset + detectableObjects.Count + 1] = hitFraction;
+ break;
+ }
+
+ if (!legacyHitFractionBehavior)
+ {
+ // Something was hit but not on the list. Still set the hit fraction.
+ perceptionBuffer[bufferOffset + detectableObjects.Count + 1] = hitFraction;
+ }
+ }
+ }
+ else
+ {
+ perceptionBuffer[bufferOffset + detectableObjects.Count] = 1f;
+ if (!legacyHitFractionBehavior)
+ {
+ // Nothing was hit, so there's full clearance in front of the agent.
+ perceptionBuffer[bufferOffset + detectableObjects.Count + 1] = 1.0f;
+ }
+ }
+
+ bufferOffset += detectableObjects.Count + 2;
+ }
+ }
+
+ ///
+ /// Converts polar coordinate to cartesian coordinate.
+ ///
+ static Vector3 PolarToCartesian3D(float radius, float angleDegrees)
+ {
+ var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees);
+ var z = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees);
+ return new Vector3(x, 0f, z);
+ }
+
+ ///
+ /// Converts polar coordinate to cartesian coordinate.
+ ///
+ static Vector2 PolarToCartesian2D(float radius, float angleDegrees)
+ {
+ var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees);
+ var y = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees);
+ return new Vector2(x, y);
+ }
+ }
+}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensor.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensor.cs.meta
new file mode 100644
index 0000000000..4c7247977c
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensor.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 71417cdf8dd542e19ec22822b001b884
+timeCreated: 1573089052
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponent2D.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponent2D.cs
new file mode 100644
index 0000000000..abb750128e
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponent2D.cs
@@ -0,0 +1,10 @@
+namespace MLAgents.Sensor
+{
+ public class RayPerceptionSensorComponent2D : RayPerceptionSensorComponentBase
+ {
+ public override RayPerceptionSensor.CastType GetCastType()
+ {
+ return RayPerceptionSensor.CastType.Cast2D;
+ }
+ }
+}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponent2D.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponent2D.cs.meta
new file mode 100644
index 0000000000..947a0904d3
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponent2D.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: f67c7e722ba14acd9153bb4488bff6e4
+timeCreated: 1573769662
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponent3D.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponent3D.cs
new file mode 100644
index 0000000000..6d7718971d
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponent3D.cs
@@ -0,0 +1,32 @@
+using System;
+using UnityEngine;
+
+namespace MLAgents.Sensor
+{
+ public class RayPerceptionSensorComponent3D : RayPerceptionSensorComponentBase
+ {
+ [Header("3D Properties", order = 100)]
+ [Range(-10f, 10f)]
+ [Tooltip("Ray start is offset up or down by this amount.")]
+ public float startVerticalOffset;
+
+ [Range(-10f, 10f)]
+ [Tooltip("Ray end is offset up or down by this amount.")]
+ public float endVerticalOffset;
+
+ public override RayPerceptionSensor.CastType GetCastType()
+ {
+ return RayPerceptionSensor.CastType.Cast3D;
+ }
+
+ public override float GetStartVerticalOffset()
+ {
+ return startVerticalOffset;
+ }
+
+ public override float GetEndVerticalOffset()
+ {
+ return endVerticalOffset;
+ }
+ }
+}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponent3D.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponent3D.cs.meta
new file mode 100644
index 0000000000..51ec4e5b16
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponent3D.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 6bb6b867a41448888c1cd4f99643ad71
+timeCreated: 1573764567
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponentBase.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponentBase.cs
new file mode 100644
index 0000000000..7ae5392d0a
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponentBase.cs
@@ -0,0 +1,143 @@
+using System;
+using System.Collections.Generic;
+using UnityEngine;
+
+namespace MLAgents.Sensor
+{
+ public abstract class RayPerceptionSensorComponentBase : SensorComponent
+ {
+ public string sensorName = "RayPerceptionSensor";
+
+ [Tooltip("List of tags in the scene to compare against.")]
+ public List detectableTags;
+
+ [Range(0, 50)]
+ [Tooltip("Number of rays to the left and right of center.")]
+ public int raysPerDirection = 3;
+
+ [Range(0, 180)]
+ [Tooltip("Cone size for rays. Using 90 degrees will cast rays to the left and right. Greater than 90 degrees will go backwards.")]
+ public float maxRayDegrees = 70;
+
+ [Range(0f, 10f)]
+ [Tooltip("Radius of sphere to cast. Set to zero for raycasts.")]
+ public float sphereCastRadius = 0.5f;
+
+ [Range(1, 1000)]
+ [Tooltip("Length of the rays to cast.")]
+ public float rayLength = 20f;
+
+ [Range(1, 50)]
+ [Tooltip("Whether to stack previous observations. Using 1 means no previous observations.")]
+ public int observationStacks = 1;
+
+ [Header("Debug Gizmos", order = 999)]
+ public Color rayHitColor = Color.red;
+ public Color rayMissColor = Color.white;
+ [Tooltip("Whether to draw the raycasts in the world space of when they happened, or using the Agent's current transform'")]
+ public bool useWorldPositions = true;
+
+
+ [NonSerialized]
+ RayPerceptionSensor m_RaySensor;
+
+ public abstract RayPerceptionSensor.CastType GetCastType();
+
+ public virtual float GetStartVerticalOffset()
+ {
+ return 0f;
+ }
+
+ public virtual float GetEndVerticalOffset()
+ {
+ return 0f;
+ }
+
+ public override ISensor CreateSensor()
+ {
+ var rayAngles = GetRayAngles(raysPerDirection, maxRayDegrees);
+ m_RaySensor = new RayPerceptionSensor(sensorName, rayLength, detectableTags, rayAngles,
+ transform, GetStartVerticalOffset(), GetEndVerticalOffset(), sphereCastRadius, GetCastType()
+ );
+
+ if (observationStacks != 1)
+ {
+ var stackingSensor = new StackingSensor(m_RaySensor, observationStacks);
+ return stackingSensor;
+ }
+
+ return m_RaySensor;
+ }
+
+ public static float[] GetRayAngles(int raysPerDirection, float maxRayDegrees)
+ {
+ // Example:
+ // { 90, 90 - delta, 90 + delta, 90 - 2*delta, 90 + 2*delta }
+ var anglesOut = new float[2 * raysPerDirection + 1];
+ var delta = maxRayDegrees / raysPerDirection;
+ anglesOut[0] = 90f;
+ for (var i = 0; i < raysPerDirection; i++)
+ {
+ anglesOut[2 * i + 1] = 90 - (i+1) * delta;
+ anglesOut[2 * i + 2] = 90 + (i+1) * delta;
+ }
+ return anglesOut;
+ }
+
+ public override int[] GetObservationShape()
+ {
+ var numRays = 2 * raysPerDirection + 1;
+ var numTags = detectableTags == null ? 0 : detectableTags.Count;
+ var obsSize = (numTags + 2) * numRays;
+ var stacks = observationStacks > 1 ? observationStacks : 1;
+ return new[] { obsSize * stacks };
+ }
+
+ ///
+ /// Draw the debug information from the sensor (if available).
+ ///
+ public void OnDrawGizmos()
+ {
+ if (m_RaySensor?.debugDisplayInfo?.rayInfos == null)
+ {
+ return;
+ }
+ var debugInfo = m_RaySensor.debugDisplayInfo;
+
+ // Draw "old" observations in a lighter color.
+ // Since the agent may not step every frame, this helps de-emphasize "stale" hit information.
+ var alpha = Mathf.Pow(.5f, debugInfo.age);
+
+ foreach (var rayInfo in debugInfo.rayInfos)
+ {
+ // Either use the original world-space coordinates of the raycast, or transform the agent-local
+ // coordinates of the rays to the current transform of the agent. If the agent acts every frame,
+ // these should be the same.
+ var startPositionWorld = rayInfo.worldStart;
+ var endPositionWorld = rayInfo.worldEnd;
+ if (!useWorldPositions)
+ {
+ startPositionWorld = transform.TransformPoint(rayInfo.localStart);
+ endPositionWorld = transform.TransformPoint(rayInfo.localEnd);
+ }
+ var rayDirection = endPositionWorld - startPositionWorld;
+ rayDirection *= rayInfo.hitFraction;
+
+ // hit fraction ^2 will shift "far" hits closer to the hit color
+ var lerpT = rayInfo.hitFraction * rayInfo.hitFraction;
+ var color = Color.Lerp(rayHitColor, rayMissColor, lerpT);
+ color.a = alpha;
+ Gizmos.color = color;
+ Gizmos.DrawRay(startPositionWorld,rayDirection);
+
+ // Draw the hit point as a sphere. If using rays to cast (0 radius), use a small sphere.
+ if (rayInfo.castHit)
+ {
+ var hitRadius = Mathf.Max(sphereCastRadius, .05f);
+ Gizmos.DrawWireSphere(startPositionWorld + rayDirection, hitRadius);
+ }
+ }
+
+ }
+ }
+}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponentBase.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponentBase.cs.meta
new file mode 100644
index 0000000000..97f40e582f
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponentBase.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 45243967d8c0419b953c02bccb7c2768
+timeCreated: 1573087062
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RenderTextureSensor.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RenderTextureSensor.cs
index 871854c809..f60fab4658 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RenderTextureSensor.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RenderTextureSensor.cs
@@ -1,27 +1,23 @@
using System;
-using System.Threading;
-using MLAgents.InferenceBrain;
using UnityEngine;
namespace MLAgents.Sensor
{
- class RenderTextureSensor : ISensor
+ public class RenderTextureSensor : ISensor
{
- private RenderTexture m_RenderTexture;
- private int m_Width;
- private int m_Height;
- private bool m_Grayscale;
- private string m_Name;
- private int[] m_Shape;
+ RenderTexture m_RenderTexture;
+ bool m_Grayscale;
+ string m_Name;
+ int[] m_Shape;
- public RenderTextureSensor(RenderTexture renderTexture, int width, int height, bool grayscale, string name)
+ public RenderTextureSensor(RenderTexture renderTexture, bool grayscale, string name)
{
m_RenderTexture = renderTexture;
- m_Width = width;
- m_Height = height;
+ var width = renderTexture != null ? renderTexture.width : 0;
+ var height = renderTexture != null ? renderTexture.height : 0;
m_Grayscale = grayscale;
m_Name = name;
- m_Shape = new[] { width, height, grayscale ? 1 : 3 };
+ m_Shape = new[] { height, width, grayscale ? 1 : 3 };
}
public string GetName()
@@ -38,7 +34,7 @@ public byte[] GetCompressedObservation()
{
using(TimerStack.Instance.Scoped("RenderTexSensor.GetCompressedObservation"))
{
- var texture = ObservationToTexture(m_RenderTexture, m_Width, m_Height);
+ var texture = ObservationToTexture(m_RenderTexture);
// TODO support more types here, e.g. JPG
var compressed = texture.EncodeToPNG();
UnityEngine.Object.Destroy(texture);
@@ -46,16 +42,19 @@ public byte[] GetCompressedObservation()
}
}
- public void WriteToTensor(TensorProxy tensorProxy, int index)
+ public int Write(WriteAdapter adapter)
{
using (TimerStack.Instance.Scoped("RenderTexSensor.GetCompressedObservation"))
{
- var texture = ObservationToTexture(m_RenderTexture, m_Width, m_Height);
- Utilities.TextureToTensorProxy(texture, tensorProxy, m_Grayscale, index);
+ var texture = ObservationToTexture(m_RenderTexture);
+ var numWritten = Utilities.TextureToTensorProxy(texture, adapter, m_Grayscale);
UnityEngine.Object.Destroy(texture);
+ return numWritten;
}
}
+ public void Update() { }
+
public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.PNG;
@@ -66,25 +65,13 @@ public SensorCompressionType GetCompressionType()
///
/// The 2D texture.
/// RenderTexture.
- /// Width of resulting 2D texture.
- /// Height of resulting 2D texture.
/// Texture2D to render to.
- public static Texture2D ObservationToTexture(RenderTexture obsTexture, int width, int height)
+ public static Texture2D ObservationToTexture(RenderTexture obsTexture)
{
+ var height = obsTexture.height;
+ var width = obsTexture.width;
var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
- if (width != texture2D.width || height != texture2D.height)
- {
- texture2D.Resize(width, height);
- }
-
- if (width != obsTexture.width || height != obsTexture.height)
- {
- throw new UnityAgentsException(string.Format(
- "RenderTexture {0} : width/height is {1}/{2} brain is expecting {3}/{4}.",
- obsTexture.name, obsTexture.width, obsTexture.height, width, height));
- }
-
var prevActiveRt = RenderTexture.active;
RenderTexture.active = obsTexture;
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RenderTextureSensorComponent.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RenderTextureSensorComponent.cs
index a19a532052..b192e33d92 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RenderTextureSensorComponent.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/RenderTextureSensorComponent.cs
@@ -7,13 +7,19 @@ public class RenderTextureSensorComponent : SensorComponent
{
public RenderTexture renderTexture;
public string sensorName = "RenderTextureSensor";
- public int width = 84;
- public int height = 84;
- public bool grayscale = false;
+ public bool grayscale;
public override ISensor CreateSensor()
{
- return new RenderTextureSensor(renderTexture, width, height, grayscale, sensorName);
+ return new RenderTextureSensor(renderTexture, grayscale, sensorName);
+ }
+
+ public override int[] GetObservationShape()
+ {
+ var width = renderTexture != null ? renderTexture.width : 0;
+ var height = renderTexture != null ? renderTexture.height : 0;
+
+ return new[] { height, width, grayscale ? 1 : 3 };
}
}
}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/SensorBase.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/SensorBase.cs
index 15eb53bee6..c61b9e8a21 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/SensorBase.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/SensorBase.cs
@@ -1,4 +1,3 @@
-using MLAgents.InferenceBrain;
using UnityEngine;
namespace MLAgents.Sensor
@@ -17,30 +16,24 @@ public abstract class SensorBase : ISensor
public abstract string GetName();
///
- /// Default implementation of WriteToTensor interface. This creates a temporary array, calls WriteObservation,
- /// and then writes the results to the TensorProxy.
+ /// Default implementation of Write interface. This creates a temporary array, calls WriteObservation,
+ /// and then writes the results to the WriteAdapter.
///
- ///
- ///
- public virtual void WriteToTensor(TensorProxy tensorProxy, int agentIndex)
+ ///
+ public virtual int Write(WriteAdapter adapter)
{
// TODO reuse buffer for similar agents, don't call GetFloatObservationShape()
- int[] shape = GetFloatObservationShape();
- int numFloats = 1;
- foreach (var dim in shape)
- {
- numFloats *= dim;
- }
-
+ var numFloats = this.ObservationSize();
float[] buffer = new float[numFloats];
WriteObservation(buffer);
- for (var i = 0; i < numFloats; i++)
- {
- tensorProxy.data[agentIndex, i] = buffer[i];
- }
+ adapter.AddRange(buffer);
+
+ return numFloats;
}
+ public void Update() { }
+
public virtual byte[] GetCompressedObservation()
{
return null;
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/SensorComponent.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/SensorComponent.cs
index 177846a7b5..63c1328377 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/SensorComponent.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/SensorComponent.cs
@@ -14,5 +14,23 @@ public abstract class SensorComponent : MonoBehaviour
///
///
public abstract ISensor CreateSensor();
+
+ ///
+ /// Returns the shape of the sensor observations that will be created.
+ ///
+ ///
+ public abstract int[] GetObservationShape();
+
+ public virtual bool IsVisual()
+ {
+ var shape = GetObservationShape();
+ return shape.Length == 3;
+ }
+
+ public virtual bool IsVector()
+ {
+ var shape = GetObservationShape();
+ return shape.Length == 1;
+ }
}
}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/StackingSensor.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/StackingSensor.cs
new file mode 100644
index 0000000000..9e5e001971
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/StackingSensor.cs
@@ -0,0 +1,115 @@
+namespace MLAgents.Sensor
+{
+ ///
+ /// Sensor that wraps around another Sensor to provide temporal stacking.
+ /// Conceptually, consecutive observations are stored left-to-right, which is how they're output
+ /// For example, 4 stacked sets of observations would be output like
+ /// | t = now - 3 | t = now -3 | t = now - 2 | t = now |
+ /// Internally, a circular buffer of arrays is used. The m_CurrentIndex represents the most recent observation.
+ ///
+ public class StackingSensor : ISensor
+ {
+ ///
+ /// The wrapped sensor.
+ ///
+ ISensor m_WrappedSensor;
+
+ ///
+ /// Number of stacks to save
+ ///
+ int m_NumStackedObservations;
+ int m_UnstackedObservationSize;
+
+ string m_Name;
+ int[] m_Shape;
+
+ ///
+ /// Buffer of previous observations
+ ///
+ float[][] m_StackedObservations;
+
+ int m_CurrentIndex;
+ WriteAdapter m_LocalAdapter = new WriteAdapter();
+
+ ///
+ ///
+ ///
+ /// The wrapped sensor
+ /// Number of stacked observations to keep
+ public StackingSensor(ISensor wrapped, int numStackedObservations)
+ {
+ // TODO ensure numStackedObservations > 1
+ m_WrappedSensor = wrapped;
+ m_NumStackedObservations = numStackedObservations;
+
+ m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";
+
+ var shape = wrapped.GetFloatObservationShape();
+ m_Shape = new int[shape.Length];
+
+ m_UnstackedObservationSize = wrapped.ObservationSize();
+ for (int d = 0; d < shape.Length; d++)
+ {
+ m_Shape[d] = shape[d];
+ }
+
+ // TODO support arbitrary stacking dimension
+ m_Shape[0] *= numStackedObservations;
+ m_StackedObservations = new float[numStackedObservations][];
+ for (var i = 0; i < numStackedObservations; i++)
+ {
+ m_StackedObservations[i] = new float[m_UnstackedObservationSize];
+ }
+ }
+
+ public int Write(WriteAdapter adapter)
+ {
+ // First, call the wrapped sensor's write method. Make sure to use our own adapater, not the passed one.
+ m_LocalAdapter.SetTarget(m_StackedObservations[m_CurrentIndex], 0);
+ m_WrappedSensor.Write(m_LocalAdapter);
+
+ // Now write the saved observations (oldest first)
+ var numWritten = 0;
+ for (var i = 0; i < m_NumStackedObservations; i++)
+ {
+ var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
+ adapter.AddRange(m_StackedObservations[obsIndex], numWritten);
+ numWritten += m_UnstackedObservationSize;
+ }
+
+ return numWritten;
+ }
+
+ ///
+ /// Updates the index of the "current" buffer.
+ ///
+ public void Update()
+ {
+ m_WrappedSensor.Update();
+ m_CurrentIndex = (m_CurrentIndex + 1) % m_NumStackedObservations;
+ }
+
+ public int[] GetFloatObservationShape()
+ {
+ return m_Shape;
+ }
+
+ public string GetName()
+ {
+ return m_Name;
+ }
+
+ public virtual byte[] GetCompressedObservation()
+ {
+ return null;
+ }
+
+ public virtual SensorCompressionType GetCompressionType()
+ {
+ return SensorCompressionType.None;
+ }
+
+ // TODO support stacked compressed observations (byte stream)
+
+ }
+}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/StackingSensor.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/StackingSensor.cs.meta
new file mode 100644
index 0000000000..f0289542ff
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/StackingSensor.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 8b7a6e88d47d4438ad67e1862566462c
+timeCreated: 1572299581
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/VectorSensor.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/VectorSensor.cs
new file mode 100644
index 0000000000..067c556029
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/VectorSensor.cs
@@ -0,0 +1,172 @@
+using System.Collections.Generic;
+using UnityEngine;
+
+namespace MLAgents.Sensor
+{
+ public class VectorSensor : ISensor
+ {
+ // TODO use float[] instead
+ // TOOD allow setting float[]
+ List m_Observations;
+ int[] m_Shape;
+ string m_Name;
+
+ public VectorSensor(int observationSize, string name = null)
+ {
+ if (name == null)
+ {
+ name = $"VectorSensor_size{observationSize}";
+ }
+
+ m_Observations = new List(observationSize);
+ m_Name = name;
+ m_Shape = new[] { observationSize };
+ }
+
+ public int Write(WriteAdapter adapter)
+ {
+ var expectedObservations = m_Shape[0];
+ if (m_Observations.Count > expectedObservations)
+ {
+ // Too many observations, truncate
+ Debug.LogWarningFormat(
+ "More observations ({0}) made than vector observation size ({1}). The observations will be truncated.",
+ m_Observations.Count, expectedObservations
+ );
+ m_Observations.RemoveRange(expectedObservations, m_Observations.Count - expectedObservations);
+ }
+ else if (m_Observations.Count < expectedObservations)
+ {
+ // Not enough observations; pad with zeros.
+ Debug.LogWarningFormat(
+ "Fewer observations ({0}) made than vector observation size ({1}). The observations will be padded.",
+ m_Observations.Count, expectedObservations
+ );
+ for (int i = m_Observations.Count; i < expectedObservations; i++)
+ {
+ m_Observations.Add(0);
+ }
+ }
+ adapter.AddRange(m_Observations);
+ return expectedObservations;
+ }
+
+ public void Update()
+ {
+ Clear();
+ }
+
+ public int[] GetFloatObservationShape()
+ {
+ return m_Shape;
+ }
+
+ public string GetName()
+ {
+ return m_Name;
+ }
+
+ public virtual byte[] GetCompressedObservation()
+ {
+ return null;
+ }
+
+ public virtual SensorCompressionType GetCompressionType()
+ {
+ return SensorCompressionType.None;
+ }
+
+ void Clear()
+ {
+ m_Observations.Clear();
+ }
+
+ void AddFloatObs(float obs)
+ {
+ m_Observations.Add(obs);
+ }
+
+ // Compatibility methods with Agent observation. These should be removed eventually.
+
+ ///
+ /// Adds a float observation to the vector observations of the agent.
+ ///
+ /// Observation.
+ public void AddObservation(float observation)
+ {
+ AddFloatObs(observation);
+ }
+
+ ///
+ /// Adds an integer observation to the vector observations of the agent.
+ ///
+ /// Observation.
+ public void AddObservation(int observation)
+ {
+ AddFloatObs(observation);
+ }
+
+ ///
+ /// Adds an Vector3 observation to the vector observations of the agent.
+ ///
+ /// Observation.
+ public void AddObservation(Vector3 observation)
+ {
+ AddFloatObs(observation.x);
+ AddFloatObs(observation.y);
+ AddFloatObs(observation.z);
+ }
+
+ ///
+ /// Adds an Vector2 observation to the vector observations of the agent.
+ ///
+ /// Observation.
+ public void AddObservation(Vector2 observation)
+ {
+ AddFloatObs(observation.x);
+ AddFloatObs(observation.y);
+ }
+
+ ///
+ /// Adds a collection of float observations to the vector observations of the agent.
+ ///
+ /// Observation.
+ public void AddObservation(IEnumerable observation)
+ {
+ foreach (var f in observation)
+ {
+ AddFloatObs(f);
+ }
+ }
+
+ ///
+ /// Adds a quaternion observation to the vector observations of the agent.
+ ///
+ /// Observation.
+ public void AddObservation(Quaternion observation)
+ {
+ AddFloatObs(observation.x);
+ AddFloatObs(observation.y);
+ AddFloatObs(observation.z);
+ AddFloatObs(observation.w);
+ }
+
+ ///
+ /// Adds a boolean observation to the vector observation of the agent.
+ ///
+ ///
+ public void AddObservation(bool observation)
+ {
+ AddFloatObs(observation ? 1f : 0f);
+ }
+
+
+ public void AddOneHotObservation(int observation, int range)
+ {
+ for (var i = 0; i < range; i++)
+ {
+ AddFloatObs(i == observation ? 1.0f : 0.0f);
+ }
+ }
+ }
+}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/VectorSensor.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/VectorSensor.cs.meta
new file mode 100644
index 0000000000..277ef0d59e
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/VectorSensor.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: e3966c9961b343108808d91a4d140a68
+timeCreated: 1572300800
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/WriteAdapter.cs b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/WriteAdapter.cs
new file mode 100644
index 0000000000..918274758a
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/WriteAdapter.cs
@@ -0,0 +1,105 @@
+using System.Collections.Generic;
+using MLAgents.InferenceBrain;
+
+namespace MLAgents.Sensor
+{
+ ///
+ /// Allows sensors to write to both TensorProxy and float arrays/lists.
+ ///
+ public class WriteAdapter
+ {
+ IList m_Data;
+ int m_Offset;
+
+ TensorProxy m_Proxy;
+ int m_Batch;
+
+ ///
+ /// Set the adapter to write to an IList at the given channelOffset.
+ ///
+ ///
+ ///
+ public void SetTarget(IList data, int offset)
+ {
+ m_Data = data;
+ m_Offset = offset;
+ m_Proxy = null;
+ m_Batch = -1;
+ }
+
+ ///
+ /// Set the adapter to write to a TensorProxy at the given batch and channel offset.
+ ///
+ ///
+ ///
+ ///
+ public void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset)
+ {
+ m_Proxy = tensorProxy;
+ m_Batch = batchIndex;
+ m_Offset = channelOffset;
+ m_Data = null;
+ }
+
+ ///
+ /// 1D write access at a specified index. Use AddRange if possible instead.
+ ///
+ /// Index to write to
+ public float this[int index]
+ {
+ set
+ {
+ if (m_Data != null)
+ {
+ m_Data[index + m_Offset] = value;
+ }
+ else
+ {
+ m_Proxy.data[m_Batch, index + m_Offset] = value;
+ }
+ }
+ }
+
+ ///
+ /// 3D write access at the specified height, width, and channel. Only usable with a TensorProxy target.
+ ///
+ ///
+ ///
+ ///
+ public float this[int h, int w, int ch]
+ {
+ set
+ {
+ // Only TensorProxy supports 3D access
+ m_Proxy.data[m_Batch, h, w, ch + m_Offset] = value;
+ }
+ }
+
+ ///
+ /// Write the range of floats
+ ///
+ ///
+ /// Optional write offset
+ public void AddRange(IEnumerable data, int writeOffset = 0)
+ {
+ if (m_Data != null)
+ {
+ int index = 0;
+ foreach (var val in data)
+ {
+ m_Data[index + m_Offset + writeOffset] = val;
+ index++;
+ }
+ }
+ else
+ {
+ int index = 0;
+ foreach (var val in data)
+ {
+ m_Proxy.data[m_Batch, index + m_Offset + writeOffset] = val;
+ index++;
+ }
+ }
+ }
+ }
+}
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Sensor/WriteAdapter.cs.meta b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/WriteAdapter.cs.meta
new file mode 100644
index 0000000000..62fc3b1aba
--- /dev/null
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Sensor/WriteAdapter.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 86bad2e6dded4a62853752a1713981f2
+timeCreated: 1572540197
\ No newline at end of file
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Startup.cs b/UnitySDK/Assets/ML-Agents/Scripts/Startup.cs
index 75ed10eb7b..650ade7c1c 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Startup.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Startup.cs
@@ -6,15 +6,15 @@ namespace MLAgents
{
public class Startup : MonoBehaviour
{
- private const string k_SceneVariableName = "SCENE_NAME";
+ const string k_SceneVariableName = "SCENE_NAME";
- private void Awake()
+ void Awake()
{
var sceneName = Environment.GetEnvironmentVariable(k_SceneVariableName);
SwitchScene(sceneName);
}
- private static void SwitchScene(string sceneName)
+ static void SwitchScene(string sceneName)
{
if (sceneName == null)
{
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Timer.cs b/UnitySDK/Assets/ML-Agents/Scripts/Timer.cs
index edf7f3d15b..8227a9e68f 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Timer.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Timer.cs
@@ -5,11 +5,6 @@
using UnityEngine.Profiling;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Json;
-#if UNITY_EDITOR
-using UnityEditor;
-
-#endif
-
namespace MLAgents
{
@@ -33,7 +28,7 @@ public class TimerNode
///
/// Custom sampler used to add timings to the profiler.
///
- private CustomSampler m_Sampler;
+ CustomSampler m_Sampler;
///
/// Number of total ticks elapsed for this node.
@@ -235,7 +230,7 @@ public string DebugGetTimerString(string parentName = "", int level = 0)
///
public class TimerStack : System.IDisposable
{
- private static readonly TimerStack k_Instance = new TimerStack();
+ static readonly TimerStack k_Instance = new TimerStack();
Stack m_Stack;
TimerNode m_RootNode;
@@ -246,7 +241,7 @@ static TimerStack()
{
}
- private TimerStack()
+ TimerStack()
{
Reset();
}
@@ -268,7 +263,7 @@ public TimerNode RootNode
get { return m_RootNode; }
}
- private void Push(string name)
+ void Push(string name)
{
var current = m_Stack.Peek();
var next = current.GetChild(name);
@@ -276,7 +271,7 @@ private void Push(string name)
next.Begin();
}
- private void Pop()
+ void Pop()
{
var node = m_Stack.Pop();
node.End();
diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Utilities.cs b/UnitySDK/Assets/ML-Agents/Scripts/Utilities.cs
index 9e7cf3d589..39660aacf8 100644
--- a/UnitySDK/Assets/ML-Agents/Scripts/Utilities.cs
+++ b/UnitySDK/Assets/ML-Agents/Scripts/Utilities.cs
@@ -1,70 +1,34 @@
using UnityEngine;
using System.Collections.Generic;
-using MLAgents.InferenceBrain;
+using MLAgents.Sensor;
namespace MLAgents
{
public static class Utilities
{
- ///
- /// Converts a list of Texture2D into a TensorProxy.
- ///
- ///
- /// The list of textures to be put into the tensor.
- /// Note that the textures must have same width and height.
- ///
- ///
- /// TensorProxy to fill with Texture data.
- ///
- ///
- /// If set to true the textures will be converted to grayscale before
- /// being stored in the tensor.
- ///
- public static void TextureToTensorProxy(
- List textures,
- TensorProxy tensorProxy,
- bool grayScale)
- {
- var numTextures = textures.Count;
- var width = textures[0].width;
- var height = textures[0].height;
-
- for (var t = 0; t < numTextures; t++)
- {
- var texture = textures[t];
- Debug.Assert(width == texture.width, "All Textures must have the same dimension");
- Debug.Assert(height == texture.height, "All Textures must have the same dimension");
- TextureToTensorProxy(texture, tensorProxy, grayScale, t);
- }
- }
///
- /// Puts a Texture2D into a TensorProxy.
+ /// Puts a Texture2D into a WriteAdapter.
///
///
/// The texture to be put into the tensor.
///
- ///
- /// TensorProxy to fill with Texture data.
+ ///
+ /// Adapter to fill with Texture data.
///
///
/// If set to true the textures will be converted to grayscale before
/// being stored in the tensor.
///
- ///
- /// Index of the texture being written.
- ///
- public static void TextureToTensorProxy(
+ /// The number of floats written
+ public static int TextureToTensorProxy(
Texture2D texture,
- TensorProxy tensorProxy,
- bool grayScale,
- int textureOffset = 0)
+ WriteAdapter adapter,
+ bool grayScale)
{
var width = texture.width;
var height = texture.height;
- var data = tensorProxy.data;
- var t = textureOffset;
var texturePixels = texture.GetPixels32();
// During training, we convert from Texture to PNG before sending to the trainer, which has the
// effect of flipping the image. We need another flip here at inference time to match this.
@@ -75,19 +39,20 @@ public static void TextureToTensorProxy(
var currentPixel = texturePixels[(height - h - 1) * width + w];
if (grayScale)
{
- data[t, h, w, 0] =
+ adapter[h, w, 0] =
(currentPixel.r + currentPixel.g + currentPixel.b) / 3f / 255.0f;
}
else
{
// For Color32, the r, g and b values are between 0 and 255.
- data[t, h, w, 0] = currentPixel.r / 255.0f;
- data[t, h, w, 1] = currentPixel.g / 255.0f;
- data[t, h, w, 2] = currentPixel.b / 255.0f;
+ adapter[h, w, 0] = currentPixel.r / 255.0f;
+ adapter[h, w, 1] = currentPixel.g / 255.0f;
+ adapter[h, w, 2] = currentPixel.b / 255.0f;
}
}
}
+ return height * width * (grayScale ? 1 : 3);
}
///
diff --git a/UnitySDK/UnityPackageManager/manifest.json b/UnitySDK/UnityPackageManager/manifest.json
new file mode 100644
index 0000000000..f1a80b4b48
--- /dev/null
+++ b/UnitySDK/UnityPackageManager/manifest.json
@@ -0,0 +1,5 @@
+{
+ "dependencies": {
+ "com.unity.barracuda": "0.3.2-preview"
+ }
+}
diff --git a/UnitySDK/UnitySDK.sln.DotSettings b/UnitySDK/UnitySDK.sln.DotSettings
index 2bc3f33d99..0e3be1bcd0 100644
--- a/UnitySDK/UnitySDK.sln.DotSettings
+++ b/UnitySDK/UnitySDK.sln.DotSettings
@@ -3,8 +3,10 @@
CPU
GPU
NN
+ PNG
RL
True
+ True
True
diff --git a/config/offline_bc_config.yaml b/config/offline_bc_config.yaml
index 5c0fd4b6c0..7d071505c5 100644
--- a/config/offline_bc_config.yaml
+++ b/config/offline_bc_config.yaml
@@ -12,6 +12,20 @@ default:
memory_size: 256
demo_path: ./UnitySDK/Assets/Demonstrations/.demo
+FoodCollector:
+ trainer: offline_bc
+ batch_size: 64
+ summary_freq: 1000
+ max_steps: 5.0e4
+ batches_per_epoch: 10
+ use_recurrent: false
+ hidden_units: 128
+ learning_rate: 3.0e-4
+ num_layers: 2
+ sequence_length: 32
+ memory_size: 256
+ demo_path: ./demos/ExpertFood.demo
+
Hallway:
trainer: offline_bc
max_steps: 5.0e5
@@ -25,3 +39,4 @@ Hallway:
memory_size: 256
sequence_length: 32
demo_path: ./demos/ExpertHallway.demo
+
diff --git a/config/sac_trainer_config.yaml b/config/sac_trainer_config.yaml
index 718778e4f6..809827854c 100644
--- a/config/sac_trainer_config.yaml
+++ b/config/sac_trainer_config.yaml
@@ -251,7 +251,7 @@ GridWorld:
init_entcoef: 0.5
buffer_init_steps: 1000
buffer_size: 50000
- max_steps: 5.0e5
+ max_steps: 50000
summary_freq: 2000
time_horizon: 5
reward_signals:
diff --git a/config/trainer_config.yaml b/config/trainer_config.yaml
index 9480e3bb49..db2cbab016 100644
--- a/config/trainer_config.yaml
+++ b/config/trainer_config.yaml
@@ -264,7 +264,7 @@ GridWorld:
hidden_units: 256
beta: 5.0e-3
buffer_size: 256
- max_steps: 5.0e5
+ max_steps: 50000
summary_freq: 2000
time_horizon: 5
reward_signals:
diff --git a/demos/Expert3DBall.demo b/demos/Expert3DBall.demo
index 2e8447a4dd..2c87eff290 100644
Binary files a/demos/Expert3DBall.demo and b/demos/Expert3DBall.demo differ
diff --git a/demos/Expert3DBallHard.demo b/demos/Expert3DBallHard.demo
index e3c976ce25..e3c5cdf0d9 100644
Binary files a/demos/Expert3DBallHard.demo and b/demos/Expert3DBallHard.demo differ
diff --git a/demos/ExpertBasic.demo b/demos/ExpertBasic.demo
index e2b62d34d2..979acb95f9 100644
Binary files a/demos/ExpertBasic.demo and b/demos/ExpertBasic.demo differ
diff --git a/demos/ExpertBouncer.demo b/demos/ExpertBouncer.demo
index 742009b82d..4b186c3a7e 100644
Binary files a/demos/ExpertBouncer.demo and b/demos/ExpertBouncer.demo differ
diff --git a/demos/ExpertCrawlerDyn.demo b/demos/ExpertCrawlerDyn.demo
index a653f93c11..3a8c3b6d9e 100644
Binary files a/demos/ExpertCrawlerDyn.demo and b/demos/ExpertCrawlerDyn.demo differ
diff --git a/demos/ExpertCrawlerSta.demo b/demos/ExpertCrawlerSta.demo
index 0ff7e205b6..3e6f5622c1 100644
Binary files a/demos/ExpertCrawlerSta.demo and b/demos/ExpertCrawlerSta.demo differ
diff --git a/demos/ExpertFood.demo b/demos/ExpertFood.demo
index ac396fa615..5490b52f23 100644
Binary files a/demos/ExpertFood.demo and b/demos/ExpertFood.demo differ
diff --git a/demos/ExpertGrid.demo b/demos/ExpertGrid.demo
index 1e587c8372..826fc63274 100644
Binary files a/demos/ExpertGrid.demo and b/demos/ExpertGrid.demo differ
diff --git a/demos/ExpertHallway.demo b/demos/ExpertHallway.demo
index 329314a7bc..b362d4dc81 100644
Binary files a/demos/ExpertHallway.demo and b/demos/ExpertHallway.demo differ
diff --git a/demos/ExpertPush.demo b/demos/ExpertPush.demo
index 182fc352b3..f91b4fa92f 100644
Binary files a/demos/ExpertPush.demo and b/demos/ExpertPush.demo differ
diff --git a/demos/ExpertPyramid.demo b/demos/ExpertPyramid.demo
index f0bdf40d72..f2c3f66816 100644
Binary files a/demos/ExpertPyramid.demo and b/demos/ExpertPyramid.demo differ
diff --git a/demos/ExpertReacher.demo b/demos/ExpertReacher.demo
index badb920d80..621064d519 100644
Binary files a/demos/ExpertReacher.demo and b/demos/ExpertReacher.demo differ
diff --git a/demos/ExpertTennis.demo b/demos/ExpertTennis.demo
index 3e193f14f1..922402b4d3 100644
Binary files a/demos/ExpertTennis.demo and b/demos/ExpertTennis.demo differ
diff --git a/demos/ExpertWalker.demo b/demos/ExpertWalker.demo
index 881fc535e2..5fc3c87f34 100644
Binary files a/demos/ExpertWalker.demo and b/demos/ExpertWalker.demo differ
diff --git a/docs/Basic-Guide.md b/docs/Basic-Guide.md
index ac9b03101a..56fbc7f629 100644
--- a/docs/Basic-Guide.md
+++ b/docs/Basic-Guide.md
@@ -158,10 +158,9 @@ like this:
INFO:mlagents.envs:
'Ball3DAcademy' started successfully!
Unity Academy name: Ball3DAcademy
- Number of Brains: 1
- Number of Training Brains : 1
- Reset Parameters :
+ Reset Parameters : {}
+INFO:mlagents.envs:Connected new brain:
Unity brain name: 3DBallLearning
Number of Visual Observations (per agent): 0
Vector Observation space size (per agent): 8
diff --git a/docs/Creating-Custom-Protobuf-Messages.md b/docs/Creating-Custom-Protobuf-Messages.md
deleted file mode 100644
index 2b22acef5a..0000000000
--- a/docs/Creating-Custom-Protobuf-Messages.md
+++ /dev/null
@@ -1,171 +0,0 @@
-# Disclaimer
-*NOTE:* `CustomAction` and `CustomObservation` are meant for researchers who intend to use the resulting environments with their own training code. In addition to implementing a custom message, you will also need to make extensive modifications to the trainer in order to produce custom actions or consume custom observations; we don't recommend modifying our trainer code, or using this feature unless you know what you are doing and have a very specific use-case in mind. *Proceed at your own risk*.
-
-# Creating Custom Protobuf Messages
-
-Unity and Python communicate by sending protobuf messages to and from each other. You can create custom protobuf messages if you want to exchange structured data beyond what is included by default.
-
-## Implementing a Custom Message
-
-Whenever you change the fields of a custom message, you must follow the directions in [this file](../protobuf-definitions/README.md) to create C# and Python files corresponding to the new message and re-install the mlagents Python package.
-
-## Custom Message Types
-
-There are three custom message types currently supported - Custom Actions, Custom Reset Parameters, and Custom Observations. In each case, `env` is an instance of a `UnityEnvironment` in Python.
-
-### Custom Actions
-
-By default, the Python API sends actions to Unity in the form of a floating point list and an optional string-valued text action for each agent.
-
-You can define a custom action type, to either replace or augment the default, by adding fields to the `CustomAction` message, which you can do by editing the file `protobuf-definitions/proto/mlagents/envs/communicator_objects/custom_action.proto`.
-
-Instances of custom actions are set via the `custom_action` parameter of the `env.step`. An agent receives a custom action by defining a method with the signature:
-
-```csharp
-public virtual void AgentAction(float[] vectorAction, string textAction, CommunicatorObjects.CustomAction customAction)
-```
-
-Below is an example of creating a custom action that instructs an agent to choose a cardinal direction to walk in and how far to walk.
-
-The `custom_action.proto` file looks like:
-
-```protobuf
-syntax = "proto3";
-
-option csharp_namespace = "MLAgents.CommunicatorObjects";
-package communicator_objects;
-
-message CustomAction {
- enum Direction {
- NORTH=0;
- SOUTH=1;
- EAST=2;
- WEST=3;
- }
- float walkAmount = 1;
- Direction direction = 2;
-}
-```
-
-The Python instance of the custom action looks like:
-
-```python
-from mlagents.envs.communicator_objects import CustomAction
-env = mlagents.envs.UnityEnvironment(...)
-...
-action = CustomAction(direction=CustomAction.NORTH, walkAmount=2.0)
-env.step(custom_action=action)
-```
-
-And the agent code looks like:
-
-```csharp
-...
-using MLAgents;
-using MLAgents.CommunicatorObjects;
-
-class MyAgent : Agent {
- ...
- override public void AgentAction(float[] vectorAction, string textAction, CustomAction customAction) {
- switch(customAction.Direction) {
- case CustomAction.Types.Direction.North:
- transform.Translate(0, 0, customAction.WalkAmount);
- break;
- ...
- }
- }
-}
-```
-
-Keep in mind that the protobuffer compiler automatically configures the capitalization scheme of the C# version of the custom field names you defined in the `CustomAction` message to match C# conventions - "NORTH" becomes "North", "walkAmount" becomes "WalkAmount", etc.
-
-### Custom Reset Parameters
-
-By default, you can configure an environment `env` in the Python API by specifying a `config` parameter that is a dictionary mapping strings to floats.
-
-You can also configure the environment reset using a custom protobuf message. To do this, add fields to the `CustomResetParameters` protobuf message in `custom_reset_parameters.proto`, analogously to `CustomAction` above. Then pass an instance of the message to `env.reset` via the `custom_reset_parameters` keyword parameter.
-
-In Unity, you can then access the `customResetParameters` field of your academy to accesss the values set in your Python script.
-
-In this example, the academy is setting the initial position of a box based on custom reset parameters. The `custom_reset_parameters.proto` would look like:
-
-```protobuf
-message CustomResetParameters {
- message Position {
- float x = 1;
- float y = 2;
- float z = 3;
- }
- message Color {
- float r = 1;
- float g = 2;
- float b = 3;
- }
- Position initialPos = 1;
- Color color = 2;
-}
-```
-
-The Python instance of the custom reset parameter looks like
-
-```python
-from mlagents.envs.communicator_objects import CustomResetParameters
-env = ...
-pos = CustomResetParameters.Position(x=1, y=1, z=2)
-color = CustomResetParameters.Color(r=.5, g=.1, b=1.0)
-params = CustomResetParameters(initialPos=pos, color=color)
-env.reset(custom_reset_parameters=params)
-```
-
-The academy looks like
-
-```csharp
-public class MyAcademy : Academy
-{
- public GameObject box; // This would be connected to a game object in your scene in the Unity editor.
-
- override public void AcademyReset()
- {
- var boxParams = customResetParameters;
- if (boxParams != null)
- {
- var pos = boxParams.InitialPos;
- var color = boxParams.Color;
- box.transform.position = new Vector3(pos.X, pos.Y, pos.Z);
- box.GetComponent().material.color = new Color(color.R, color.G, color.B);
- }
- }
-}
-```
-
-### Custom Observations
-
-By default, Unity returns observations to Python in the form of a floating-point vector.
-
-You can define a custom observation message to supplement that. To do so, add fields to the `CustomObservation` protobuf message in `custom_observation.proto`.
-
-Then in your agent, create an instance of a custom observation via `new CommunicatorObjects.CustomObservation`. Then in `CollectObservations`, call `SetCustomObservation` with the custom observation instance as the parameter.
-
-In Python, the custom observation can be accessed by calling `env.step` or `env.reset` and accessing the `custom_observations` property of the return value. It will contain a list with one `CustomObservation` instance per agent.
-
-For example, if you have added a field called `customField` to the `CustomObservation` message, the agent code looks like:
-
-```csharp
-class MyAgent : Agent {
- override public void CollectObservations() {
- var obs = new CustomObservation();
- obs.CustomField = 1.0;
- SetCustomObservation(obs);
- }
-}
-```
-
-In Python, the custom field would be accessed like:
-
-```python
-...
-result = env.step(...)
-result[behavior_name].custom_observations[0].customField
-```
-
-where `behavior_name` is the `Behavior Name` property of the Agent.
diff --git a/docs/Glossary.md b/docs/Glossary.md
index 055f470f52..7db920ee51 100644
--- a/docs/Glossary.md
+++ b/docs/Glossary.md
@@ -18,7 +18,7 @@
* **Frame** - An instance of rendering the main camera for the display.
Corresponds to each `Update` call of the game engine.
* **Observation** - Partial information describing the state of the environment
- available to a given agent. (e.g. Vector, Visual, Text)
+ available to a given agent. (e.g. Vector, Visual)
* **Policy** - Function for producing decisions from observations.
* **Reward** - Signal provided at every step used to indicate desirability of an
agent’s action within the current state of the environment.
diff --git a/docs/Installation-Windows.md b/docs/Installation-Windows.md
new file mode 100644
index 0000000000..883ab7fb27
--- /dev/null
+++ b/docs/Installation-Windows.md
@@ -0,0 +1,354 @@
+# Installing ML-Agents Toolkit for Windows (Deprecated)
+
+Note: We no longer use this guide ourselves and so it may not work correctly. We've decided to
+ keep it up just in case it is helpful to you.
+
+The ML-Agents toolkit supports Windows 10. While it might be possible to run the
+ML-Agents toolkit using other versions of Windows, it has not been tested on
+other versions. Furthermore, the ML-Agents toolkit has not been tested on a
+Windows VM such as Bootcamp or Parallels.
+
+To use the ML-Agents toolkit, you install Python and the required Python
+packages as outlined below. This guide also covers how set up GPU-based training
+(for advanced users). GPU-based training is not currently required for the
+ML-Agents toolkit. However, training on a GPU might be required by future
+versions and features.
+
+## Step 1: Install Python via Anaconda
+
+[Download](https://www.anaconda.com/download/#windows) and install Anaconda for
+Windows. By using Anaconda, you can manage separate environments for different
+distributions of Python. Python 3.6.1 or higher is required as we no longer support
+Python 2. In this guide, we are using Python version 3.6 and Anaconda version
+5.1
+([64-bit](https://repo.continuum.io/archive/Anaconda3-5.1.0-Windows-x86_64.exe)
+or [32-bit](https://repo.continuum.io/archive/Anaconda3-5.1.0-Windows-x86.exe)
+direct links).
+
+
+
+
+
+We recommend the default _advanced installation options_. However, select the
+options appropriate for your specific situation.
+
+
+
+
+
+After installation, you must open __Anaconda Navigator__ to finish the setup.
+From the Windows search bar, type _anaconda navigator_. You can close Anaconda
+Navigator after it opens.
+
+If environment variables were not created, you will see error "conda is not
+recognized as internal or external command" when you type `conda` into the
+command line. To solve this you will need to set the environment variable
+correctly.
+
+Type `environment variables` in the search bar (this can be reached by hitting
+the Windows key or the bottom left Windows button). You should see an option
+called __Edit the system environment variables__.
+
+
+
+
+
+From here, click the __Environment Variables__ button. Double click "Path" under
+__System variable__ to edit the "Path" variable, click __New__ to add the
+following new paths.
+
+```console
+%UserProfile%\Anaconda3\Scripts
+%UserProfile%\Anaconda3\Scripts\conda.exe
+%UserProfile%\Anaconda3
+%UserProfile%\Anaconda3\python.exe
+```
+
+## Step 2: Setup and Activate a New Conda Environment
+
+You will create a new [Conda environment](https://conda.io/docs/) to be used
+with the ML-Agents toolkit. This means that all the packages that you install
+are localized to just this environment. It will not affect any other
+installation of Python or other environments. Whenever you want to run
+ML-Agents, you will need activate this Conda environment.
+
+To create a new Conda environment, open a new Anaconda Prompt (_Anaconda Prompt_
+in the search bar) and type in the following command:
+
+```sh
+conda create -n ml-agents python=3.6
+```
+
+You may be asked to install new packages. Type `y` and press enter _(make sure
+you are connected to the Internet)_. You must install these required packages.
+The new Conda environment is called ml-agents and uses Python version 3.6.
+
+
+
+
+
+To use this environment, you must activate it. _(To use this environment In the
+future, you can run the same command)_. In the same Anaconda Prompt, type in the
+following command:
+
+```sh
+activate ml-agents
+```
+
+You should see `(ml-agents)` prepended on the last line.
+
+Next, install `tensorflow`. Install this package using `pip` - which is a
+package management system used to install Python packages. Latest versions of
+TensorFlow won't work, so you will need to make sure that you install version
+1.7.1. In the same Anaconda Prompt, type in the following command _(make sure
+you are connected to the Internet)_:
+
+```sh
+pip install tensorflow==1.7.1
+```
+
+## Step 3: Install Required Python Packages
+
+The ML-Agents toolkit depends on a number of Python packages. Use `pip` to
+install these Python dependencies.
+
+If you haven't already, clone the ML-Agents Toolkit Github repository to your
+local computer. You can do this using Git ([download
+here](https://git-scm.com/download/win)) and running the following commands in
+an Anaconda Prompt _(if you open a new prompt, be sure to activate the ml-agents
+Conda environment by typing `activate ml-agents`)_:
+
+```sh
+git clone https://github.com/Unity-Technologies/ml-agents.git
+```
+
+If you don't want to use Git, you can always directly download all the files
+[here](https://github.com/Unity-Technologies/ml-agents/archive/master.zip).
+
+The `UnitySDK` subdirectory contains the Unity Assets to add to your projects.
+It also contains many [example environments](Learning-Environment-Examples.md)
+to help you get started.
+
+The `ml-agents` subdirectory contains a Python package which provides deep reinforcement
+learning trainers to use with Unity environments.
+
+The `ml-agents-envs` subdirectory contains a Python API to interface with Unity, which
+the `ml-agents` package depends on.
+
+The `gym-unity` subdirectory contains a package to interface with OpenAI Gym.
+
+Keep in mind where the files were downloaded, as you will need the
+trainer config files in this directory when running `mlagents-learn`.
+Make sure you are connected to the Internet and then type in the Anaconda
+Prompt:
+
+```console
+pip install mlagents
+```
+
+This will complete the installation of all the required Python packages to run
+the ML-Agents toolkit.
+
+Sometimes on Windows, when you use pip to install certain Python packages, the pip will get stuck when trying to read the cache of the package. If you see this, you can try:
+
+```console
+pip install mlagents --no-cache-dir
+```
+
+This `--no-cache-dir` tells the pip to disable the cache.
+
+### Installing for Development
+
+If you intend to make modifications to `ml-agents` or `ml-agents-envs`, you should install
+the packages from the cloned repo rather than from PyPi. To do this, you will need to install
+ `ml-agents` and `ml-agents-envs` separately.
+
+In our example, the files are located in `C:\Downloads`. After you have either
+cloned or downloaded the files, from the Anaconda Prompt, change to the ml-agents
+subdirectory inside the ml-agents directory:
+
+```console
+cd C:\Downloads\ml-agents
+```
+
+From the repo's main directory, now run:
+
+```console
+cd ml-agents-envs
+pip install -e .
+cd ..
+cd ml-agents
+pip install -e .
+```
+
+Running pip with the `-e` flag will let you make changes to the Python files directly and have those
+reflected when you run `mlagents-learn`. It is important to install these packages in this order as the
+`mlagents` package depends on `mlagents_envs`, and installing it in the other
+order will download `mlagents_envs` from PyPi.
+
+## (Optional) Step 4: GPU Training using The ML-Agents Toolkit
+
+GPU is not required for the ML-Agents toolkit and won't speed up the PPO
+algorithm a lot during training(but something in the future will benefit from
+GPU). This is a guide for advanced users who want to train using GPUs.
+Additionally, you will need to check if your GPU is CUDA compatible. Please
+check Nvidia's page [here](https://developer.nvidia.com/cuda-gpus).
+
+Currently for the ML-Agents toolkit, only CUDA v9.0 and cuDNN v7.0.5 is supported.
+
+### Install Nvidia CUDA toolkit
+
+[Download](https://developer.nvidia.com/cuda-toolkit-archive) and install the
+CUDA toolkit 9.0 from Nvidia's archive. The toolkit includes GPU-accelerated
+libraries, debugging and optimization tools, a C/C++ (Step Visual Studio 2017)
+compiler and a runtime library and is needed to run the ML-Agents toolkit. In
+this guide, we are using version
+[9.0.176](https://developer.nvidia.com/compute/cuda/9.0/Prod/network_installers/cuda_9.0.176_win10_network-exe)).
+
+Before installing, please make sure you __close any running instances of Unity
+or Visual Studio__.
+
+Run the installer and select the Express option. Note the directory where you
+installed the CUDA toolkit. In this guide, we installed in the directory
+`C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0`
+
+### Install Nvidia cuDNN library
+
+[Download](https://developer.nvidia.com/cudnn) and install the cuDNN library
+from Nvidia. cuDNN is a GPU-accelerated library of primitives for deep neural
+networks. Before you can download, you will need to sign up for free to the
+Nvidia Developer Program.
+
+
+
+
+
+Once you've signed up, go back to the cuDNN
+[downloads page](https://developer.nvidia.com/cudnn).
+You may or may not be asked to fill out a short survey. When you get to the list
+cuDNN releases, __make sure you are downloading the right version for the CUDA
+toolkit you installed in Step 1.__ In this guide, we are using version 7.0.5 for
+CUDA toolkit version 9.0
+([direct link](https://developer.nvidia.com/compute/machine-learning/cudnn/secure/v7.0.5/prod/9.0_20171129/cudnn-9.0-windows10-x64-v7)).
+
+After you have downloaded the cuDNN files, you will need to extract the files
+into the CUDA toolkit directory. In the cuDNN zip file, there are three folders
+called `bin`, `include`, and `lib`.
+
+
+
+
+
+Copy these three folders into the CUDA toolkit directory. The CUDA toolkit
+directory is located at
+`C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0`
+
+
+
+
+
+### Set Environment Variables
+
+You will need to add one environment variable and two path variables.
+
+To set the environment variable, type `environment variables` in the search bar
+(this can be reached by hitting the Windows key or the bottom left Windows
+button). You should see an option called __Edit the system environment
+variables__.
+
+
+
+
+
+From here, click the __Environment Variables__ button. Click __New__ to add a
+new system variable _(make sure you do this under __System variables__ and not
+User variables_.
+
+
+
+
+
+For __Variable Name__, enter `CUDA_HOME`. For the variable value, put the
+directory location for the CUDA toolkit. In this guide, the directory location
+is `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0`. Press __OK__ once.
+
+
+
+
+
+To set the two path variables, inside the same __Environment Variables__ window
+and under the second box called __System Variables__, find a variable called
+`Path` and click __Edit__. You will add two directories to the list. For this
+guide, the two entries would look like:
+
+```console
+C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0\lib\x64
+C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0\extras\CUPTI\libx64
+```
+
+Make sure to replace the relevant directory location with the one you have
+installed. _Please note that case sensitivity matters_.
+
+
+
+
+
+### Install TensorFlow GPU
+
+Next, install `tensorflow-gpu` using `pip`. You'll need version 1.7.1. In an
+Anaconda Prompt with the Conda environment ml-agents activated, type in the
+following command to uninstall TensorFlow for cpu and install TensorFlow
+for gpu _(make sure you are connected to the Internet)_:
+
+```sh
+pip uninstall tensorflow
+pip install tensorflow-gpu==1.7.1
+```
+
+Lastly, you should test to see if everything installed properly and that
+TensorFlow can identify your GPU. In the same Anaconda Prompt, open Python
+in the Prompt by calling:
+
+```sh
+python
+```
+
+And then type the following commands:
+
+```python
+import tensorflow as tf
+
+sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
+```
+
+You should see something similar to:
+
+```console
+Found device 0 with properties ...
+```
+
+## Acknowledgments
+
+We would like to thank
+[Jason Weimann](https://unity3d.college/2017/10/25/machine-learning-in-unity3d-setting-up-the-environment-tensorflow-for-agentml-on-windows-10/)
+and
+[Nitish S. Mutha](http://blog.nitishmutha.com/tensorflow/2017/01/22/TensorFlow-with-gpu-for-windows.html)
+for writing the original articles which were used to create this guide.
diff --git a/docs/Learning-Environment-Create-New.md b/docs/Learning-Environment-Create-New.md
index 1fa2f297da..0deee59a76 100644
--- a/docs/Learning-Environment-Create-New.md
+++ b/docs/Learning-Environment-Create-New.md
@@ -367,7 +367,7 @@ With the action and reward logic outlined above, the final version of the
```csharp
public float speed = 10;
-public override void AgentAction(float[] vectorAction, string textAction)
+public override void AgentAction(float[] vectorAction)
{
// Actions, size = 2
Vector3 controlSignal = Vector3.zero;
@@ -411,7 +411,8 @@ with our Agent code.
2. Change **Decision Interval** from `1` to `10`.
3. Drag the Target GameObject from the Hierarchy window to the RollerAgent
Target field.
-4. Modify the Behavior Parameters of the Agent :
+4. Add the Behavior Parameters script with the Add Component button from the RollerAgent Inspector.
+5. Modify the Behavior Parameters of the Agent :
* `Behavior Name` to *RollerBallBrain*
* `Vector Observation` `Space Size` = 8
* `Vector Action` `Space Type` = **Continuous**
diff --git a/docs/Learning-Environment-Design-Agents.md b/docs/Learning-Environment-Design-Agents.md
index 8e6f42a539..6651678fcf 100644
--- a/docs/Learning-Environment-Design-Agents.md
+++ b/docs/Learning-Environment-Design-Agents.md
@@ -245,6 +245,49 @@ as observations directly, this is done automatically by the Agent.
![Agent RenderTexture Debug](images/gridworld.png)
+### Raycast Observations
+Raycasts are an alternative system for the Agent to provide observations based on
+the physical environment. This can be easily implemented by adding a
+RayPerceptionSensorComponent3D (or RayPerceptionSensorComponent2D) to the Agent.
+
+During observations, several rays (or spheres, depending on settings) are cast into
+the physics world, and the objects that are hit determine the observation vector that
+is produced.
+
+![Agent with two RayPerceptionSensorComponent3Ds](images/ray_perception.png)
+
+Both sensor components have several settings:
+ * _Detectable Tags_ A list of strings corresponding to the types of objects that the
+ Agent should be able to distinguish between. For example, in the WallJump example,
+ we use "wall", "goal", and "block" as the list of objects to detect.
+ * _Rays Per Direction_ Determines the number of rays that are cast. One ray is
+ always cast forward, and this many rays are cast to the left and right.
+ * _Max Ray Degrees_ The angle (in degrees) for the outermost rays. 90 degrees
+ corresponds to the left and right of the agent.
+ * _ Sphere Cast Radius_ The size of the sphere used for sphere casting. If set
+ to 0, rays will be used instead of spheres. Rays may be more efficient,
+ especially in complex scenes.
+ * _Ray Length_ The length of the casts
+ * _Observation Stacks_ The number of previous results to "stack" with the cast
+ results. Note that this can be independent of the "Stacked Vectors" setting
+ in `Behavior Parameters`.
+ * _Start Vertical Offset_ (3D only) The vertical offset of the ray start point.
+ * _End Vertical Offset_ (3D only) The vertical offset of the ray end point.
+
+In the example image above, the Agent has two RayPerceptionSensorComponent3Ds.
+Both use 3 Rays Per Direction and 90 Max Ray Degrees. One of the components
+had a vertical offset, so the Agent can tell whether it's clear to jump over
+the wall.
+
+The total size of the created observations is
+```
+(Observation Stacks) * (1 + 2 * Rays Per Direction) * (Num Detectable Tags + 2)
+```
+so the number of rays and tags should be kept as small as possible to reduce the
+amount of data used. Note that this is separate from the State Size defined in
+`Behavior Parameters`, so you don't need to worry about the formula above when
+setting the State Size.
+
## Vector Actions
An action is an instruction from the Policy that the agent carries out. The
diff --git a/docs/Learning-Environment-Executable.md b/docs/Learning-Environment-Executable.md
index c6d527324b..7a566fc32e 100644
--- a/docs/Learning-Environment-Executable.md
+++ b/docs/Learning-Environment-Executable.md
@@ -143,13 +143,10 @@ Mono path[0] = '/Users/dericp/workspace/ml-agents/3DBall.app/Contents/Resources/
Mono config path = '/Users/dericp/workspace/ml-agents/3DBall.app/Contents/MonoBleedingEdge/etc'
INFO:mlagents.envs:
'Ball3DAcademy' started successfully!
-INFO:mlagents.envs:
-'Ball3DAcademy' started successfully!
Unity Academy name: Ball3DAcademy
- Number of Brains: 1
- Number of Training Brains : 1
- Reset Parameters :
+ Reset Parameters : {}
+INFO:mlagents.envs:Connected new brain:
Unity brain name: Ball3DLearning
Number of Visual Observations (per agent): 0
Vector Observation space size (per agent): 8
diff --git a/docs/Migrating.md b/docs/Migrating.md
index 409646cb19..e87714297b 100644
--- a/docs/Migrating.md
+++ b/docs/Migrating.md
@@ -7,13 +7,30 @@ The versions can be found in
# Migrating
-## Migrating from ML-Agents toolkit v0.10 to v0.11
+## Migrating from ML-Agents toolkit v0.11.0 to v0.12.0
+
+### Important Changes
+* Text actions and observations, and custom action and observation protos have been removed.
+* RayPerception3D and RayPerception2D are marked deprecated, and will be removed in a future release. They can be replaced by RayPerceptionSensorComponent3D and RayPerceptionSensorComponent2D.
+* The `Use Heuristic` checkbox in Behavior Parameters has been replaced with a `Behavior Type` dropdown menu. This has the following options:
+ * `Default` corresponds to the previous unchecked behavior, meaning that Agents will train if they connect to a python trainer, otherwise they will performance inference.
+ * `Heuristic Only` means the Agent will always use the `Heuristic()` method. This corresponds to having "Use Heuristic" selected in 0.11.0.
+ * `Inference Only` means the Agent will always perform inference.
+
+### Steps to Migrate
+* We [fixed a bug](https://github.com/Unity-Technologies/ml-agents/pull/2823) in `RayPerception3d.Perceive()` that was causing the `endOffset` to be used incorrectly. However this may produce different behavior from previous versions if you use a non-zero `startOffset`. To reproduce the old behavior, you should increase the the value of `endOffset` by `startOffset`. You can verify your raycasts are performing as expected in scene view using the debug rays.
+* If you use RayPerception3D, replace it with RayPerceptionSensorComponent3D (and similarly for 2D). The settings, such as ray angles and detectable tags, are configured on the component now.
+RayPerception3D would contribute `(# of rays) * (# of tags + 2)` to the State Size in Behavior Parameters, but this is no longer necessary, so you should reduce the State Size by this amount.
+Making this change will require retraining your model, since the observations that RayPerceptionSensorComponent3D produces are different from the old behavior.
+
+## Migrating from ML-Agents toolkit v0.10 to v0.11.0
### Important Changes
* The definition of the gRPC service has changed.
* The online BC training feature has been removed.
* The BroadcastHub has been deprecated. If there is a training Python process, all LearningBrains in the scene will automatically be trained. If there is no Python process, inference will be used.
* The Brain ScriptableObjects have been deprecated. The Brain Parameters are now on the Agent and are referred to as Behavior Parameters. Make sure the Behavior Parameters is attached to the Agent GameObject.
+* To use a heuristic behavior, implement the `Heuristic()` method in the Agent class and check the `use heuristic` checkbox in the Behavior Parameters.
* Several changes were made to the setup for visual observations (i.e. using Cameras or RenderTextures):
* Camera resolutions are no longer stored in the Brain Parameters.
* AgentParameters no longer stores lists of Cameras and RenderTextures
diff --git a/docs/Python-API.md b/docs/Python-API.md
index c97a1e554e..6369430182 100644
--- a/docs/Python-API.md
+++ b/docs/Python-API.md
@@ -76,10 +76,6 @@ A BrainInfo object contains the following fields:
the list corresponds to the nth observation of the Brain.
- **`vector_observations`** : A two dimensional numpy array of dimension `(batch
size, vector observation size)`.
-- **`text_observations`** : A list of string corresponding to the Agents text
- observations.
-- **`memories`** : A two dimensional numpy array of dimension `(batch size,
- memory size)` which corresponds to the memories sent at the previous step.
- **`rewards`** : A list as long as the number of Agents using the Brain
containing the rewards they each obtained at the previous step.
- **`local_done`** : A list as long as the number of Agents using the Brain
@@ -87,9 +83,6 @@ A BrainInfo object contains the following fields:
- **`max_reached`** : A list as long as the number of Agents using the Brain
containing true if the Agents reached their max steps.
- **`agents`** : A list of the unique ids of the Agents using the Brain.
-- **`previous_actions`** : A two dimensional numpy array of dimension `(batch
- size, vector action size)` if the vector action space is continuous and
- `(batch size, number of branches)` if the vector action space is discrete.
Once loaded, you can use your UnityEnvironment object, which referenced by a
variable named `env` in this example, can be used in the following way:
@@ -108,14 +101,10 @@ variable named `env` in this example, can be used in the following way:
`resetParameters` and the values are their corresponding float values.
Define the reset parameters on the Academy Inspector window in the Unity
Editor.
-- **Step : `env.step(action, memory=None, text_action=None)`**
+- **Step : `env.step(action)`**
Sends a step signal to the environment using the actions. For each Brain :
- `action` can be one dimensional arrays or two dimensional arrays if you have
multiple Agents per Brain.
- - `memory` is an optional input that can be used to send a list of floats per
- Agents to be retrieved at the next step.
- - `text_action` is an optional input that be used to send a single string per
- Agent.
Returns a dictionary mapping Brain names to BrainInfo objects.
diff --git a/docs/Readme.md b/docs/Readme.md
index d88962a043..d033511850 100644
--- a/docs/Readme.md
+++ b/docs/Readme.md
@@ -28,7 +28,6 @@
* [Using the Monitor](Feature-Monitor.md)
* [Using the Video Recorder](https://github.com/Unity-Technologies/video-recorder)
* [Using an Executable Environment](Learning-Environment-Executable.md)
- * [Creating Custom Protobuf Messages](Creating-Custom-Protobuf-Messages.md)
## Training
@@ -45,14 +44,6 @@
* [Training with LSTM](Feature-Memory.md)
* [Training Generalized Reinforcement Learning Agents](Training-Generalized-Reinforcement-Learning-Agents.md)
-### Cloud Training (Deprecated)
-Here are the cloud training set-up guides for Azure and AWS. We no longer use them ourselves and
-so they may not be work correctly. We've decided to keep them up just in case they are helpful to
-you.
-
-* [Training on the Cloud with Amazon Web Services](Training-on-Amazon-Web-Service.md)
-* [Training on the Cloud with Microsoft Azure](Training-on-Microsoft-Azure.md)
-
## Inference
* [Unity Inference Engine](Unity-Inference-Engine.md)
@@ -69,4 +60,12 @@ you.
* [API Reference](API-Reference.md)
* [How to use the Python API](Python-API.md)
* [Wrapping Learning Environment as a Gym (+Baselines/Dopamine Integration)](../gym-unity/README.md)
-* [Creating custom protobuf messages](Creating-Custom-Protobuf-Messages.md)
+
+## Deprecated Docs
+We no longer use them ourselves and so they may not be up-to-date.
+We've decided to keep them up just in case they are helpful to you.
+
+* [Training on the Cloud with Amazon Web Services](Training-on-Amazon-Web-Service.md)
+* [Training on the Cloud with Microsoft Azure](Training-on-Microsoft-Azure.md)
+* [Using Docker](Using-Docker.md)
+* [Installation-Windows](Installation-Windows.md)
diff --git a/docs/Training-ML-Agents.md b/docs/Training-ML-Agents.md
index 5cbe56f793..7954b900ac 100644
--- a/docs/Training-ML-Agents.md
+++ b/docs/Training-ML-Agents.md
@@ -195,7 +195,6 @@ example environments are included in the provided config file.
| memory_size | The size of the memory an agent must keep. Used for training with a recurrent neural network. See [Using Recurrent Neural Networks](Feature-Memory.md). | PPO, SAC, BC |
| normalize | Whether to automatically normalize observations. | PPO, SAC |
| num_epoch | The number of passes to make through the experience buffer when performing gradient descent optimization. | PPO |
-<<<<<<< HEAD
| num_layers | The number of hidden layers in the neural network. | PPO, SAC, BC |
| pretraining | Use demonstrations to bootstrap the policy neural network. See [Pretraining Using Demonstrations](Training-PPO.md#optional-pretraining-using-demonstrations). | PPO, SAC |
| reward_signals | The reward signals used to train the policy. Enable Curiosity and GAIL here. See [Reward Signals](Reward-Signals.md) for configuration options. | PPO, SAC, BC |
diff --git a/docs/Using-Docker.md b/docs/Using-Docker.md
new file mode 100644
index 0000000000..f965fc01eb
--- /dev/null
+++ b/docs/Using-Docker.md
@@ -0,0 +1,169 @@
+# Using Docker For ML-Agents (Deprecated)
+
+Note: We no longer use this guide ourselves and so it may not work correctly. We've decided to
+ keep it up just in case it is helpful to you.
+
+We currently offer a solution for Windows and Mac users who would like to do
+training or inference using Docker. This option may be appealing to those who
+would like to avoid installing Python and TensorFlow themselves. The current
+setup forces both TensorFlow and Unity to _only_ rely on the CPU for
+computations. Consequently, our Docker simulation does not use a GPU and uses
+[`Xvfb`](https://en.wikipedia.org/wiki/Xvfb) to do visual rendering. `Xvfb` is a
+utility that enables `ML-Agents` (or any other application) to do rendering
+virtually i.e. it does not assume that the machine running `ML-Agents` has a GPU
+or a display attached to it. This means that rich environments which involve
+agents using camera-based visual observations might be slower.
+
+## Requirements
+
+- Unity _Linux Build Support_ Component
+- [Docker](https://www.docker.com)
+
+## Setup
+
+- [Download](https://unity3d.com/get-unity/download) the Unity Installer and add
+ the _Linux Build Support_ Component
+
+- [Download](https://www.docker.com/community-edition#/download) and install
+ Docker if you don't have it setup on your machine.
+
+- Since Docker runs a container in an environment that is isolated from the host
+ machine, a mounted directory in your host machine is used to share data, e.g.
+ the trainer configuration file, Unity executable, curriculum files and
+ TensorFlow graph. For convenience, we created an empty `unity-volume`
+ directory at the root of the repository for this purpose, but feel free to use
+ any other directory. The remainder of this guide assumes that the
+ `unity-volume` directory is the one used.
+
+## Usage
+
+Using Docker for ML-Agents involves three steps: building the Unity environment
+with specific flags, building a Docker container and, finally, running the
+container. If you are not familiar with building a Unity environment for
+ML-Agents, please read through our [Getting Started with the 3D Balance Ball
+Example](Getting-Started-with-Balance-Ball.md) guide first.
+
+### Build the Environment (Optional)
+
+_If you want to used the Editor to perform training, you can skip this step._
+
+Since Docker typically runs a container sharing a (linux) kernel with the host
+machine, the Unity environment **has** to be built for the **linux platform**.
+When building a Unity environment, please select the following options from the
+the Build Settings window:
+
+- Set the _Target Platform_ to `Linux`
+- Set the _Architecture_ to `x86_64`
+- If the environment does not contain visual observations, you can select the
+ `headless` option here.
+
+Then click `Build`, pick an environment name (e.g. `3DBall`) and set the output
+directory to `unity-volume`. After building, ensure that the file
+`.x86_64` and subdirectory `_Data/` are
+created under `unity-volume`.
+
+![Build Settings For Docker](images/docker_build_settings.png)
+
+### Build the Docker Container
+
+First, make sure the Docker engine is running on your machine. Then build the
+Docker container by calling the following command at the top-level of the
+repository:
+
+```sh
+docker build -t .
+```
+
+Replace `` with a name for the Docker image, e.g.
+`balance.ball.v0.1`.
+
+### Run the Docker Container
+
+Run the Docker container by calling the following command at the top-level of
+the repository:
+
+```sh
+docker run -it --name \
+ --mount type=bind,source="$(pwd)"/unity-volume,target=/unity-volume \
+ -p 5005:5005 \
+ -p 6006:6006 \
+ :latest \
+ --docker-target-name=unity-volume \
+ \
+ --env= \
+ --train \
+ --run-id=
+```
+
+Notes on argument values:
+
+- `` is used to identify the container (in case you want to
+ interrupt and terminate it). This is optional and Docker will generate a
+ random name if this is not set. _Note that this must be unique for every run
+ of a Docker image._
+- `` references the image name used when building the container.
+- `` __(Optional)__: If you are training with a linux
+ executable, this is the name of the executable. If you are training in the
+ Editor, do not pass a `` argument and press the
+ :arrow_forward: button in Unity when the message _"Start training by pressing
+ the Play button in the Unity Editor"_ is displayed on the screen.
+- `source`: Reference to the path in your host OS where you will store the Unity
+ executable.
+- `target`: Tells Docker to mount the `source` path as a disk with this name.
+- `docker-target-name`: Tells the ML-Agents Python package what the name of the
+ disk where it can read the Unity executable and store the graph. **This should
+ therefore be identical to `target`.**
+- `trainer-config-file`, `train`, `run-id`: ML-Agents arguments passed to
+ `mlagents-learn`. `trainer-config-file` is the filename of the trainer config
+ file, `train` trains the algorithm, and `run-id` is used to tag each
+ experiment with a unique identifier. We recommend placing the trainer-config
+ file inside `unity-volume` so that the container has access to the file.
+
+To train with a `3DBall` environment executable, the command would be:
+
+```sh
+docker run -it --name 3DBallContainer.first.trial \
+ --mount type=bind,source="$(pwd)"/unity-volume,target=/unity-volume \
+ -p 5005:5005 \
+ -p 6006:6006 \
+ balance.ball.v0.1:latest 3DBall \
+ --docker-target-name=unity-volume \
+ trainer_config.yaml \
+ --env=3DBall \
+ --train \
+ --run-id=3dball_first_trial
+```
+
+For more detail on Docker mounts, check out
+[these](https://docs.docker.com/storage/bind-mounts/) docs from Docker.
+
+**NOTE** If you are training using docker for environments that use visual observations, you may need to increase the default memory that Docker allocates for the container. For example, see [here](https://docs.docker.com/docker-for-mac/#advanced) for instructions for Docker for Mac.
+
+### Running Tensorboard
+
+You can run Tensorboard to monitor your training instance on http://localhost:6006:
+
+```sh
+docker exec -it tensorboard --logdir=/unity-volume/summaries --host=0.0.0.0
+```
+
+With our previous 3DBall example, this command would look like this:
+```sh
+docker exec -it 3DBallContainer.first.trial tensorboard --logdir=/unity-volume/summaries --host=0.0.0.0
+```
+
+For more details on Tensorboard, check out the documentation about [Using Tensorboard](Using-Tensorboard.md).
+
+### Stopping Container and Saving State
+
+If you are satisfied with the training progress, you can stop the Docker
+container while saving state by either using `Ctrl+C` or `⌘+C` (Mac) or by using
+the following command:
+
+```sh
+docker kill --signal=SIGINT
+```
+
+`` is the name of the container specified in the earlier `docker
+run` command. If you didn't specify one, you can find the randomly generated
+identifier by running `docker container ls`.
diff --git a/docs/images/ray_perception.png b/docs/images/ray_perception.png
new file mode 100644
index 0000000000..6eef39dcd6
Binary files /dev/null and b/docs/images/ray_perception.png differ
diff --git a/gym-unity/gym_unity/__init__.py b/gym-unity/gym_unity/__init__.py
index e69de29bb2..ea370a8e55 100644
--- a/gym-unity/gym_unity/__init__.py
+++ b/gym-unity/gym_unity/__init__.py
@@ -0,0 +1 @@
+__version__ = "0.12.0"
diff --git a/gym-unity/gym_unity/envs/__init__.py b/gym-unity/gym_unity/envs/__init__.py
index 68e6206adf..33daeb3114 100644
--- a/gym-unity/gym_unity/envs/__init__.py
+++ b/gym-unity/gym_unity/envs/__init__.py
@@ -62,9 +62,8 @@ def __init__(
self._n_agents = None
self._multiagent = multiagent
self._flattener = None
- self.game_over = (
- False
- ) # Hidden flag used by Atari environments to determine if the game is over
+ # Hidden flag used by Atari environments to determine if the game is over
+ self.game_over = False
self._allow_multiple_visual_obs = allow_multiple_visual_obs
# Check brain configuration
@@ -103,12 +102,6 @@ def __init__(
"Otherwise, please note that only the first will be provided in the observation."
)
- if brain.num_stacked_vector_observations != 1:
- raise UnityGymException(
- "There can only be one stacked vector observation in a UnityEnvironment "
- "if it is wrapped in a gym."
- )
-
# Check for number of agents in scene.
initial_info = self._env.reset()[self.brain_name]
self._check_agents(len(initial_info.agents))
@@ -241,7 +234,7 @@ def _single_step(self, info):
default_observation,
info.rewards[0],
info.local_done[0],
- {"text_observation": info.text_observations[0], "brain_info": info},
+ {"text_observation": None, "brain_info": info},
)
def _preprocess_single(self, single_visual_obs):
@@ -260,7 +253,7 @@ def _multi_step(self, info):
list(default_observation),
info.rewards,
info.local_done,
- {"text_observation": info.text_observations, "brain_info": info},
+ {"text_observation": None, "brain_info": info},
)
def _preprocess_multi(self, multiple_visual_obs):
@@ -289,7 +282,7 @@ def seed(self, seed=None):
"""Sets the seed for this env's random number generator(s).
Currently not implemented.
"""
- logger.warn("Could not seed environment %s", self.name)
+ logger.warning("Could not seed environment %s", self.name)
return
def _check_agents(self, n_agents):
diff --git a/gym-unity/gym_unity/tests/test_gym.py b/gym-unity/gym_unity/tests/test_gym.py
index cbc35454c5..e77a8bee49 100644
--- a/gym-unity/gym_unity/tests/test_gym.py
+++ b/gym-unity/gym_unity/tests/test_gym.py
@@ -89,7 +89,6 @@ def test_gym_wrapper_visual(mock_env, use_uint8):
def create_mock_brainparams(
number_visual_observations=0,
- num_stacked_vector_observations=1,
vector_action_space_type="continuous",
vector_observation_space_size=3,
vector_action_space_size=None,
@@ -107,9 +106,7 @@ def create_mock_brainparams(
CameraResolution(width=8, height=8, num_channels=3)
for _ in range(number_visual_observations)
]
- mock_brain.return_value.num_stacked_vector_observations = (
- num_stacked_vector_observations
- )
+
mock_brain.return_value.vector_action_space_type = vector_action_space_type
mock_brain.return_value.vector_observation_space_size = (
vector_observation_space_size
@@ -131,7 +128,6 @@ def create_mock_vector_braininfo(num_agents=1, number_visual_observations=0):
mock_braininfo.return_value.visual_observations = [[np.zeros(shape=(8, 8, 3))]]
mock_braininfo.return_value.rewards = num_agents * [1.0]
mock_braininfo.return_value.local_done = num_agents * [False]
- mock_braininfo.return_value.text_observations = num_agents * [""]
mock_braininfo.return_value.agents = range(0, num_agents)
return mock_braininfo()
diff --git a/gym-unity/setup.py b/gym-unity/setup.py
index f724e0dae8..c78d541203 100755
--- a/gym-unity/setup.py
+++ b/gym-unity/setup.py
@@ -4,8 +4,9 @@
import sys
from setuptools import setup, find_packages
from setuptools.command.install import install
+import gym_unity
-VERSION = "0.11.0"
+VERSION = gym_unity.__version__
class VerifyVersionCommand(install):
diff --git a/markdown-link-check.fast.json b/markdown-link-check.fast.json
new file mode 100644
index 0000000000..3f16635c8a
--- /dev/null
+++ b/markdown-link-check.fast.json
@@ -0,0 +1,16 @@
+{
+ "ignorePatterns": [
+ {
+ "pattern": "^http://localhost",
+ "comment": "Ignore local tensorboard links"
+ },
+ {
+ "pattern": "^https://developer.nvidia.com/compute/machine-learning/cudnn/secure",
+ "comment": "Requires login"
+ },
+ {
+ "pattern": "^https?://",
+ "comment": "Skip external links for fast runs."
+ }
+ ]
+}
diff --git a/markdown-link-check.config.json b/markdown-link-check.full.json
similarity index 100%
rename from markdown-link-check.config.json
rename to markdown-link-check.full.json
diff --git a/ml-agents-envs/mlagents/envs/__init__.py b/ml-agents-envs/mlagents/envs/__init__.py
index e69de29bb2..ea370a8e55 100644
--- a/ml-agents-envs/mlagents/envs/__init__.py
+++ b/ml-agents-envs/mlagents/envs/__init__.py
@@ -0,0 +1 @@
+__version__ = "0.12.0"
diff --git a/ml-agents-envs/mlagents/envs/action_info.py b/ml-agents-envs/mlagents/envs/action_info.py
index f6bd4561fc..782223648f 100644
--- a/ml-agents-envs/mlagents/envs/action_info.py
+++ b/ml-agents-envs/mlagents/envs/action_info.py
@@ -5,7 +5,5 @@
class ActionInfo(NamedTuple):
action: Any
- memory: Any
- text: Any
value: Any
outputs: ActionInfoOutputs
diff --git a/ml-agents-envs/mlagents/envs/base_unity_environment.py b/ml-agents-envs/mlagents/envs/base_unity_environment.py
index b588f31cb9..1d4e68be0d 100644
--- a/ml-agents-envs/mlagents/envs/base_unity_environment.py
+++ b/ml-agents-envs/mlagents/envs/base_unity_environment.py
@@ -7,11 +7,7 @@
class BaseUnityEnvironment(ABC):
@abstractmethod
def step(
- self,
- vector_action: Optional[Dict] = None,
- memory: Optional[Dict] = None,
- text_action: Optional[Dict] = None,
- value: Optional[Dict] = None,
+ self, vector_action: Optional[Dict] = None, value: Optional[Dict] = None
) -> AllBrainInfo:
pass
diff --git a/ml-agents-envs/mlagents/envs/brain.py b/ml-agents-envs/mlagents/envs/brain.py
index 1c33f7ec4d..2e2c9bea8f 100644
--- a/ml-agents-envs/mlagents/envs/brain.py
+++ b/ml-agents-envs/mlagents/envs/brain.py
@@ -4,6 +4,7 @@
from mlagents.envs.communicator_objects.agent_info_pb2 import AgentInfoProto
from mlagents.envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
+from mlagents.envs.communicator_objects.observation_pb2 import ObservationProto
from mlagents.envs.timers import hierarchical_timer, timed
from typing import Dict, List, NamedTuple, Optional
from PIL import Image
@@ -20,13 +21,15 @@ class CameraResolution(NamedTuple):
def gray_scale(self) -> bool:
return self.num_channels == 1
+ def __str__(self):
+ return f"CameraResolution({self.height}, {self.width}, {self.num_channels})"
+
class BrainParameters:
def __init__(
self,
brain_name: str,
vector_observation_space_size: int,
- num_stacked_vector_observations: int,
camera_resolutions: List[CameraResolution],
vector_action_space_size: List[int],
vector_action_descriptions: List[str],
@@ -37,7 +40,6 @@ def __init__(
"""
self.brain_name = brain_name
self.vector_observation_space_size = vector_observation_space_size
- self.num_stacked_vector_observations = num_stacked_vector_observations
self.number_visual_observations = len(camera_resolutions)
self.camera_resolutions = camera_resolutions
self.vector_action_space_size = vector_action_space_size
@@ -49,15 +51,15 @@ def __init__(
def __str__(self):
return """Unity brain name: {}
Number of Visual Observations (per agent): {}
+ Camera Resolutions: {}
Vector Observation space size (per agent): {}
- Number of stacked Vector Observation: {}
Vector Action space type: {}
Vector Action space size (per agent): {}
Vector Action descriptions: {}""".format(
self.brain_name,
str(self.number_visual_observations),
+ str([str(cr) for cr in self.camera_resolutions]),
str(self.vector_observation_space_size),
- str(self.num_stacked_vector_observations),
self.vector_action_space_type,
str(self.vector_action_space_size),
", ".join(self.vector_action_descriptions),
@@ -73,18 +75,24 @@ def from_proto(
:return: BrainParameter object.
"""
resolutions = [
- CameraResolution(x.shape[0], x.shape[1], x.shape[2])
- for x in agent_info.compressed_observations
+ CameraResolution(obs.shape[0], obs.shape[1], obs.shape[2])
+ for obs in agent_info.observations
+ if len(obs.shape) >= 3
]
+ total_vector_obs = sum(
+ obs.shape[0] for obs in agent_info.observations if len(obs.shape) == 1
+ )
+
brain_params = BrainParameters(
- brain_param_proto.brain_name,
- brain_param_proto.vector_observation_size,
- brain_param_proto.num_stacked_vector_observations,
- resolutions,
- list(brain_param_proto.vector_action_size),
- list(brain_param_proto.vector_action_descriptions),
- brain_param_proto.vector_action_space_type,
+ brain_name=brain_param_proto.brain_name,
+ vector_observation_space_size=total_vector_obs,
+ camera_resolutions=resolutions,
+ vector_action_space_size=list(brain_param_proto.vector_action_size),
+ vector_action_descriptions=list(
+ brain_param_proto.vector_action_descriptions
+ ),
+ vector_action_space_type=brain_param_proto.vector_action_space_type,
)
return brain_params
@@ -94,59 +102,22 @@ def __init__(
self,
visual_observation,
vector_observation,
- text_observations,
- memory=None,
reward=None,
agents=None,
local_done=None,
- vector_action=None,
- text_action=None,
max_reached=None,
action_mask=None,
- custom_observations=None,
):
"""
Describes experience at current step of all agents linked to a brain.
"""
self.visual_observations = visual_observation
self.vector_observations = vector_observation
- self.text_observations = text_observations
- self.memories = memory
self.rewards = reward
self.local_done = local_done
self.max_reached = max_reached
self.agents = agents
- self.previous_vector_actions = vector_action
- self.previous_text_actions = text_action
self.action_masks = action_mask
- self.custom_observations = custom_observations
-
- def merge(self, other):
- for i in range(len(self.visual_observations)):
- self.visual_observations[i].extend(other.visual_observations[i])
- self.vector_observations = np.append(
- self.vector_observations, other.vector_observations, axis=0
- )
- self.text_observations.extend(other.text_observations)
- self.memories = self.merge_memories(
- self.memories, other.memories, self.agents, other.agents
- )
- self.rewards = safe_concat_lists(self.rewards, other.rewards)
- self.local_done = safe_concat_lists(self.local_done, other.local_done)
- self.max_reached = safe_concat_lists(self.max_reached, other.max_reached)
- self.agents = safe_concat_lists(self.agents, other.agents)
- self.previous_vector_actions = safe_concat_np_ndarray(
- self.previous_vector_actions, other.previous_vector_actions
- )
- self.previous_text_actions = safe_concat_lists(
- self.previous_text_actions, other.previous_text_actions
- )
- self.action_masks = safe_concat_np_ndarray(
- self.action_masks, other.action_masks
- )
- self.custom_observations = safe_concat_lists(
- self.custom_observations, other.custom_observations
- )
@staticmethod
def merge_memories(m1, m2, agents1, agents2):
@@ -194,28 +165,8 @@ def from_agent_proto(
"""
Converts list of agent infos to BrainInfo.
"""
- vis_obs: List[np.ndarray] = []
- for i in range(brain_params.number_visual_observations):
- obs = [
- BrainInfo.process_pixels(
- x.compressed_observations[i].data,
- brain_params.camera_resolutions[i].gray_scale,
- )
- for x in agent_info_list
- ]
- vis_obs += [obs]
- if len(agent_info_list) == 0:
- memory_size = 0
- else:
- memory_size = max(len(x.memories) for x in agent_info_list)
- if memory_size == 0:
- memory = np.zeros((0, 0))
- else:
- [
- x.memories.extend([0] * (memory_size - len(x.memories)))
- for x in agent_info_list
- ]
- memory = np.array([list(x.memories) for x in agent_info_list])
+ vis_obs = BrainInfo._process_visual_observations(brain_params, agent_info_list)
+
total_num_actions = sum(brain_params.vector_action_space_size)
mask_actions = np.ones((len(agent_info_list), total_num_actions))
for agent_index, agent_info in enumerate(agent_info_list):
@@ -230,20 +181,72 @@ def from_agent_proto(
"An agent had a NaN reward for brain " + brain_params.brain_name
)
- if len(agent_info_list) == 0:
- vector_obs = np.zeros(
- (
- 0,
- brain_params.vector_observation_space_size
- * brain_params.num_stacked_vector_observations,
+ vector_obs = BrainInfo._process_vector_observations(
+ brain_params, agent_info_list
+ )
+
+ agents = [f"${worker_id}-{x.id}" for x in agent_info_list]
+ brain_info = BrainInfo(
+ visual_observation=vis_obs,
+ vector_observation=vector_obs,
+ reward=[x.reward if not np.isnan(x.reward) else 0 for x in agent_info_list],
+ agents=agents,
+ local_done=[x.done for x in agent_info_list],
+ max_reached=[x.max_step_reached for x in agent_info_list],
+ action_mask=mask_actions,
+ )
+ return brain_info
+
+ @staticmethod
+ def _process_visual_observations(
+ brain_params: BrainParameters, agent_info_list: List[AgentInfoProto]
+ ) -> List[np.ndarray]:
+
+ visual_observation_protos: List[List[ObservationProto]] = []
+
+ # Grab the visual observations - need this together so we can iterate with the camera observations
+ for agent in agent_info_list:
+ agent_vis: List[ObservationProto] = []
+ for proto_obs in agent.observations:
+ is_visual = len(proto_obs.shape) == 3
+ if is_visual:
+ agent_vis.append(proto_obs)
+ visual_observation_protos.append(agent_vis)
+
+ vis_obs: List[np.ndarray] = []
+ for i in range(brain_params.number_visual_observations):
+ # TODO check compression type, handle uncompressed visuals
+ obs = [
+ BrainInfo.process_pixels(
+ agent_obs[i].compressed_data,
+ brain_params.camera_resolutions[i].gray_scale,
)
- )
+ for agent_obs in visual_observation_protos
+ ]
+ vis_obs += [obs]
+ return vis_obs
+
+ @staticmethod
+ def _process_vector_observations(
+ brain_params: BrainParameters, agent_info_list: List[AgentInfoProto]
+ ) -> np.ndarray:
+ if len(agent_info_list) == 0:
+ vector_obs = np.zeros((0, brain_params.vector_observation_space_size))
else:
stacked_obs = []
has_nan = False
has_inf = False
- for x in agent_info_list:
- np_obs = np.array(x.stacked_vector_observation)
+ for agent_info in agent_info_list:
+ vec_obs = [
+ obs for obs in agent_info.observations if len(obs.shape) == 1
+ ]
+ # Concatenate vector obs
+ proto_vector_obs: List[float] = []
+ for vo in vec_obs:
+ # TODO consider itertools.chain here
+ proto_vector_obs.extend(vo.float_data.data)
+ np_obs = np.array(proto_vector_obs)
+
# Check for NaNs or infs in the observations
# If there's a NaN in the observations, the dot() result will be NaN
# If there's an Inf (either sign) then the result will be Inf
@@ -264,23 +267,7 @@ def from_agent_proto(
logger.warning(
f"An agent had a NaN observation for brain {brain_params.brain_name}"
)
-
- agents = [f"${worker_id}-{x.id}" for x in agent_info_list]
- brain_info = BrainInfo(
- visual_observation=vis_obs,
- vector_observation=vector_obs,
- text_observations=[x.text_observation for x in agent_info_list],
- memory=memory,
- reward=[x.reward if not np.isnan(x.reward) else 0 for x in agent_info_list],
- agents=agents,
- local_done=[x.done for x in agent_info_list],
- vector_action=np.array([x.stored_vector_actions for x in agent_info_list]),
- text_action=[list(x.stored_text_actions) for x in agent_info_list],
- max_reached=[x.max_step_reached for x in agent_info_list],
- custom_observations=[x.custom_observation for x in agent_info_list],
- action_mask=mask_actions,
- )
- return brain_info
+ return vector_obs
def safe_concat_lists(l1: Optional[List], l2: Optional[List]) -> Optional[List]:
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.py b/ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.py
index 9b2454e53d..b1eb87176b 100644
--- a/ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.py
+++ b/ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.py
@@ -13,16 +13,14 @@
_sym_db = _symbol_database.Default()
-from mlagents.envs.communicator_objects import custom_action_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_custom__action__pb2
DESCRIPTOR = _descriptor.FileDescriptor(
name='mlagents/envs/communicator_objects/agent_action.proto',
package='communicator_objects',
syntax='proto3',
- serialized_pb=_b('\n5mlagents/envs/communicator_objects/agent_action.proto\x12\x14\x63ommunicator_objects\x1a\x36mlagents/envs/communicator_objects/custom_action.proto\"\xa1\x01\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\x10\n\x08memories\x18\x03 \x03(\x02\x12\r\n\x05value\x18\x04 \x01(\x02\x12>\n\rcustom_action\x18\x05 \x01(\x0b\x32\'.communicator_objects.CustomActionProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
- ,
- dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_custom__action__pb2.DESCRIPTOR,])
+ serialized_pb=_b('\n5mlagents/envs/communicator_objects/agent_action.proto\x12\x14\x63ommunicator_objects\"K\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\r\n\x05value\x18\x04 \x01(\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x05\x10\x06\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
+)
@@ -42,33 +40,12 @@
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
- name='text_actions', full_name='communicator_objects.AgentActionProto.text_actions', index=1,
- number=2, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='memories', full_name='communicator_objects.AgentActionProto.memories', index=2,
- number=3, type=2, cpp_type=6, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='value', full_name='communicator_objects.AgentActionProto.value', index=3,
+ name='value', full_name='communicator_objects.AgentActionProto.value', index=1,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='custom_action', full_name='communicator_objects.AgentActionProto.custom_action', index=4,
- number=5, type=11, cpp_type=10, label=1,
- has_default_value=False, default_value=None,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None, file=DESCRIPTOR),
],
extensions=[
],
@@ -81,11 +58,10 @@
extension_ranges=[],
oneofs=[
],
- serialized_start=136,
- serialized_end=297,
+ serialized_start=79,
+ serialized_end=154,
)
-_AGENTACTIONPROTO.fields_by_name['custom_action'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_custom__action__pb2._CUSTOMACTIONPROTO
DESCRIPTOR.message_types_by_name['AgentActionProto'] = _AGENTACTIONPROTO
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.pyi b/ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.pyi
index d96652aee0..7b00efbf34 100644
--- a/ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.pyi
+++ b/ml-agents-envs/mlagents/envs/communicator_objects/agent_action_pb2.pyi
@@ -12,14 +12,9 @@ from google.protobuf.message import (
Message as google___protobuf___message___Message,
)
-from mlagents.envs.communicator_objects.custom_action_pb2 import (
- CustomActionProto as mlagents___envs___communicator_objects___custom_action_pb2___CustomActionProto,
-)
-
from typing import (
Iterable as typing___Iterable,
Optional as typing___Optional,
- Text as typing___Text,
)
from typing_extensions import (
@@ -36,28 +31,18 @@ builtin___int = int
class AgentActionProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
vector_actions = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___float]
- text_actions = ... # type: typing___Text
- memories = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___float]
value = ... # type: builtin___float
- @property
- def custom_action(self) -> mlagents___envs___communicator_objects___custom_action_pb2___CustomActionProto: ...
-
def __init__(self,
*,
vector_actions : typing___Optional[typing___Iterable[builtin___float]] = None,
- text_actions : typing___Optional[typing___Text] = None,
- memories : typing___Optional[typing___Iterable[builtin___float]] = None,
value : typing___Optional[builtin___float] = None,
- custom_action : typing___Optional[mlagents___envs___communicator_objects___custom_action_pb2___CustomActionProto] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> AgentActionProto: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
- def HasField(self, field_name: typing_extensions___Literal[u"custom_action"]) -> builtin___bool: ...
- def ClearField(self, field_name: typing_extensions___Literal[u"custom_action",u"memories",u"text_actions",u"value",u"vector_actions"]) -> None: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"value",u"vector_actions"]) -> None: ...
else:
- def HasField(self, field_name: typing_extensions___Literal[u"custom_action",b"custom_action"]) -> builtin___bool: ...
- def ClearField(self, field_name: typing_extensions___Literal[u"custom_action",b"custom_action",u"memories",b"memories",u"text_actions",b"text_actions",u"value",b"value",u"vector_actions",b"vector_actions"]) -> None: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"value",b"value",u"vector_actions",b"vector_actions"]) -> None: ...
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/agent_info_action_pair_pb2.py b/ml-agents-envs/mlagents/envs/communicator_objects/agent_info_action_pair_pb2.py
new file mode 100644
index 0000000000..45323335e7
--- /dev/null
+++ b/ml-agents-envs/mlagents/envs/communicator_objects/agent_info_action_pair_pb2.py
@@ -0,0 +1,83 @@
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: mlagents/envs/communicator_objects/agent_info_action_pair.proto
+
+import sys
+_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf import descriptor_pb2
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+from mlagents.envs.communicator_objects import agent_info_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_agent__info__pb2
+from mlagents.envs.communicator_objects import agent_action_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_agent__action__pb2
+
+
+DESCRIPTOR = _descriptor.FileDescriptor(
+ name='mlagents/envs/communicator_objects/agent_info_action_pair.proto',
+ package='communicator_objects',
+ syntax='proto3',
+ serialized_pb=_b('\n?mlagents/envs/communicator_objects/agent_info_action_pair.proto\x12\x14\x63ommunicator_objects\x1a\x33mlagents/envs/communicator_objects/agent_info.proto\x1a\x35mlagents/envs/communicator_objects/agent_action.proto\"\x91\x01\n\x18\x41gentInfoActionPairProto\x12\x38\n\nagent_info\x18\x01 \x01(\x0b\x32$.communicator_objects.AgentInfoProto\x12;\n\x0b\x61\x63tion_info\x18\x02 \x01(\x0b\x32&.communicator_objects.AgentActionProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
+ ,
+ dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_agent__info__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_agent__action__pb2.DESCRIPTOR,])
+
+
+
+
+_AGENTINFOACTIONPAIRPROTO = _descriptor.Descriptor(
+ name='AgentInfoActionPairProto',
+ full_name='communicator_objects.AgentInfoActionPairProto',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='agent_info', full_name='communicator_objects.AgentInfoActionPairProto.agent_info', index=0,
+ number=1, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='action_info', full_name='communicator_objects.AgentInfoActionPairProto.action_info', index=1,
+ number=2, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=198,
+ serialized_end=343,
+)
+
+_AGENTINFOACTIONPAIRPROTO.fields_by_name['agent_info'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_agent__info__pb2._AGENTINFOPROTO
+_AGENTINFOACTIONPAIRPROTO.fields_by_name['action_info'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_agent__action__pb2._AGENTACTIONPROTO
+DESCRIPTOR.message_types_by_name['AgentInfoActionPairProto'] = _AGENTINFOACTIONPAIRPROTO
+_sym_db.RegisterFileDescriptor(DESCRIPTOR)
+
+AgentInfoActionPairProto = _reflection.GeneratedProtocolMessageType('AgentInfoActionPairProto', (_message.Message,), dict(
+ DESCRIPTOR = _AGENTINFOACTIONPAIRPROTO,
+ __module__ = 'mlagents.envs.communicator_objects.agent_info_action_pair_pb2'
+ # @@protoc_insertion_point(class_scope:communicator_objects.AgentInfoActionPairProto)
+ ))
+_sym_db.RegisterMessage(AgentInfoActionPairProto)
+
+
+DESCRIPTOR.has_options = True
+DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
+# @@protoc_insertion_point(module_scope)
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/agent_info_action_pair_pb2.pyi b/ml-agents-envs/mlagents/envs/communicator_objects/agent_info_action_pair_pb2.pyi
new file mode 100644
index 0000000000..4ebc76c4be
--- /dev/null
+++ b/ml-agents-envs/mlagents/envs/communicator_objects/agent_info_action_pair_pb2.pyi
@@ -0,0 +1,57 @@
+# @generated by generate_proto_mypy_stubs.py. Do not edit!
+import sys
+from google.protobuf.descriptor import (
+ Descriptor as google___protobuf___descriptor___Descriptor,
+)
+
+from google.protobuf.message import (
+ Message as google___protobuf___message___Message,
+)
+
+from mlagents.envs.communicator_objects.agent_action_pb2 import (
+ AgentActionProto as mlagents___envs___communicator_objects___agent_action_pb2___AgentActionProto,
+)
+
+from mlagents.envs.communicator_objects.agent_info_pb2 import (
+ AgentInfoProto as mlagents___envs___communicator_objects___agent_info_pb2___AgentInfoProto,
+)
+
+from typing import (
+ Optional as typing___Optional,
+)
+
+from typing_extensions import (
+ Literal as typing_extensions___Literal,
+)
+
+
+builtin___bool = bool
+builtin___bytes = bytes
+builtin___float = float
+builtin___int = int
+
+
+class AgentInfoActionPairProto(google___protobuf___message___Message):
+ DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
+
+ @property
+ def agent_info(self) -> mlagents___envs___communicator_objects___agent_info_pb2___AgentInfoProto: ...
+
+ @property
+ def action_info(self) -> mlagents___envs___communicator_objects___agent_action_pb2___AgentActionProto: ...
+
+ def __init__(self,
+ *,
+ agent_info : typing___Optional[mlagents___envs___communicator_objects___agent_info_pb2___AgentInfoProto] = None,
+ action_info : typing___Optional[mlagents___envs___communicator_objects___agent_action_pb2___AgentActionProto] = None,
+ ) -> None: ...
+ @classmethod
+ def FromString(cls, s: builtin___bytes) -> AgentInfoActionPairProto: ...
+ def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
+ def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
+ if sys.version_info >= (3,):
+ def HasField(self, field_name: typing_extensions___Literal[u"action_info",u"agent_info"]) -> builtin___bool: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"action_info",u"agent_info"]) -> None: ...
+ else:
+ def HasField(self, field_name: typing_extensions___Literal[u"action_info",b"action_info",u"agent_info",b"agent_info"]) -> builtin___bool: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"action_info",b"action_info",u"agent_info",b"agent_info"]) -> None: ...
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.py b/ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.py
index 8818a369fd..170ef4c973 100644
--- a/ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.py
+++ b/ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.py
@@ -13,17 +13,16 @@
_sym_db = _symbol_database.Default()
-from mlagents.envs.communicator_objects import compressed_observation_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_compressed__observation__pb2
-from mlagents.envs.communicator_objects import custom_observation_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_custom__observation__pb2
+from mlagents.envs.communicator_objects import observation_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_observation__pb2
DESCRIPTOR = _descriptor.FileDescriptor(
name='mlagents/envs/communicator_objects/agent_info.proto',
package='communicator_objects',
syntax='proto3',
- serialized_pb=_b('\n3mlagents/envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a?mlagents/envs/communicator_objects/compressed_observation.proto\x1a;mlagents/envs/communicator_objects/custom_observation.proto\"\x98\x03\n\x0e\x41gentInfoProto\x12\"\n\x1astacked_vector_observation\x18\x01 \x03(\x02\x12\x18\n\x10text_observation\x18\x03 \x01(\t\x12\x1d\n\x15stored_vector_actions\x18\x04 \x03(\x02\x12\x1b\n\x13stored_text_actions\x18\x05 \x01(\t\x12\x10\n\x08memories\x18\x06 \x03(\x02\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12H\n\x12\x63ustom_observation\x18\x0c \x01(\x0b\x32,.communicator_objects.CustomObservationProto\x12Q\n\x17\x63ompressed_observations\x18\r \x03(\x0b\x32\x30.communicator_objects.CompressedObservationProtoJ\x04\x08\x02\x10\x03\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
+ serialized_pb=_b('\n3mlagents/envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents/envs/communicator_objects/observation.proto\"\xd1\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProtoJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
,
- dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_compressed__observation__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_custom__observation__pb2.DESCRIPTOR,])
+ dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_observation__pb2.DESCRIPTOR,])
@@ -36,84 +35,42 @@
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
- name='stacked_vector_observation', full_name='communicator_objects.AgentInfoProto.stacked_vector_observation', index=0,
- number=1, type=2, cpp_type=6, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='text_observation', full_name='communicator_objects.AgentInfoProto.text_observation', index=1,
- number=3, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='stored_vector_actions', full_name='communicator_objects.AgentInfoProto.stored_vector_actions', index=2,
- number=4, type=2, cpp_type=6, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='stored_text_actions', full_name='communicator_objects.AgentInfoProto.stored_text_actions', index=3,
- number=5, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='memories', full_name='communicator_objects.AgentInfoProto.memories', index=4,
- number=6, type=2, cpp_type=6, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='reward', full_name='communicator_objects.AgentInfoProto.reward', index=5,
+ name='reward', full_name='communicator_objects.AgentInfoProto.reward', index=0,
number=7, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
- name='done', full_name='communicator_objects.AgentInfoProto.done', index=6,
+ name='done', full_name='communicator_objects.AgentInfoProto.done', index=1,
number=8, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
- name='max_step_reached', full_name='communicator_objects.AgentInfoProto.max_step_reached', index=7,
+ name='max_step_reached', full_name='communicator_objects.AgentInfoProto.max_step_reached', index=2,
number=9, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
- name='id', full_name='communicator_objects.AgentInfoProto.id', index=8,
+ name='id', full_name='communicator_objects.AgentInfoProto.id', index=3,
number=10, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
- name='action_mask', full_name='communicator_objects.AgentInfoProto.action_mask', index=9,
+ name='action_mask', full_name='communicator_objects.AgentInfoProto.action_mask', index=4,
number=11, type=8, cpp_type=7, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
- name='custom_observation', full_name='communicator_objects.AgentInfoProto.custom_observation', index=10,
- number=12, type=11, cpp_type=10, label=1,
- has_default_value=False, default_value=None,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='compressed_observations', full_name='communicator_objects.AgentInfoProto.compressed_observations', index=11,
+ name='observations', full_name='communicator_objects.AgentInfoProto.observations', index=5,
number=13, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
@@ -131,12 +88,11 @@
extension_ranges=[],
oneofs=[
],
- serialized_start=204,
- serialized_end=612,
+ serialized_start=132,
+ serialized_end=341,
)
-_AGENTINFOPROTO.fields_by_name['custom_observation'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_custom__observation__pb2._CUSTOMOBSERVATIONPROTO
-_AGENTINFOPROTO.fields_by_name['compressed_observations'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_compressed__observation__pb2._COMPRESSEDOBSERVATIONPROTO
+_AGENTINFOPROTO.fields_by_name['observations'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_observation__pb2._OBSERVATIONPROTO
DESCRIPTOR.message_types_by_name['AgentInfoProto'] = _AGENTINFOPROTO
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.pyi b/ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.pyi
index 4aaa82c3a7..6c70bf2054 100644
--- a/ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.pyi
+++ b/ml-agents-envs/mlagents/envs/communicator_objects/agent_info_pb2.pyi
@@ -13,18 +13,13 @@ from google.protobuf.message import (
Message as google___protobuf___message___Message,
)
-from mlagents.envs.communicator_objects.compressed_observation_pb2 import (
- CompressedObservationProto as mlagents___envs___communicator_objects___compressed_observation_pb2___CompressedObservationProto,
-)
-
-from mlagents.envs.communicator_objects.custom_observation_pb2 import (
- CustomObservationProto as mlagents___envs___communicator_objects___custom_observation_pb2___CustomObservationProto,
+from mlagents.envs.communicator_objects.observation_pb2 import (
+ ObservationProto as mlagents___envs___communicator_objects___observation_pb2___ObservationProto,
)
from typing import (
Iterable as typing___Iterable,
Optional as typing___Optional,
- Text as typing___Text,
)
from typing_extensions import (
@@ -40,11 +35,6 @@ builtin___int = int
class AgentInfoProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
- stacked_vector_observation = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___float]
- text_observation = ... # type: typing___Text
- stored_vector_actions = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___float]
- stored_text_actions = ... # type: typing___Text
- memories = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___float]
reward = ... # type: builtin___float
done = ... # type: builtin___bool
max_step_reached = ... # type: builtin___bool
@@ -52,33 +42,22 @@ class AgentInfoProto(google___protobuf___message___Message):
action_mask = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___bool]
@property
- def custom_observation(self) -> mlagents___envs___communicator_objects___custom_observation_pb2___CustomObservationProto: ...
-
- @property
- def compressed_observations(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents___envs___communicator_objects___compressed_observation_pb2___CompressedObservationProto]: ...
+ def observations(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents___envs___communicator_objects___observation_pb2___ObservationProto]: ...
def __init__(self,
*,
- stacked_vector_observation : typing___Optional[typing___Iterable[builtin___float]] = None,
- text_observation : typing___Optional[typing___Text] = None,
- stored_vector_actions : typing___Optional[typing___Iterable[builtin___float]] = None,
- stored_text_actions : typing___Optional[typing___Text] = None,
- memories : typing___Optional[typing___Iterable[builtin___float]] = None,
reward : typing___Optional[builtin___float] = None,
done : typing___Optional[builtin___bool] = None,
max_step_reached : typing___Optional[builtin___bool] = None,
id : typing___Optional[builtin___int] = None,
action_mask : typing___Optional[typing___Iterable[builtin___bool]] = None,
- custom_observation : typing___Optional[mlagents___envs___communicator_objects___custom_observation_pb2___CustomObservationProto] = None,
- compressed_observations : typing___Optional[typing___Iterable[mlagents___envs___communicator_objects___compressed_observation_pb2___CompressedObservationProto]] = None,
+ observations : typing___Optional[typing___Iterable[mlagents___envs___communicator_objects___observation_pb2___ObservationProto]] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> AgentInfoProto: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
- def HasField(self, field_name: typing_extensions___Literal[u"custom_observation"]) -> builtin___bool: ...
- def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"compressed_observations",u"custom_observation",u"done",u"id",u"max_step_reached",u"memories",u"reward",u"stacked_vector_observation",u"stored_text_actions",u"stored_vector_actions",u"text_observation"]) -> None: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",u"done",u"id",u"max_step_reached",u"observations",u"reward"]) -> None: ...
else:
- def HasField(self, field_name: typing_extensions___Literal[u"custom_observation",b"custom_observation"]) -> builtin___bool: ...
- def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"compressed_observations",b"compressed_observations",u"custom_observation",b"custom_observation",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"memories",b"memories",u"reward",b"reward",u"stacked_vector_observation",b"stacked_vector_observation",u"stored_text_actions",b"stored_text_actions",u"stored_vector_actions",b"stored_vector_actions",u"text_observation",b"text_observation"]) -> None: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"action_mask",b"action_mask",u"done",b"done",u"id",b"id",u"max_step_reached",b"max_step_reached",u"observations",b"observations",u"reward",b"reward"]) -> None: ...
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/brain_parameters_pb2.py b/ml-agents-envs/mlagents/envs/communicator_objects/brain_parameters_pb2.py
index 31007b6869..97b4a08645 100644
--- a/ml-agents-envs/mlagents/envs/communicator_objects/brain_parameters_pb2.py
+++ b/ml-agents-envs/mlagents/envs/communicator_objects/brain_parameters_pb2.py
@@ -20,7 +20,7 @@
name='mlagents/envs/communicator_objects/brain_parameters.proto',
package='communicator_objects',
syntax='proto3',
- serialized_pb=_b('\n9mlagents/envs/communicator_objects/brain_parameters.proto\x12\x14\x63ommunicator_objects\x1a\x33mlagents/envs/communicator_objects/space_type.proto\"\x97\x02\n\x14\x42rainParametersProto\x12\x1f\n\x17vector_observation_size\x18\x01 \x01(\x05\x12\'\n\x1fnum_stacked_vector_observations\x18\x02 \x01(\x05\x12\x1a\n\x12vector_action_size\x18\x03 \x03(\x05\x12\"\n\x1avector_action_descriptions\x18\x05 \x03(\t\x12\x46\n\x18vector_action_space_type\x18\x06 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12\x12\n\nbrain_name\x18\x07 \x01(\t\x12\x13\n\x0bis_training\x18\x08 \x01(\x08J\x04\x08\x04\x10\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
+ serialized_pb=_b('\n9mlagents/envs/communicator_objects/brain_parameters.proto\x12\x14\x63ommunicator_objects\x1a\x33mlagents/envs/communicator_objects/space_type.proto\"\xd9\x01\n\x14\x42rainParametersProto\x12\x1a\n\x12vector_action_size\x18\x03 \x03(\x05\x12\"\n\x1avector_action_descriptions\x18\x05 \x03(\t\x12\x46\n\x18vector_action_space_type\x18\x06 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12\x12\n\nbrain_name\x18\x07 \x01(\t\x12\x13\n\x0bis_training\x18\x08 \x01(\x08J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x04\x10\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
,
dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_space__type__pb2.DESCRIPTOR,])
@@ -35,49 +35,35 @@
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
- name='vector_observation_size', full_name='communicator_objects.BrainParametersProto.vector_observation_size', index=0,
- number=1, type=5, cpp_type=1, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='num_stacked_vector_observations', full_name='communicator_objects.BrainParametersProto.num_stacked_vector_observations', index=1,
- number=2, type=5, cpp_type=1, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='vector_action_size', full_name='communicator_objects.BrainParametersProto.vector_action_size', index=2,
+ name='vector_action_size', full_name='communicator_objects.BrainParametersProto.vector_action_size', index=0,
number=3, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
- name='vector_action_descriptions', full_name='communicator_objects.BrainParametersProto.vector_action_descriptions', index=3,
+ name='vector_action_descriptions', full_name='communicator_objects.BrainParametersProto.vector_action_descriptions', index=1,
number=5, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
- name='vector_action_space_type', full_name='communicator_objects.BrainParametersProto.vector_action_space_type', index=4,
+ name='vector_action_space_type', full_name='communicator_objects.BrainParametersProto.vector_action_space_type', index=2,
number=6, type=14, cpp_type=8, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
- name='brain_name', full_name='communicator_objects.BrainParametersProto.brain_name', index=5,
+ name='brain_name', full_name='communicator_objects.BrainParametersProto.brain_name', index=3,
number=7, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
- name='is_training', full_name='communicator_objects.BrainParametersProto.is_training', index=6,
+ name='is_training', full_name='communicator_objects.BrainParametersProto.is_training', index=4,
number=8, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
@@ -96,7 +82,7 @@
oneofs=[
],
serialized_start=137,
- serialized_end=416,
+ serialized_end=354,
)
_BRAINPARAMETERSPROTO.fields_by_name['vector_action_space_type'].enum_type = mlagents_dot_envs_dot_communicator__objects_dot_space__type__pb2._SPACETYPEPROTO
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/brain_parameters_pb2.pyi b/ml-agents-envs/mlagents/envs/communicator_objects/brain_parameters_pb2.pyi
index f3d1c30597..0ac7699a0a 100644
--- a/ml-agents-envs/mlagents/envs/communicator_objects/brain_parameters_pb2.pyi
+++ b/ml-agents-envs/mlagents/envs/communicator_objects/brain_parameters_pb2.pyi
@@ -35,8 +35,6 @@ builtin___int = int
class BrainParametersProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
- vector_observation_size = ... # type: builtin___int
- num_stacked_vector_observations = ... # type: builtin___int
vector_action_size = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
vector_action_descriptions = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text]
vector_action_space_type = ... # type: mlagents___envs___communicator_objects___space_type_pb2___SpaceTypeProto
@@ -45,8 +43,6 @@ class BrainParametersProto(google___protobuf___message___Message):
def __init__(self,
*,
- vector_observation_size : typing___Optional[builtin___int] = None,
- num_stacked_vector_observations : typing___Optional[builtin___int] = None,
vector_action_size : typing___Optional[typing___Iterable[builtin___int]] = None,
vector_action_descriptions : typing___Optional[typing___Iterable[typing___Text]] = None,
vector_action_space_type : typing___Optional[mlagents___envs___communicator_objects___space_type_pb2___SpaceTypeProto] = None,
@@ -58,6 +54,6 @@ class BrainParametersProto(google___protobuf___message___Message):
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
- def ClearField(self, field_name: typing_extensions___Literal[u"brain_name",u"is_training",u"num_stacked_vector_observations",u"vector_action_descriptions",u"vector_action_size",u"vector_action_space_type",u"vector_observation_size"]) -> None: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"brain_name",u"is_training",u"vector_action_descriptions",u"vector_action_size",u"vector_action_space_type"]) -> None: ...
else:
- def ClearField(self, field_name: typing_extensions___Literal[u"brain_name",b"brain_name",u"is_training",b"is_training",u"num_stacked_vector_observations",b"num_stacked_vector_observations",u"vector_action_descriptions",b"vector_action_descriptions",u"vector_action_size",b"vector_action_size",u"vector_action_space_type",b"vector_action_space_type",u"vector_observation_size",b"vector_observation_size"]) -> None: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"brain_name",b"brain_name",u"is_training",b"is_training",u"vector_action_descriptions",b"vector_action_descriptions",u"vector_action_size",b"vector_action_size",u"vector_action_space_type",b"vector_action_space_type"]) -> None: ...
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/compressed_observation_pb2.py b/ml-agents-envs/mlagents/envs/communicator_objects/compressed_observation_pb2.py
deleted file mode 100644
index 5ff3611f53..0000000000
--- a/ml-agents-envs/mlagents/envs/communicator_objects/compressed_observation_pb2.py
+++ /dev/null
@@ -1,113 +0,0 @@
-# Generated by the protocol buffer compiler. DO NOT EDIT!
-# source: mlagents/envs/communicator_objects/compressed_observation.proto
-
-import sys
-_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
-from google.protobuf.internal import enum_type_wrapper
-from google.protobuf import descriptor as _descriptor
-from google.protobuf import message as _message
-from google.protobuf import reflection as _reflection
-from google.protobuf import symbol_database as _symbol_database
-from google.protobuf import descriptor_pb2
-# @@protoc_insertion_point(imports)
-
-_sym_db = _symbol_database.Default()
-
-
-
-
-DESCRIPTOR = _descriptor.FileDescriptor(
- name='mlagents/envs/communicator_objects/compressed_observation.proto',
- package='communicator_objects',
- syntax='proto3',
- serialized_pb=_b('\n?mlagents/envs/communicator_objects/compressed_observation.proto\x12\x14\x63ommunicator_objects\"\x7f\n\x1a\x43ompressedObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
-)
-
-_COMPRESSIONTYPEPROTO = _descriptor.EnumDescriptor(
- name='CompressionTypeProto',
- full_name='communicator_objects.CompressionTypeProto',
- filename=None,
- file=DESCRIPTOR,
- values=[
- _descriptor.EnumValueDescriptor(
- name='NONE', index=0, number=0,
- options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='PNG', index=1, number=1,
- options=None,
- type=None),
- ],
- containing_type=None,
- options=None,
- serialized_start=218,
- serialized_end=259,
-)
-_sym_db.RegisterEnumDescriptor(_COMPRESSIONTYPEPROTO)
-
-CompressionTypeProto = enum_type_wrapper.EnumTypeWrapper(_COMPRESSIONTYPEPROTO)
-NONE = 0
-PNG = 1
-
-
-
-_COMPRESSEDOBSERVATIONPROTO = _descriptor.Descriptor(
- name='CompressedObservationProto',
- full_name='communicator_objects.CompressedObservationProto',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- _descriptor.FieldDescriptor(
- name='shape', full_name='communicator_objects.CompressedObservationProto.shape', index=0,
- number=1, type=5, cpp_type=1, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='compression_type', full_name='communicator_objects.CompressedObservationProto.compression_type', index=1,
- number=2, type=14, cpp_type=8, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='data', full_name='communicator_objects.CompressedObservationProto.data', index=2,
- number=3, type=12, cpp_type=9, label=1,
- has_default_value=False, default_value=_b(""),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None, file=DESCRIPTOR),
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- options=None,
- is_extendable=False,
- syntax='proto3',
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=89,
- serialized_end=216,
-)
-
-_COMPRESSEDOBSERVATIONPROTO.fields_by_name['compression_type'].enum_type = _COMPRESSIONTYPEPROTO
-DESCRIPTOR.message_types_by_name['CompressedObservationProto'] = _COMPRESSEDOBSERVATIONPROTO
-DESCRIPTOR.enum_types_by_name['CompressionTypeProto'] = _COMPRESSIONTYPEPROTO
-_sym_db.RegisterFileDescriptor(DESCRIPTOR)
-
-CompressedObservationProto = _reflection.GeneratedProtocolMessageType('CompressedObservationProto', (_message.Message,), dict(
- DESCRIPTOR = _COMPRESSEDOBSERVATIONPROTO,
- __module__ = 'mlagents.envs.communicator_objects.compressed_observation_pb2'
- # @@protoc_insertion_point(class_scope:communicator_objects.CompressedObservationProto)
- ))
-_sym_db.RegisterMessage(CompressedObservationProto)
-
-
-DESCRIPTOR.has_options = True
-DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
-# @@protoc_insertion_point(module_scope)
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/custom_action_pb2.py b/ml-agents-envs/mlagents/envs/communicator_objects/custom_action_pb2.py
deleted file mode 100644
index ecead71d76..0000000000
--- a/ml-agents-envs/mlagents/envs/communicator_objects/custom_action_pb2.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# Generated by the protocol buffer compiler. DO NOT EDIT!
-# source: mlagents/envs/communicator_objects/custom_action.proto
-
-import sys
-_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
-from google.protobuf import descriptor as _descriptor
-from google.protobuf import message as _message
-from google.protobuf import reflection as _reflection
-from google.protobuf import symbol_database as _symbol_database
-from google.protobuf import descriptor_pb2
-# @@protoc_insertion_point(imports)
-
-_sym_db = _symbol_database.Default()
-
-
-
-
-DESCRIPTOR = _descriptor.FileDescriptor(
- name='mlagents/envs/communicator_objects/custom_action.proto',
- package='communicator_objects',
- syntax='proto3',
- serialized_pb=_b('\n6mlagents/envs/communicator_objects/custom_action.proto\x12\x14\x63ommunicator_objects\"\x13\n\x11\x43ustomActionProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
-)
-
-
-
-
-_CUSTOMACTIONPROTO = _descriptor.Descriptor(
- name='CustomActionProto',
- full_name='communicator_objects.CustomActionProto',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- options=None,
- is_extendable=False,
- syntax='proto3',
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=80,
- serialized_end=99,
-)
-
-DESCRIPTOR.message_types_by_name['CustomActionProto'] = _CUSTOMACTIONPROTO
-_sym_db.RegisterFileDescriptor(DESCRIPTOR)
-
-CustomActionProto = _reflection.GeneratedProtocolMessageType('CustomActionProto', (_message.Message,), dict(
- DESCRIPTOR = _CUSTOMACTIONPROTO,
- __module__ = 'mlagents.envs.communicator_objects.custom_action_pb2'
- # @@protoc_insertion_point(class_scope:communicator_objects.CustomActionProto)
- ))
-_sym_db.RegisterMessage(CustomActionProto)
-
-
-DESCRIPTOR.has_options = True
-DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
-# @@protoc_insertion_point(module_scope)
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/custom_action_pb2.pyi b/ml-agents-envs/mlagents/envs/communicator_objects/custom_action_pb2.pyi
deleted file mode 100644
index 2d834a3133..0000000000
--- a/ml-agents-envs/mlagents/envs/communicator_objects/custom_action_pb2.pyi
+++ /dev/null
@@ -1,23 +0,0 @@
-# @generated by generate_proto_mypy_stubs.py. Do not edit!
-import sys
-from google.protobuf.descriptor import (
- Descriptor as google___protobuf___descriptor___Descriptor,
-)
-
-from google.protobuf.message import (
- Message as google___protobuf___message___Message,
-)
-
-
-builtin___bytes = bytes
-
-
-class CustomActionProto(google___protobuf___message___Message):
- DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
-
- def __init__(self,
- ) -> None: ...
- @classmethod
- def FromString(cls, s: builtin___bytes) -> CustomActionProto: ...
- def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
- def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/custom_observation_pb2.py b/ml-agents-envs/mlagents/envs/communicator_objects/custom_observation_pb2.py
deleted file mode 100644
index d0f89db251..0000000000
--- a/ml-agents-envs/mlagents/envs/communicator_objects/custom_observation_pb2.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# Generated by the protocol buffer compiler. DO NOT EDIT!
-# source: mlagents/envs/communicator_objects/custom_observation.proto
-
-import sys
-_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
-from google.protobuf import descriptor as _descriptor
-from google.protobuf import message as _message
-from google.protobuf import reflection as _reflection
-from google.protobuf import symbol_database as _symbol_database
-from google.protobuf import descriptor_pb2
-# @@protoc_insertion_point(imports)
-
-_sym_db = _symbol_database.Default()
-
-
-
-
-DESCRIPTOR = _descriptor.FileDescriptor(
- name='mlagents/envs/communicator_objects/custom_observation.proto',
- package='communicator_objects',
- syntax='proto3',
- serialized_pb=_b('\n;mlagents/envs/communicator_objects/custom_observation.proto\x12\x14\x63ommunicator_objects\"\x18\n\x16\x43ustomObservationProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
-)
-
-
-
-
-_CUSTOMOBSERVATIONPROTO = _descriptor.Descriptor(
- name='CustomObservationProto',
- full_name='communicator_objects.CustomObservationProto',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- options=None,
- is_extendable=False,
- syntax='proto3',
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=85,
- serialized_end=109,
-)
-
-DESCRIPTOR.message_types_by_name['CustomObservationProto'] = _CUSTOMOBSERVATIONPROTO
-_sym_db.RegisterFileDescriptor(DESCRIPTOR)
-
-CustomObservationProto = _reflection.GeneratedProtocolMessageType('CustomObservationProto', (_message.Message,), dict(
- DESCRIPTOR = _CUSTOMOBSERVATIONPROTO,
- __module__ = 'mlagents.envs.communicator_objects.custom_observation_pb2'
- # @@protoc_insertion_point(class_scope:communicator_objects.CustomObservationProto)
- ))
-_sym_db.RegisterMessage(CustomObservationProto)
-
-
-DESCRIPTOR.has_options = True
-DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
-# @@protoc_insertion_point(module_scope)
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/custom_observation_pb2.pyi b/ml-agents-envs/mlagents/envs/communicator_objects/custom_observation_pb2.pyi
deleted file mode 100644
index 3e5f324325..0000000000
--- a/ml-agents-envs/mlagents/envs/communicator_objects/custom_observation_pb2.pyi
+++ /dev/null
@@ -1,23 +0,0 @@
-# @generated by generate_proto_mypy_stubs.py. Do not edit!
-import sys
-from google.protobuf.descriptor import (
- Descriptor as google___protobuf___descriptor___Descriptor,
-)
-
-from google.protobuf.message import (
- Message as google___protobuf___message___Message,
-)
-
-
-builtin___bytes = bytes
-
-
-class CustomObservationProto(google___protobuf___message___Message):
- DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
-
- def __init__(self,
- ) -> None: ...
- @classmethod
- def FromString(cls, s: builtin___bytes) -> CustomObservationProto: ...
- def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
- def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/observation_pb2.py b/ml-agents-envs/mlagents/envs/communicator_objects/observation_pb2.py
new file mode 100644
index 0000000000..9f4e9491f6
--- /dev/null
+++ b/ml-agents-envs/mlagents/envs/communicator_objects/observation_pb2.py
@@ -0,0 +1,169 @@
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: mlagents/envs/communicator_objects/observation.proto
+
+import sys
+_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
+from google.protobuf.internal import enum_type_wrapper
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf import descriptor_pb2
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor.FileDescriptor(
+ name='mlagents/envs/communicator_objects/observation.proto',
+ package='communicator_objects',
+ syntax='proto3',
+ serialized_pb=_b('\n4mlagents/envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\xf9\x01\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
+)
+
+_COMPRESSIONTYPEPROTO = _descriptor.EnumDescriptor(
+ name='CompressionTypeProto',
+ full_name='communicator_objects.CompressionTypeProto',
+ filename=None,
+ file=DESCRIPTOR,
+ values=[
+ _descriptor.EnumValueDescriptor(
+ name='NONE', index=0, number=0,
+ options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='PNG', index=1, number=1,
+ options=None,
+ type=None),
+ ],
+ containing_type=None,
+ options=None,
+ serialized_start=330,
+ serialized_end=371,
+)
+_sym_db.RegisterEnumDescriptor(_COMPRESSIONTYPEPROTO)
+
+CompressionTypeProto = enum_type_wrapper.EnumTypeWrapper(_COMPRESSIONTYPEPROTO)
+NONE = 0
+PNG = 1
+
+
+
+_OBSERVATIONPROTO_FLOATDATA = _descriptor.Descriptor(
+ name='FloatData',
+ full_name='communicator_objects.ObservationProto.FloatData',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='data', full_name='communicator_objects.ObservationProto.FloatData.data', index=0,
+ number=1, type=2, cpp_type=6, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=283,
+ serialized_end=308,
+)
+
+_OBSERVATIONPROTO = _descriptor.Descriptor(
+ name='ObservationProto',
+ full_name='communicator_objects.ObservationProto',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='shape', full_name='communicator_objects.ObservationProto.shape', index=0,
+ number=1, type=5, cpp_type=1, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='compression_type', full_name='communicator_objects.ObservationProto.compression_type', index=1,
+ number=2, type=14, cpp_type=8, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='compressed_data', full_name='communicator_objects.ObservationProto.compressed_data', index=2,
+ number=3, type=12, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b(""),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='float_data', full_name='communicator_objects.ObservationProto.float_data', index=3,
+ number=4, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[_OBSERVATIONPROTO_FLOATDATA, ],
+ enum_types=[
+ ],
+ options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ _descriptor.OneofDescriptor(
+ name='observation_data', full_name='communicator_objects.ObservationProto.observation_data',
+ index=0, containing_type=None, fields=[]),
+ ],
+ serialized_start=79,
+ serialized_end=328,
+)
+
+_OBSERVATIONPROTO_FLOATDATA.containing_type = _OBSERVATIONPROTO
+_OBSERVATIONPROTO.fields_by_name['compression_type'].enum_type = _COMPRESSIONTYPEPROTO
+_OBSERVATIONPROTO.fields_by_name['float_data'].message_type = _OBSERVATIONPROTO_FLOATDATA
+_OBSERVATIONPROTO.oneofs_by_name['observation_data'].fields.append(
+ _OBSERVATIONPROTO.fields_by_name['compressed_data'])
+_OBSERVATIONPROTO.fields_by_name['compressed_data'].containing_oneof = _OBSERVATIONPROTO.oneofs_by_name['observation_data']
+_OBSERVATIONPROTO.oneofs_by_name['observation_data'].fields.append(
+ _OBSERVATIONPROTO.fields_by_name['float_data'])
+_OBSERVATIONPROTO.fields_by_name['float_data'].containing_oneof = _OBSERVATIONPROTO.oneofs_by_name['observation_data']
+DESCRIPTOR.message_types_by_name['ObservationProto'] = _OBSERVATIONPROTO
+DESCRIPTOR.enum_types_by_name['CompressionTypeProto'] = _COMPRESSIONTYPEPROTO
+_sym_db.RegisterFileDescriptor(DESCRIPTOR)
+
+ObservationProto = _reflection.GeneratedProtocolMessageType('ObservationProto', (_message.Message,), dict(
+
+ FloatData = _reflection.GeneratedProtocolMessageType('FloatData', (_message.Message,), dict(
+ DESCRIPTOR = _OBSERVATIONPROTO_FLOATDATA,
+ __module__ = 'mlagents.envs.communicator_objects.observation_pb2'
+ # @@protoc_insertion_point(class_scope:communicator_objects.ObservationProto.FloatData)
+ ))
+ ,
+ DESCRIPTOR = _OBSERVATIONPROTO,
+ __module__ = 'mlagents.envs.communicator_objects.observation_pb2'
+ # @@protoc_insertion_point(class_scope:communicator_objects.ObservationProto)
+ ))
+_sym_db.RegisterMessage(ObservationProto)
+_sym_db.RegisterMessage(ObservationProto.FloatData)
+
+
+DESCRIPTOR.has_options = True
+DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects'))
+# @@protoc_insertion_point(module_scope)
diff --git a/ml-agents-envs/mlagents/envs/communicator_objects/compressed_observation_pb2.pyi b/ml-agents-envs/mlagents/envs/communicator_objects/observation_pb2.pyi
similarity index 53%
rename from ml-agents-envs/mlagents/envs/communicator_objects/compressed_observation_pb2.pyi
rename to ml-agents-envs/mlagents/envs/communicator_objects/observation_pb2.pyi
index 41a3e719e4..79681430fb 100644
--- a/ml-agents-envs/mlagents/envs/communicator_objects/compressed_observation_pb2.pyi
+++ b/ml-agents-envs/mlagents/envs/communicator_objects/observation_pb2.pyi
@@ -50,23 +50,47 @@ class CompressionTypeProto(builtin___int):
NONE = typing___cast('CompressionTypeProto', 0)
PNG = typing___cast('CompressionTypeProto', 1)
-class CompressedObservationProto(google___protobuf___message___Message):
+class ObservationProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
+ class FloatData(google___protobuf___message___Message):
+ DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
+ data = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___float]
+
+ def __init__(self,
+ *,
+ data : typing___Optional[typing___Iterable[builtin___float]] = None,
+ ) -> None: ...
+ @classmethod
+ def FromString(cls, s: builtin___bytes) -> ObservationProto.FloatData: ...
+ def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
+ def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
+ if sys.version_info >= (3,):
+ def ClearField(self, field_name: typing_extensions___Literal[u"data"]) -> None: ...
+ else:
+ def ClearField(self, field_name: typing_extensions___Literal[u"data",b"data"]) -> None: ...
+
shape = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
compression_type = ... # type: CompressionTypeProto
- data = ... # type: builtin___bytes
+ compressed_data = ... # type: builtin___bytes
+
+ @property
+ def float_data(self) -> ObservationProto.FloatData: ...
def __init__(self,
*,
shape : typing___Optional[typing___Iterable[builtin___int]] = None,
compression_type : typing___Optional[CompressionTypeProto] = None,
- data : typing___Optional[builtin___bytes] = None,
+ compressed_data : typing___Optional[builtin___bytes] = None,
+ float_data : typing___Optional[ObservationProto.FloatData] = None,
) -> None: ...
@classmethod
- def FromString(cls, s: builtin___bytes) -> CompressedObservationProto: ...
+ def FromString(cls, s: builtin___bytes) -> ObservationProto: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
- def ClearField(self, field_name: typing_extensions___Literal[u"compression_type",u"data",u"shape"]) -> None: ...
+ def HasField(self, field_name: typing_extensions___Literal[u"compressed_data",u"float_data",u"observation_data"]) -> builtin___bool: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"compressed_data",u"compression_type",u"float_data",u"observation_data",u"shape"]) -> None: ...
else:
- def ClearField(self, field_name: typing_extensions___Literal[u"compression_type",b"compression_type",u"data",b"data",u"shape",b"shape"]) -> None: ...
+ def HasField(self, field_name: typing_extensions___Literal[u"compressed_data",b"compressed_data",u"float_data",b"float_data",u"observation_data",b"observation_data"]) -> builtin___bool: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"float_data",b"float_data",u"observation_data",b"observation_data",u"shape",b"shape"]) -> None: ...
+ def WhichOneof(self, oneof_group: typing_extensions___Literal[u"observation_data",b"observation_data"]) -> typing_extensions___Literal["compressed_data","float_data"]: ...
diff --git a/ml-agents-envs/mlagents/envs/env_manager.py b/ml-agents-envs/mlagents/envs/env_manager.py
index 8f3499104e..fa48a57a3e 100644
--- a/ml-agents-envs/mlagents/envs/env_manager.py
+++ b/ml-agents-envs/mlagents/envs/env_manager.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import List, Dict, NamedTuple, Optional
+from typing import Any, List, Dict, NamedTuple, Optional
from mlagents.envs.brain import AllBrainInfo, BrainParameters
from mlagents.envs.policy import Policy
from mlagents.envs.action_info import ActionInfo
@@ -10,6 +10,13 @@ class EnvironmentStep(NamedTuple):
current_all_brain_info: AllBrainInfo
brain_name_to_action_info: Optional[Dict[str, ActionInfo]]
+ def has_actions_for_brain(self, brain_name: str) -> bool:
+ return (
+ self.brain_name_to_action_info is not None
+ and brain_name in self.brain_name_to_action_info
+ and self.brain_name_to_action_info[brain_name].outputs is not None
+ )
+
class EnvManager(ABC):
def __init__(self):
@@ -24,7 +31,10 @@ def step(self) -> List[EnvironmentStep]:
@abstractmethod
def reset(
- self, config: Dict = None, train_mode: bool = True
+ self,
+ config: Dict = None,
+ train_mode: bool = True,
+ custom_reset_parameters: Any = None,
) -> List[EnvironmentStep]:
pass
diff --git a/ml-agents-envs/mlagents/envs/environment.py b/ml-agents-envs/mlagents/envs/environment.py
index 52fd8132cd..57d30956bf 100644
--- a/ml-agents-envs/mlagents/envs/environment.py
+++ b/ml-agents-envs/mlagents/envs/environment.py
@@ -28,7 +28,6 @@
)
from mlagents.envs.communicator_objects.unity_input_pb2 import UnityInputProto
-from mlagents.envs.communicator_objects.custom_action_pb2 import CustomActionProto
from .rpc_communicator import RpcCommunicator
from sys import platform
@@ -41,8 +40,7 @@
class UnityEnvironment(BaseUnityEnvironment):
SCALAR_ACTION_TYPES = (int, np.int32, np.int64, float, np.float32, np.float64)
SINGLE_BRAIN_ACTION_TYPES = SCALAR_ACTION_TYPES + (list, np.ndarray)
- SINGLE_BRAIN_TEXT_TYPES = list
- API_VERSION = "API-11"
+ API_VERSION = "API-12"
def __init__(
self,
@@ -52,7 +50,7 @@ def __init__(
seed: int = 0,
docker_training: bool = False,
no_graphics: bool = False,
- timeout_wait: int = 30,
+ timeout_wait: int = 60,
args: Optional[List[str]] = None,
):
"""
@@ -74,12 +72,10 @@ def __init__(
self.port = base_port + worker_id
self._buffer_size = 12000
self._version_ = UnityEnvironment.API_VERSION
- self._loaded = (
- False
- ) # If true, this means the environment was successfully loaded
- self.proc1 = (
- None
- ) # The process that is started. If None, no process was started
+ # If true, this means the environment was successfully loaded
+ self._loaded = False
+ # The process that is started. If None, no process was started
+ self.proc1 = None
self.timeout_wait: int = timeout_wait
self.communicator = self.get_communicator(worker_id, base_port, timeout_wait)
self.worker_id = worker_id
@@ -249,23 +245,22 @@ def executable_launcher(self, file_name, docker_training, no_graphics, args):
) from perm
else:
- """
- Comments for future maintenance:
- xvfb-run is a wrapper around Xvfb, a virtual xserver where all
- rendering is done to virtual memory. It automatically creates a
- new virtual server automatically picking a server number `auto-servernum`.
- The server is passed the arguments using `server-args`, we are telling
- Xvfb to create Screen number 0 with width 640, height 480 and depth 24 bits.
- Note that 640 X 480 are the default width and height. The main reason for
- us to add this is because we'd like to change the depth from the default
- of 8 bits to 24.
- Unfortunately, this means that we will need to pass the arguments through
- a shell which is why we set `shell=True`. Now, this adds its own
- complications. E.g SIGINT can bounce off the shell and not get propagated
- to the child processes. This is why we add `exec`, so that the shell gets
- launched, the arguments are passed to `xvfb-run`. `exec` replaces the shell
- we created with `xvfb`.
- """
+ # Comments for future maintenance:
+ # xvfb-run is a wrapper around Xvfb, a virtual xserver where all
+ # rendering is done to virtual memory. It automatically creates a
+ # new virtual server automatically picking a server number `auto-servernum`.
+ # The server is passed the arguments using `server-args`, we are telling
+ # Xvfb to create Screen number 0 with width 640, height 480 and depth 24 bits.
+ # Note that 640 X 480 are the default width and height. The main reason for
+ # us to add this is because we'd like to change the depth from the default
+ # of 8 bits to 24.
+ # Unfortunately, this means that we will need to pass the arguments through
+ # a shell which is why we set `shell=True`. Now, this adds its own
+ # complications. E.g SIGINT can bounce off the shell and not get propagated
+ # to the child processes. This is why we add `exec`, so that the shell gets
+ # launched, the arguments are passed to `xvfb-run`. `exec` replaces the shell
+ # we created with `xvfb`.
+ #
docker_ls = (
"exec xvfb-run --auto-servernum"
" --server-args='-screen 0 640x480x24'"
@@ -279,22 +274,18 @@ def executable_launcher(self, file_name, docker_training, no_graphics, args):
)
def __str__(self):
- return (
- """Unity Academy name: {0}
- Number of Training Brains : {1}
- Reset Parameters :\n\t\t{2}""".format(
- self._academy_name,
- str(self._num_external_brains),
- "\n\t\t".join(
- [
- str(k) + " -> " + str(self._resetParameters[k])
- for k in self._resetParameters
- ]
- ),
+ reset_params_str = (
+ "\n\t\t".join(
+ [
+ str(k) + " -> " + str(self._resetParameters[k])
+ for k in self._resetParameters
+ ]
)
- + "\n"
- + "\n".join([str(self._brains[b]) for b in self._brains])
+ if self._resetParameters
+ else "{}"
)
+ return f"""Unity Academy name: {self._academy_name}
+ Reset Parameters : {reset_params_str}"""
def reset(
self,
@@ -348,10 +339,7 @@ def reset(
def step(
self,
vector_action: Dict[str, np.ndarray] = None,
- memory: Optional[Dict[str, np.ndarray]] = None,
- text_action: Optional[Dict[str, List[str]]] = None,
value: Optional[Dict[str, np.ndarray]] = None,
- custom_action: Dict[str, Any] = None,
) -> AllBrainInfo:
"""
Provides the environment with an action, moves the environment dynamics forward accordingly,
@@ -359,17 +347,12 @@ def step(
:param value: Value estimates provided by agents.
:param vector_action: Agent's vector action. Can be a scalar or vector of int/floats.
:param memory: Vector corresponding to memory used for recurrent policies.
- :param text_action: Text action to send to environment for.
- :param custom_action: Optional instance of a CustomAction protobuf message.
:return: AllBrainInfo : A Data structure corresponding to the new state of the environment.
"""
if self._is_first_message:
return self.reset()
vector_action = {} if vector_action is None else vector_action
- memory = {} if memory is None else memory
- text_action = {} if text_action is None else text_action
value = {} if value is None else value
- custom_action = {} if custom_action is None else custom_action
# Check that environment is loaded, and episode is currently running.
if not self._loaded:
@@ -389,34 +372,6 @@ def step(
"step cannot take a vector_action input"
)
- if isinstance(memory, self.SINGLE_BRAIN_ACTION_TYPES):
- if self._num_external_brains == 1:
- memory = {self._external_brain_names[0]: memory}
- elif self._num_external_brains > 1:
- raise UnityActionException(
- "You have {0} brains, you need to feed a dictionary of brain names as keys "
- "and memories as values".format(self._num_external_brains)
- )
- else:
- raise UnityActionException(
- "There are no external brains in the environment, "
- "step cannot take a memory input"
- )
-
- if isinstance(text_action, self.SINGLE_BRAIN_TEXT_TYPES):
- if self._num_external_brains == 1:
- text_action = {self._external_brain_names[0]: text_action}
- elif self._num_external_brains > 1:
- raise UnityActionException(
- "You have {0} brains, you need to feed a dictionary of brain names as keys "
- "and text_actions as values".format(self._num_external_brains)
- )
- else:
- raise UnityActionException(
- "There are no external brains in the environment, "
- "step cannot take a value input"
- )
-
if isinstance(value, self.SINGLE_BRAIN_ACTION_TYPES):
if self._num_external_brains == 1:
value = {self._external_brain_names[0]: value}
@@ -433,27 +388,7 @@ def step(
"step cannot take a value input"
)
- if isinstance(custom_action, CustomActionProto):
- if self._num_external_brains == 1:
- custom_action = {self._external_brain_names[0]: custom_action}
- elif self._num_external_brains > 1:
- raise UnityActionException(
- "You have {0} brains, you need to feed a dictionary of brain names as keys "
- "and CustomAction instances as values".format(
- self._num_external_brains
- )
- )
- else:
- raise UnityActionException(
- "There are no external brains in the environment, "
- "step cannot take a custom_action input"
- )
-
- for brain_name in (
- list(vector_action.keys())
- + list(memory.keys())
- + list(text_action.keys())
- ):
+ for brain_name in list(vector_action.keys()):
if brain_name not in self._external_brain_names:
raise UnityActionException(
"The name {0} does not correspond to an external brain "
@@ -477,37 +412,6 @@ def step(
)
else:
vector_action[brain_name] = self._flatten(vector_action[brain_name])
- if brain_name not in memory:
- memory[brain_name] = []
- else:
- if memory[brain_name] is None:
- memory[brain_name] = []
- else:
- memory[brain_name] = self._flatten(memory[brain_name])
- if brain_name not in text_action:
- text_action[brain_name] = [""] * n_agent
- else:
- if text_action[brain_name] is None:
- text_action[brain_name] = [""] * n_agent
- if brain_name not in custom_action:
- custom_action[brain_name] = [None] * n_agent
- else:
- if custom_action[brain_name] is None:
- custom_action[brain_name] = [None] * n_agent
- if isinstance(custom_action[brain_name], CustomActionProto):
- custom_action[brain_name] = [
- custom_action[brain_name]
- ] * n_agent
-
- number_text_actions = len(text_action[brain_name])
- if not ((number_text_actions == n_agent) or number_text_actions == 0):
- raise UnityActionException(
- "There was a mismatch between the provided text_action and "
- "the environment's expectation: "
- "The brain {0} expected {1} text_action but was given {2}".format(
- brain_name, n_agent, number_text_actions
- )
- )
discrete_check = (
self._brains[brain_name].vector_action_space_type == "discrete"
@@ -548,9 +452,7 @@ def step(
)
)
- step_input = self._generate_step_input(
- vector_action, memory, text_action, value, custom_action
- )
+ step_input = self._generate_step_input(vector_action, value)
with hierarchical_timer("communicator.exchange"):
outputs = self.communicator.exchange(step_input)
if outputs is None:
@@ -602,8 +504,10 @@ def _flatten(cls, arr: Any) -> List[float]:
if len(arr) == 0:
return arr
if isinstance(arr[0], np.ndarray):
+ # pylint: disable=no-member
arr = [item for sublist in arr for item in sublist.tolist()]
if isinstance(arr[0], list):
+ # pylint: disable=not-an-iterable
arr = [item for sublist in arr for item in sublist]
arr = [float(x) for x in arr]
return arr
@@ -630,20 +534,15 @@ def _update_brain_parameters(self, output: UnityOutputProto) -> None:
agent_infos = output.rl_output.agentInfos[brain_param.brain_name]
if agent_infos.value:
agent = agent_infos.value[0]
- self._brains[brain_param.brain_name] = BrainParameters.from_proto(
- brain_param, agent
- )
+ new_brain = BrainParameters.from_proto(brain_param, agent)
+ self._brains[brain_param.brain_name] = new_brain
+ logger.info(f"Connected new brain:\n{new_brain}")
self._external_brain_names = list(self._brains.keys())
self._num_external_brains = len(self._external_brain_names)
@timed
def _generate_step_input(
- self,
- vector_action: Dict[str, np.ndarray],
- memory: Dict[str, np.ndarray],
- text_action: Dict[str, list],
- value: Dict[str, np.ndarray],
- custom_action: Dict[str, list],
+ self, vector_action: Dict[str, np.ndarray], value: Dict[str, np.ndarray]
) -> UnityInputProto:
rl_in = UnityRLInputProto()
for b in vector_action:
@@ -651,13 +550,9 @@ def _generate_step_input(
if n_agents == 0:
continue
_a_s = len(vector_action[b]) // n_agents
- _m_s = len(memory[b]) // n_agents
for i in range(n_agents):
action = AgentActionProto(
- vector_actions=vector_action[b][i * _a_s : (i + 1) * _a_s],
- memories=memory[b][i * _m_s : (i + 1) * _m_s],
- text_actions=text_action[b][i],
- custom_action=custom_action[b][i],
+ vector_actions=vector_action[b][i * _a_s : (i + 1) * _a_s]
)
if b in value:
if value[b] is not None:
@@ -702,7 +597,7 @@ def returncode_to_signal_name(returncode: int) -> Optional[str]:
"""
try:
# A negative value -N indicates that the child was terminated by signal N (POSIX only).
- s = signal.Signals(-returncode)
+ s = signal.Signals(-returncode) # pylint: disable=no-member
return s.name
except Exception:
# Should generally be a ValueError, but catch everything just in case.
diff --git a/ml-agents-envs/mlagents/envs/mock_communicator.py b/ml-agents-envs/mlagents/envs/mock_communicator.py
index 398af545db..1f614952df 100755
--- a/ml-agents-envs/mlagents/envs/mock_communicator.py
+++ b/ml-agents-envs/mlagents/envs/mock_communicator.py
@@ -8,9 +8,10 @@
from mlagents.envs.communicator_objects.unity_input_pb2 import UnityInputProto
from mlagents.envs.communicator_objects.unity_output_pb2 import UnityOutputProto
from mlagents.envs.communicator_objects.agent_info_pb2 import AgentInfoProto
-from mlagents.envs.communicator_objects.compressed_observation_pb2 import (
- CompressedObservationProto,
- CompressionTypeProto,
+from mlagents.envs.communicator_objects.observation_pb2 import (
+ ObservationProto,
+ NONE as COMPRESSION_TYPE_NONE,
+ PNG as COMPRESSION_TYPE_PNG,
)
@@ -19,7 +20,6 @@ def __init__(
self,
discrete_action=False,
visual_inputs=0,
- stack=True,
num_agents=3,
brain_name="RealFakeBrain",
vec_obs_size=3,
@@ -30,6 +30,7 @@ def __init__(
:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
:int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
"""
+ super().__init__()
self.is_discrete = discrete_action
self.steps = 0
self.visual_inputs = visual_inputs
@@ -37,15 +38,9 @@ def __init__(
self.num_agents = num_agents
self.brain_name = brain_name
self.vec_obs_size = vec_obs_size
- if stack:
- self.num_stacks = 2
- else:
- self.num_stacks = 1
def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
bp = BrainParametersProto(
- vector_observation_size=self.vec_obs_size,
- num_stacked_vector_observations=self.num_stacks,
vector_action_size=[2],
vector_action_descriptions=["", ""],
vector_action_space_type=int(not self.is_discrete),
@@ -63,36 +58,32 @@ def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
def _get_agent_infos(self):
dict_agent_info = {}
- if self.is_discrete:
- vector_action = [1]
- else:
- vector_action = [1, 2]
list_agent_info = []
- if self.num_stacks == 1:
- observation = [1, 2, 3]
- else:
- observation = [1, 2, 3, 1, 2, 3]
+ vector_obs = [1, 2, 3]
- compressed_obs = [
- CompressedObservationProto(
- data=None, shape=[30, 40, 3], compression_type=CompressionTypeProto.PNG
+ observations = [
+ ObservationProto(
+ compressed_data=None,
+ shape=[30, 40, 3],
+ compression_type=COMPRESSION_TYPE_PNG,
)
for _ in range(self.visual_inputs)
]
+ vector_obs_proto = ObservationProto(
+ float_data=ObservationProto.FloatData(data=vector_obs),
+ shape=[len(vector_obs)],
+ compression_type=COMPRESSION_TYPE_NONE,
+ )
+ observations.append(vector_obs_proto)
for i in range(self.num_agents):
list_agent_info.append(
AgentInfoProto(
- stacked_vector_observation=observation,
reward=1,
- stored_vector_actions=vector_action,
- stored_text_actions="",
- text_observation="",
- memories=[],
done=(i == 2),
max_step_reached=False,
id=i,
- compressed_observations=compressed_obs,
+ observations=observations,
)
)
dict_agent_info["RealFakeBrain"] = UnityRLOutputProto.ListAgentInfoProto(
diff --git a/ml-agents-envs/mlagents/envs/rpc_communicator.py b/ml-agents-envs/mlagents/envs/rpc_communicator.py
index 9795d98e34..1aa8f6c0f1 100644
--- a/ml-agents-envs/mlagents/envs/rpc_communicator.py
+++ b/ml-agents-envs/mlagents/envs/rpc_communicator.py
@@ -41,6 +41,7 @@ def __init__(self, worker_id=0, base_port=5005, timeout_wait=30):
:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
:int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
"""
+ super().__init__(worker_id, base_port)
self.port = base_port + worker_id
self.worker_id = worker_id
self.timeout_wait = timeout_wait
@@ -82,7 +83,12 @@ def check_port(self, port):
finally:
s.close()
- def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
+ def poll_for_timeout(self):
+ """
+ Polls the GRPC parent connection for data, to be used before calling recv. This prevents
+ us from hanging indefinitely in the case where the environment process has died or was not
+ launched.
+ """
if not self.unity_to_external.parent_conn.poll(self.timeout_wait):
raise UnityTimeOutException(
"The Unity environment took too long to respond. Make sure that :\n"
@@ -90,6 +96,9 @@ def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
"\t The Agents are linked to the appropriate Brains\n"
"\t The environment and the Python interface have compatible versions."
)
+
+ def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
+ self.poll_for_timeout()
aca_param = self.unity_to_external.parent_conn.recv().unity_output
message = UnityMessageProto()
message.header.status = 200
@@ -103,6 +112,7 @@ def exchange(self, inputs: UnityInputProto) -> Optional[UnityOutputProto]:
message.header.status = 200
message.unity_input.CopyFrom(inputs)
self.unity_to_external.parent_conn.send(message)
+ self.poll_for_timeout()
output = self.unity_to_external.parent_conn.recv()
if output.header.status != 200:
return None
diff --git a/ml-agents-envs/mlagents/envs/simple_env_manager.py b/ml-agents-envs/mlagents/envs/simple_env_manager.py
index e2d5feeb3b..1a2a8a4a40 100644
--- a/ml-agents-envs/mlagents/envs/simple_env_manager.py
+++ b/ml-agents-envs/mlagents/envs/simple_env_manager.py
@@ -25,15 +25,11 @@ def step(self) -> List[EnvironmentStep]:
self.previous_all_action_info = all_action_info
actions = {}
- memories = {}
- texts = {}
values = {}
for brain_name, action_info in all_action_info.items():
actions[brain_name] = action_info.action
- memories[brain_name] = action_info.memory
- texts[brain_name] = action_info.text
values[brain_name] = action_info.value
- all_brain_info = self.env.step(actions, memories, texts, values)
+ all_brain_info = self.env.step(vector_action=actions, value=values)
step_brain_info = all_brain_info
step_info = EnvironmentStep(
diff --git a/ml-agents-envs/mlagents/envs/subprocess_env_manager.py b/ml-agents-envs/mlagents/envs/subprocess_env_manager.py
index a91a49650c..c83465f152 100644
--- a/ml-agents-envs/mlagents/envs/subprocess_env_manager.py
+++ b/ml-agents-envs/mlagents/envs/subprocess_env_manager.py
@@ -3,7 +3,7 @@
import cloudpickle
from mlagents.envs.environment import UnityEnvironment
-from mlagents.envs.exception import UnityCommunicationException
+from mlagents.envs.exception import UnityCommunicationException, UnityTimeOutException
from multiprocessing import Process, Pipe, Queue
from multiprocessing.connection import Connection
from queue import Empty as EmptyQueueException
@@ -79,7 +79,7 @@ def worker(
env_factory: Callable[[int], UnityEnvironment] = cloudpickle.loads(
pickled_env_factory
)
- env = env_factory(worker_id)
+ env: BaseUnityEnvironment = env_factory(worker_id)
def _send_response(cmd_name, payload):
parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload))
@@ -90,15 +90,11 @@ def _send_response(cmd_name, payload):
if cmd.name == "step":
all_action_info = cmd.payload
actions = {}
- memories = {}
- texts = {}
values = {}
for brain_name, action_info in all_action_info.items():
actions[brain_name] = action_info.action
- memories[brain_name] = action_info.memory
- texts[brain_name] = action_info.text
values[brain_name] = action_info.value
- all_brain_info = env.step(actions, memories, texts, values)
+ all_brain_info = env.step(vector_action=actions, value=values)
# The timers in this process are independent from all the processes and the "main" process
# So after we send back the root timer, we can safely clear them.
# Note that we could randomly return timers a fraction of the time if we wanted to reduce
@@ -118,7 +114,7 @@ def _send_response(cmd_name, payload):
_send_response("reset", all_brain_info)
elif cmd.name == "close":
break
- except (KeyboardInterrupt, UnityCommunicationException):
+ except (KeyboardInterrupt, UnityCommunicationException, UnityTimeOutException):
logger.info(f"UnityEnvironment worker {worker_id}: environment stopping.")
step_queue.put(EnvironmentResponse("env_close", worker_id, None))
finally:
diff --git a/ml-agents-envs/mlagents/envs/tests/test_brain.py b/ml-agents-envs/mlagents/envs/tests/test_brain.py
index 73b7b23954..b5d4a0c254 100644
--- a/ml-agents-envs/mlagents/envs/tests/test_brain.py
+++ b/ml-agents-envs/mlagents/envs/tests/test_brain.py
@@ -1,15 +1,19 @@
+from typing import List
import logging
import numpy as np
import sys
from unittest import mock
from mlagents.envs.communicator_objects.agent_info_pb2 import AgentInfoProto
+from mlagents.envs.communicator_objects.observation_pb2 import (
+ ObservationProto,
+ NONE as COMPRESSION_TYPE_NONE,
+)
from mlagents.envs.brain import BrainInfo, BrainParameters
test_brain = BrainParameters(
brain_name="test_brain",
vector_observation_space_size=3,
- num_stacked_vector_observations=1,
camera_resolutions=[],
vector_action_space_size=[],
vector_action_descriptions=[],
@@ -17,11 +21,20 @@
)
+def _make_agent_info_proto(vector_obs: List[float]) -> AgentInfoProto:
+ obs = ObservationProto(
+ float_data=ObservationProto.FloatData(data=vector_obs),
+ shape=[len(vector_obs)],
+ compression_type=COMPRESSION_TYPE_NONE,
+ )
+ agent_info_proto = AgentInfoProto(observations=[obs])
+ return agent_info_proto
+
+
@mock.patch.object(np, "nan_to_num", wraps=np.nan_to_num)
@mock.patch.object(logging.Logger, "warning")
def test_from_agent_proto_nan(mock_warning, mock_nan_to_num):
- agent_info_proto = AgentInfoProto()
- agent_info_proto.stacked_vector_observation.extend([1.0, 2.0, float("nan")])
+ agent_info_proto = _make_agent_info_proto([1.0, 2.0, float("nan")])
brain_info = BrainInfo.from_agent_proto(1, [agent_info_proto], test_brain)
# nan gets set to 0.0
@@ -34,8 +47,7 @@ def test_from_agent_proto_nan(mock_warning, mock_nan_to_num):
@mock.patch.object(np, "nan_to_num", wraps=np.nan_to_num)
@mock.patch.object(logging.Logger, "warning")
def test_from_agent_proto_inf(mock_warning, mock_nan_to_num):
- agent_info_proto = AgentInfoProto()
- agent_info_proto.stacked_vector_observation.extend([1.0, float("inf"), 0.0])
+ agent_info_proto = _make_agent_info_proto([1.0, float("inf"), 0.0])
brain_info = BrainInfo.from_agent_proto(1, [agent_info_proto], test_brain)
# inf should get set to float_max
@@ -52,8 +64,7 @@ def test_from_agent_proto_fast_path(mock_warning, mock_nan_to_num):
"""
Check that all finite values skips the nan_to_num call
"""
- agent_info_proto = AgentInfoProto()
- agent_info_proto.stacked_vector_observation.extend([1.0, 2.0, 3.0])
+ agent_info_proto = _make_agent_info_proto([1.0, 2.0, 3.0])
brain_info = BrainInfo.from_agent_proto(1, [agent_info_proto], test_brain)
expected = [1.0, 2.0, 3.0]
diff --git a/ml-agents-envs/mlagents/envs/tests/test_envs.py b/ml-agents-envs/mlagents/envs/tests/test_envs.py
index 1827937862..7b6c518727 100755
--- a/ml-agents-envs/mlagents/envs/tests/test_envs.py
+++ b/ml-agents-envs/mlagents/envs/tests/test_envs.py
@@ -49,7 +49,7 @@ def test_reset(mock_communicator, mock_launcher):
)
assert (
len(brain_info["RealFakeBrain"].vector_observations[0])
- == brain.vector_observation_space_size * brain.num_stacked_vector_observations
+ == brain.vector_observation_space_size
)
@@ -88,7 +88,7 @@ def test_step(mock_communicator, mock_launcher):
)
assert (
len(brain_info["RealFakeBrain"].vector_observations[0])
- == brain.vector_observation_space_size * brain.num_stacked_vector_observations
+ == brain.vector_observation_space_size
)
print("\n\n\n\n\n\n\n" + str(brain_info["RealFakeBrain"].local_done))
diff --git a/ml-agents-envs/mlagents/envs/timers.py b/ml-agents-envs/mlagents/envs/timers.py
index 88ccd00e2d..8098910fc9 100644
--- a/ml-agents-envs/mlagents/envs/timers.py
+++ b/ml-agents-envs/mlagents/envs/timers.py
@@ -1,10 +1,3 @@
-# # Unity ML-Agents Toolkit
-import math
-from time import perf_counter
-
-from contextlib import contextmanager
-from typing import Any, Callable, Dict, Generator, List, TypeVar
-
"""
Lightweight, hierarchical timers for profiling sections of code.
@@ -35,6 +28,12 @@ def main():
over the timer name, or are splitting up multiple sections of a large function.
"""
+import math
+from time import perf_counter
+
+from contextlib import contextmanager
+from typing import Any, Callable, Dict, Generator, List, TypeVar
+
class TimerNode:
"""
diff --git a/ml-agents-envs/setup.py b/ml-agents-envs/setup.py
index f2021aee65..236616367d 100644
--- a/ml-agents-envs/setup.py
+++ b/ml-agents-envs/setup.py
@@ -2,8 +2,9 @@
import sys
from setuptools import setup
from setuptools.command.install import install
+import mlagents.envs
-VERSION = "0.11.0"
+VERSION = mlagents.envs.__version__
here = os.path.abspath(os.path.dirname(__file__))
diff --git a/ml-agents/mlagents/tf_utils/__init__.py b/ml-agents/mlagents/tf_utils/__init__.py
new file mode 100644
index 0000000000..239f11b423
--- /dev/null
+++ b/ml-agents/mlagents/tf_utils/__init__.py
@@ -0,0 +1,2 @@
+from mlagents.tf_utils.tf import tf as tf # noqa
+from mlagents.tf_utils.tf import set_warnings_enabled # noqa
diff --git a/ml-agents/mlagents/tf_utils/tf.py b/ml-agents/mlagents/tf_utils/tf.py
new file mode 100644
index 0000000000..6a2917da6d
--- /dev/null
+++ b/ml-agents/mlagents/tf_utils/tf.py
@@ -0,0 +1,30 @@
+# This should be the only place that we import tensorflow directly.
+# Everywhere else is caught by the banned-modules setting for flake8
+import tensorflow as tf # noqa I201
+from distutils.version import LooseVersion
+
+
+# LooseVersion handles things "1.2.3a" or "4.5.6-rc7" fairly sensibly.
+_is_tensorflow2 = LooseVersion(tf.__version__) >= LooseVersion("2.0.0")
+
+if _is_tensorflow2:
+ import tensorflow.compat.v1 as tf
+
+ tf.disable_v2_behavior()
+ tf_logging = tf.logging
+else:
+ try:
+ # Newer versions of tf 1.x will complain that tf.logging is deprecated
+ tf_logging = tf.compat.v1.logging
+ except AttributeError:
+ # Fall back to the safe import, even if it might generate a warning or two.
+ tf_logging = tf.logging
+
+
+def set_warnings_enabled(is_enabled: bool) -> None:
+ """
+ Enable or disable tensorflow warnings (notabley, this disables deprecation warnings.
+ :param is_enabled:
+ """
+ level = tf_logging.WARN if is_enabled else tf_logging.ERROR
+ tf_logging.set_verbosity(level)
diff --git a/ml-agents/mlagents/trainers/__init__.py b/ml-agents/mlagents/trainers/__init__.py
index e69de29bb2..ea370a8e55 100644
--- a/ml-agents/mlagents/trainers/__init__.py
+++ b/ml-agents/mlagents/trainers/__init__.py
@@ -0,0 +1 @@
+__version__ = "0.12.0"
diff --git a/ml-agents/mlagents/trainers/barracuda.py b/ml-agents/mlagents/trainers/barracuda.py
index dba755186d..c2d48ba952 100644
--- a/ml-agents/mlagents/trainers/barracuda.py
+++ b/ml-agents/mlagents/trainers/barracuda.py
@@ -1,3 +1,5 @@
+# pylint: skip-file
+# flake8: noqa
from __future__ import print_function
from collections import defaultdict
import numpy as np
diff --git a/ml-agents/mlagents/trainers/bc/models.py b/ml-agents/mlagents/trainers/bc/models.py
index b6f1f715f6..7619972bac 100644
--- a/ml-agents/mlagents/trainers/bc/models.py
+++ b/ml-agents/mlagents/trainers/bc/models.py
@@ -1,5 +1,5 @@
-import tensorflow as tf
-import tensorflow.contrib.layers as c_layers
+from mlagents.tf_utils import tf
+
from mlagents.trainers.models import LearningModel
@@ -44,9 +44,7 @@ def __init__(
size,
activation=None,
use_bias=False,
- kernel_initializer=c_layers.variance_scaling_initializer(
- factor=0.01
- ),
+ kernel_initializer=tf.initializers.variance_scaling(0.01),
)
)
self.action_probs = tf.concat(
@@ -93,7 +91,7 @@ def __init__(
activation=None,
use_bias=False,
name="pre_action",
- kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01),
+ kernel_initializer=tf.initializers.variance_scaling(0.01),
)
self.clipped_sample_action = tf.clip_by_value(self.policy, -1, 1)
self.sample_action = tf.identity(self.clipped_sample_action, name="action")
diff --git a/ml-agents/mlagents/trainers/bc/policy.py b/ml-agents/mlagents/trainers/bc/policy.py
index 2c31a19dec..4370e84911 100644
--- a/ml-agents/mlagents/trainers/bc/policy.py
+++ b/ml-agents/mlagents/trainers/bc/policy.py
@@ -59,9 +59,7 @@ def evaluate(self, brain_info):
feed_dict = self.fill_eval_dict(feed_dict, brain_info)
if self.use_recurrent:
- if brain_info.memories.shape[1] == 0:
- brain_info.memories = self.make_empty_memory(len(brain_info.agents))
- feed_dict[self.model.memory_in] = brain_info.memories
+ feed_dict[self.model.memory_in] = self.retrieve_memories(brain_info.agents)
run_out = self._execute_model(feed_dict, self.inference_dict)
return run_out
diff --git a/ml-agents/mlagents/trainers/bc/trainer.py b/ml-agents/mlagents/trainers/bc/trainer.py
index 4bce079748..757cdaef2f 100644
--- a/ml-agents/mlagents/trainers/bc/trainer.py
+++ b/ml-agents/mlagents/trainers/bc/trainer.py
@@ -6,7 +6,7 @@
import numpy as np
-from mlagents.envs.brain import AllBrainInfo
+from mlagents.envs.brain import BrainInfo
from mlagents.envs.action_info import ActionInfoOutputs
from mlagents.trainers.bc.policy import BCPolicy
from mlagents.trainers.buffer import Buffer
@@ -45,50 +45,47 @@ def __init__(self, brain, trainer_parameters, training, load, seed, run_id):
def add_experiences(
self,
- curr_info: AllBrainInfo,
- next_info: AllBrainInfo,
+ curr_info: BrainInfo,
+ next_info: BrainInfo,
take_action_outputs: ActionInfoOutputs,
) -> None:
"""
Adds experiences to each agent's experience history.
- :param curr_info: Current AllBrainInfo (Dictionary of all current brains and corresponding BrainInfo).
- :param next_info: Next AllBrainInfo (Dictionary of all current brains and corresponding BrainInfo).
+ :param curr_info: Current BrainInfo
+ :param next_info: Next BrainInfo
:param take_action_outputs: The outputs of the take action method.
"""
# Used to collect information about student performance.
- info_student = curr_info[self.brain_name]
- next_info_student = next_info[self.brain_name]
- for agent_id in info_student.agents:
- self.evaluation_buffer[agent_id].last_brain_info = info_student
-
- for agent_id in next_info_student.agents:
- stored_info_student = self.evaluation_buffer[agent_id].last_brain_info
- if stored_info_student is None:
+ for agent_id in curr_info.agents:
+ self.evaluation_buffer[agent_id].last_brain_info = curr_info
+
+ for agent_id in next_info.agents:
+ stored_next_info = self.evaluation_buffer[agent_id].last_brain_info
+ if stored_next_info is None:
continue
else:
- next_idx = next_info_student.agents.index(agent_id)
+ next_idx = next_info.agents.index(agent_id)
if agent_id not in self.cumulative_rewards:
self.cumulative_rewards[agent_id] = 0
- self.cumulative_rewards[agent_id] += next_info_student.rewards[next_idx]
- if not next_info_student.local_done[next_idx]:
+ self.cumulative_rewards[agent_id] += next_info.rewards[next_idx]
+ if not next_info.local_done[next_idx]:
if agent_id not in self.episode_steps:
self.episode_steps[agent_id] = 0
self.episode_steps[agent_id] += 1
def process_experiences(
- self, current_info: AllBrainInfo, next_info: AllBrainInfo
+ self, current_info: BrainInfo, next_info: BrainInfo
) -> None:
"""
Checks agent histories for processing condition, and processes them as necessary.
Processing involves calculating value and advantage targets for model updating step.
- :param current_info: Current AllBrainInfo
- :param next_info: Next AllBrainInfo
+ :param current_info: Current BrainInfo
+ :param next_info: Next BrainInfo
"""
- info_student = next_info[self.brain_name]
- for l in range(len(info_student.agents)):
- if info_student.local_done[l]:
- agent_id = info_student.agents[l]
+ for l in range(len(next_info.agents)):
+ if next_info.local_done[l]:
+ agent_id = next_info.agents[l]
self.stats["Environment/Cumulative Reward"].append(
self.cumulative_rewards.get(agent_id, 0)
)
@@ -125,13 +122,14 @@ def update_policy(self):
"""
self.demonstration_buffer.update_buffer.shuffle(self.policy.sequence_length)
batch_losses = []
+ batch_size = self.n_sequences * self.policy.sequence_length
+ # We either divide the entire buffer into num_batches batches, or limit the number
+ # of batches to batches_per_epoch.
num_batches = min(
- len(self.demonstration_buffer.update_buffer["actions"]) // self.n_sequences,
+ len(self.demonstration_buffer.update_buffer["actions"]) // batch_size,
self.batches_per_epoch,
)
- batch_size = self.n_sequences * self.policy.sequence_length
-
for i in range(0, num_batches * batch_size, batch_size):
update_buffer = self.demonstration_buffer.update_buffer
mini_batch = update_buffer.make_mini_batch(i, i + batch_size)
diff --git a/ml-agents/mlagents/trainers/components/bc/model.py b/ml-agents/mlagents/trainers/components/bc/model.py
index 4890e65572..7f57a1ec7e 100644
--- a/ml-agents/mlagents/trainers/components/bc/model.py
+++ b/ml-agents/mlagents/trainers/components/bc/model.py
@@ -1,4 +1,5 @@
-import tensorflow as tf
+from mlagents.tf_utils import tf
+
from mlagents.trainers.models import LearningModel
@@ -73,7 +74,7 @@ def create_loss(self, learning_rate: float, anneal_steps: int) -> None:
power=1.0,
)
else:
- self.annealed_learning_rate = learning_rate
+ self.annealed_learning_rate = tf.Variable(learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate=self.annealed_learning_rate)
self.update_batch = optimizer.minimize(self.loss)
diff --git a/ml-agents/mlagents/trainers/components/reward_signals/__init__.py b/ml-agents/mlagents/trainers/components/reward_signals/__init__.py
index c91dd382fe..a5141f0161 100644
--- a/ml-agents/mlagents/trainers/components/reward_signals/__init__.py
+++ b/ml-agents/mlagents/trainers/components/reward_signals/__init__.py
@@ -4,7 +4,7 @@
import numpy as np
import abc
-import tensorflow as tf
+from mlagents.tf_utils import tf
from mlagents.envs.brain import BrainInfo
from mlagents.trainers.trainer import UnityTrainerException
@@ -49,11 +49,12 @@ def __init__(
self.stats_name_to_update_name: Dict[str, str] = {}
def evaluate(
- self, current_info: BrainInfo, next_info: BrainInfo
+ self, current_info: BrainInfo, action: np.array, next_info: BrainInfo
) -> RewardSignalResult:
"""
Evaluates the reward for the agents present in current_info given the next_info
:param current_info: The current BrainInfo.
+ :param action: the action that was taken between the two infos
:param next_info: The BrainInfo from the next timestep.
:return: a RewardSignalResult of (scaled intrinsic reward, unscaled intrinsic reward) provided by the generator
"""
diff --git a/ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py b/ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py
index 009e2663df..421ba58e55 100644
--- a/ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py
+++ b/ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py
@@ -1,5 +1,6 @@
from typing import List, Tuple
-import tensorflow as tf
+from mlagents.tf_utils import tf
+
from mlagents.trainers.models import LearningModel
diff --git a/ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py b/ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
index 52542fc0dd..4f643f9d65 100644
--- a/ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
+++ b/ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
@@ -1,6 +1,7 @@
from typing import Any, Dict, List
import numpy as np
-import tensorflow as tf
+from mlagents.tf_utils import tf
+
from mlagents.envs.brain import BrainInfo
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult
@@ -45,7 +46,7 @@ def __init__(
self.has_updated = False
def evaluate(
- self, current_info: BrainInfo, next_info: BrainInfo
+ self, current_info: BrainInfo, action: np.array, next_info: BrainInfo
) -> RewardSignalResult:
"""
Evaluates the reward for the agents present in current_info given the next_info
@@ -57,7 +58,7 @@ def evaluate(
return RewardSignalResult([], [])
mini_batch: Dict[str, np.array] = {}
# Construct the batch and use evaluate_batch
- mini_batch["actions"] = next_info.previous_vector_actions
+ mini_batch["actions"] = action
mini_batch["done"] = np.reshape(next_info.local_done, [-1, 1])
for i in range(len(current_info.visual_observations)):
mini_batch["visual_obs%d" % i] = current_info.visual_observations[i]
diff --git a/ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py b/ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py
index 150e27b26f..cbcf3c4d09 100644
--- a/ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py
+++ b/ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py
@@ -3,27 +3,9 @@
from mlagents.envs.brain import BrainInfo
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult
-from mlagents.trainers.tf_policy import TFPolicy
-from mlagents.trainers.models import LearningModel
class ExtrinsicRewardSignal(RewardSignal):
- def __init__(
- self,
- policy: TFPolicy,
- policy_model: LearningModel,
- strength: float,
- gamma: float,
- ):
- """
- The extrinsic reward generator. Returns the reward received by the environment
- :param policy: The Policy object (e.g. PPOPolicy) that this Reward Signal will apply to.
- :param strength: The strength of the reward. The reward's raw value will be multiplied by this value.
- :param gamma: The time discounting factor used for this reward.
- :return: An ExtrinsicRewardSignal object.
- """
- super().__init__(policy, policy_model, strength, gamma)
-
@classmethod
def check_config(
cls, config_dict: Dict[str, Any], param_keys: List[str] = None
@@ -36,7 +18,7 @@ def check_config(
super().check_config(config_dict, param_keys)
def evaluate(
- self, current_info: BrainInfo, next_info: BrainInfo
+ self, current_info: BrainInfo, action: np.array, next_info: BrainInfo
) -> RewardSignalResult:
"""
Evaluates the reward for the agents present in current_info given the next_info
diff --git a/ml-agents/mlagents/trainers/components/reward_signals/gail/model.py b/ml-agents/mlagents/trainers/components/reward_signals/gail/model.py
index 7d106ddd97..791e806480 100644
--- a/ml-agents/mlagents/trainers/components/reward_signals/gail/model.py
+++ b/ml-agents/mlagents/trainers/components/reward_signals/gail/model.py
@@ -1,6 +1,7 @@
-from typing import Tuple, List
+from typing import List, Optional, Tuple
+
+from mlagents.tf_utils import tf
-import tensorflow as tf
from mlagents.trainers.models import LearningModel
EPSILON = 1e-7
@@ -37,6 +38,10 @@ def __init__(
self.gradient_penalty_weight = gradient_penalty_weight
self.use_vail = use_vail
self.use_actions = use_actions # True # Not using actions
+
+ self.noise: Optional[tf.Tensor] = None
+ self.z: Optional[tf.Tensor] = None
+
self.make_inputs()
self.create_network()
self.create_loss(learning_rate)
diff --git a/ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py b/ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py
index 5539805be5..eb77c7edf2 100644
--- a/ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py
+++ b/ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py
@@ -1,7 +1,7 @@
from typing import Any, Dict, List
import logging
import numpy as np
-import tensorflow as tf
+from mlagents.tf_utils import tf
from mlagents.envs.brain import BrainInfo
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult
@@ -68,13 +68,13 @@ def __init__(
}
def evaluate(
- self, current_info: BrainInfo, next_info: BrainInfo
+ self, current_info: BrainInfo, action: np.array, next_info: BrainInfo
) -> RewardSignalResult:
if len(current_info.agents) == 0:
return RewardSignalResult([], [])
mini_batch: Dict[str, np.array] = {}
# Construct the batch
- mini_batch["actions"] = next_info.previous_vector_actions
+ mini_batch["actions"] = action
mini_batch["done"] = np.reshape(next_info.local_done, [-1, 1])
for i, obs in enumerate(current_info.visual_observations):
mini_batch["visual_obs%d" % i] = obs
@@ -126,7 +126,7 @@ def check_config(
def prepare_update(
self,
policy_model: LearningModel,
- mini_batch_policy: Dict[str, np.ndarray],
+ mini_batch: Dict[str, np.ndarray],
num_sequences: int,
) -> Dict[tf.Tensor, Any]:
"""
@@ -136,21 +136,21 @@ def prepare_update(
:return: Feed_dict for update process.
"""
max_num_experiences = min(
- len(mini_batch_policy["actions"]),
+ len(mini_batch["actions"]),
len(self.demonstration_buffer.update_buffer["actions"]),
)
# If num_sequences is less, we need to shorten the input batch.
- for key, element in mini_batch_policy.items():
- mini_batch_policy[key] = element[:max_num_experiences]
+ for key, element in mini_batch.items():
+ mini_batch[key] = element[:max_num_experiences]
# Get batch from demo buffer
mini_batch_demo = self.demonstration_buffer.update_buffer.sample_mini_batch(
- len(mini_batch_policy["actions"]), 1
+ len(mini_batch["actions"]), 1
)
feed_dict: Dict[tf.Tensor, Any] = {
self.model.done_expert_holder: mini_batch_demo["done"],
- self.model.done_policy_holder: mini_batch_policy["done"],
+ self.model.done_policy_holder: mini_batch["done"],
}
if self.model.use_vail:
@@ -158,20 +158,18 @@ def prepare_update(
feed_dict[self.model.action_in_expert] = np.array(mini_batch_demo["actions"])
if self.policy.use_continuous_act:
- feed_dict[policy_model.selected_actions] = mini_batch_policy["actions"]
+ feed_dict[policy_model.selected_actions] = mini_batch["actions"]
else:
- feed_dict[policy_model.action_holder] = mini_batch_policy["actions"]
+ feed_dict[policy_model.action_holder] = mini_batch["actions"]
if self.policy.use_vis_obs > 0:
for i in range(len(policy_model.visual_in)):
- feed_dict[policy_model.visual_in[i]] = mini_batch_policy[
- "visual_obs%d" % i
- ]
+ feed_dict[policy_model.visual_in[i]] = mini_batch["visual_obs%d" % i]
feed_dict[self.model.expert_visual_in[i]] = mini_batch_demo[
"visual_obs%d" % i
]
if self.policy.use_vec_obs:
- feed_dict[policy_model.vector_in] = mini_batch_policy["vector_obs"]
+ feed_dict[policy_model.vector_in] = mini_batch["vector_obs"]
feed_dict[self.model.obs_in_expert] = mini_batch_demo["vector_obs"]
self.has_updated = True
return feed_dict
diff --git a/ml-agents/mlagents/trainers/demo_loader.py b/ml-agents/mlagents/trainers/demo_loader.py
index afba4941ec..629524aaf8 100644
--- a/ml-agents/mlagents/trainers/demo_loader.py
+++ b/ml-agents/mlagents/trainers/demo_loader.py
@@ -2,9 +2,12 @@
import logging
import os
from typing import List, Tuple
+import numpy as np
from mlagents.trainers.buffer import Buffer
from mlagents.envs.brain import BrainParameters, BrainInfo
-from mlagents.envs.communicator_objects.agent_info_pb2 import AgentInfoProto
+from mlagents.envs.communicator_objects.agent_info_action_pair_pb2 import (
+ AgentInfoActionPairProto,
+)
from mlagents.envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
from mlagents.envs.communicator_objects.demonstration_meta_pb2 import (
DemonstrationMetaProto,
@@ -16,15 +19,26 @@
def make_demo_buffer(
- brain_infos: List[BrainInfo], brain_params: BrainParameters, sequence_length: int
+ pair_infos: List[AgentInfoActionPairProto],
+ brain_params: BrainParameters,
+ sequence_length: int,
) -> Buffer:
# Create and populate buffer using experiences
demo_buffer = Buffer()
- for idx, experience in enumerate(brain_infos):
- if idx > len(brain_infos) - 2:
+ for idx, experience in enumerate(pair_infos):
+ if idx > len(pair_infos) - 2:
break
- current_brain_info = brain_infos[idx]
- next_brain_info = brain_infos[idx + 1]
+ current_pair_info = pair_infos[idx]
+ next_pair_info = pair_infos[idx + 1]
+ current_brain_info = BrainInfo.from_agent_proto(
+ 0, [current_pair_info.agent_info], brain_params
+ )
+ next_brain_info = BrainInfo.from_agent_proto(
+ 0, [next_pair_info.agent_info], brain_params
+ )
+ previous_action = np.array(pair_infos[idx].action_info.vector_actions) * 0
+ if idx > 0:
+ previous_action = np.array(pair_infos[idx - 1].action_info.vector_actions)
demo_buffer[0].last_brain_info = current_brain_info
demo_buffer[0]["done"].append(next_brain_info.local_done[0])
demo_buffer[0]["rewards"].append(next_brain_info.rewards[0])
@@ -36,10 +50,8 @@ def make_demo_buffer(
demo_buffer[0]["vector_obs"].append(
current_brain_info.vector_observations[0]
)
- demo_buffer[0]["actions"].append(next_brain_info.previous_vector_actions[0])
- demo_buffer[0]["prev_action"].append(
- current_brain_info.previous_vector_actions[0]
- )
+ demo_buffer[0]["actions"].append(current_pair_info.action_info.vector_actions)
+ demo_buffer[0]["prev_action"].append(previous_action)
if next_brain_info.local_done[0]:
demo_buffer.append_update_buffer(
0, batch_size=None, training_length=sequence_length
@@ -60,16 +72,18 @@ def demo_to_buffer(
:param sequence_length: Length of trajectories to fill buffer.
:return:
"""
- brain_params, brain_infos, _ = load_demonstration(file_path)
- demo_buffer = make_demo_buffer(brain_infos, brain_params, sequence_length)
+ brain_params, info_action_pair, _ = load_demonstration(file_path)
+ demo_buffer = make_demo_buffer(info_action_pair, brain_params, sequence_length)
return brain_params, demo_buffer
-def load_demonstration(file_path: str) -> Tuple[BrainParameters, List[BrainInfo], int]:
+def load_demonstration(
+ file_path: str
+) -> Tuple[BrainParameters, List[AgentInfoActionPairProto], int]:
"""
Loads and parses a demonstration file.
:param file_path: Location of demonstration file (.demo).
- :return: BrainParameter and list of BrainInfos containing demonstration data.
+ :return: BrainParameter and list of AgentInfoActionPairProto containing demonstration data.
"""
# First 32 bytes of file dedicated to meta-data.
@@ -97,7 +111,7 @@ def load_demonstration(file_path: str) -> Tuple[BrainParameters, List[BrainInfo]
brain_params = None
brain_param_proto = None
- brain_infos = []
+ info_action_pairs = []
total_expected = 0
for _file_path in file_paths:
data = open(_file_path, "rb").read()
@@ -112,19 +126,17 @@ def load_demonstration(file_path: str) -> Tuple[BrainParameters, List[BrainInfo]
if obs_decoded == 1:
brain_param_proto = BrainParametersProto()
brain_param_proto.ParseFromString(data[pos : pos + next_pos])
-
pos += next_pos
if obs_decoded > 1:
- agent_info = AgentInfoProto()
- agent_info.ParseFromString(data[pos : pos + next_pos])
+ agent_info_action = AgentInfoActionPairProto()
+ agent_info_action.ParseFromString(data[pos : pos + next_pos])
if brain_params is None:
brain_params = BrainParameters.from_proto(
- brain_param_proto, agent_info
+ brain_param_proto, agent_info_action.agent_info
)
- brain_info = BrainInfo.from_agent_proto(0, [agent_info], brain_params)
- brain_infos.append(brain_info)
- if len(brain_infos) == total_expected:
+ info_action_pairs.append(agent_info_action)
+ if len(info_action_pairs) == total_expected:
break
pos += next_pos
obs_decoded += 1
- return brain_params, brain_infos, total_expected
+ return brain_params, info_action_pairs, total_expected
diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py
index 5dc0d709c7..31ad40a43c 100644
--- a/ml-agents/mlagents/trainers/learn.py
+++ b/ml-agents/mlagents/trainers/learn.py
@@ -10,7 +10,9 @@
from typing import Any, Callable, Optional, List, NamedTuple
-
+import mlagents.trainers
+import mlagents.envs
+from mlagents import tf_utils
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.exception import TrainerError
from mlagents.trainers.meta_curriculum import MetaCurriculum
@@ -54,6 +56,15 @@ def from_argparse(args: Any) -> "CommandLineOptions":
return CommandLineOptions(**vars(args))
+def get_version_string() -> str:
+ return f""" Version information:\n
+ ml-agents: {mlagents.trainers.__version__},
+ ml-agents-envs: {mlagents.envs.__version__},
+ Communicator API: {UnityEnvironment.API_VERSION},
+ TensorFlow: {tf_utils.tf.__version__}
+"""
+
+
def parse_command_line(argv: Optional[List[str]] = None) -> CommandLineOptions:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -160,6 +171,8 @@ def parse_command_line(argv: Optional[List[str]] = None) -> CommandLineOptions:
"--cpu", default=False, action="store_true", help="Run with CPU only"
)
+ parser.add_argument("--version", action="version", version=get_version_string())
+
args = parser.parse_args(argv)
return CommandLineOptions.from_argparse(args)
@@ -333,13 +346,11 @@ def create_environment_factory(
)
docker_training = docker_target_name is not None
if docker_training and env_path is not None:
- """
- Comments for future maintenance:
- Some OS/VM instances (e.g. COS GCP Image) mount filesystems
- with COS flag which prevents execution of the Unity scene,
- to get around this, we will copy the executable into the
- container.
- """
+ # Comments for future maintenance:
+ # Some OS/VM instances (e.g. COS GCP Image) mount filesystems
+ # with COS flag which prevents execution of the Unity scene,
+ # to get around this, we will copy the executable into the
+ # container.
# Navigate in docker path and find env_path and copy it.
env_path = prepare_for_docker_run(docker_target_name, env_path)
seed_count = 10000
@@ -391,6 +402,9 @@ def main():
if options.debug:
trainer_logger.setLevel("DEBUG")
env_logger.setLevel("DEBUG")
+ else:
+ # disable noisy warnings from tensorflow.
+ tf_utils.set_warnings_enabled(False)
if options.env_path is None and options.num_runs > 1:
raise TrainerError(
"It is not possible to launch more than one concurrent training session "
diff --git a/ml-agents/mlagents/trainers/models.py b/ml-agents/mlagents/trainers/models.py
index 3896fbb567..362878a58d 100644
--- a/ml-agents/mlagents/trainers/models.py
+++ b/ml-agents/mlagents/trainers/models.py
@@ -1,10 +1,9 @@
import logging
from enum import Enum
-from typing import Callable, List
+from typing import Callable, Dict, List, Optional
import numpy as np
-import tensorflow as tf
-import tensorflow.contrib.layers as c_layers
+from mlagents.tf_utils import tf
from mlagents.trainers.trainer import UnityTrainerException
from mlagents.envs.brain import CameraResolution
@@ -54,9 +53,7 @@ def __init__(
self.m_size = 0
self.normalize = normalize
self.act_size = brain.vector_action_space_size
- self.vec_obs_size = (
- brain.vector_observation_space_size * brain.num_stacked_vector_observations
- )
+ self.vec_obs_size = brain.vector_observation_space_size
self.vis_obs_size = brain.number_visual_observations
tf.Variable(
int(brain.vector_action_space_type == "continuous"),
@@ -85,6 +82,12 @@ def __init__(
trainable=False,
dtype=tf.int32,
)
+ self.value_heads: Dict[str, tf.Tensor] = {}
+ self.normalization_steps: Optional[tf.Variable] = None
+ self.running_mean: Optional[tf.Variable] = None
+ self.running_variance: Optional[tf.Variable] = None
+ self.update_normalization: Optional[tf.Operation] = None
+ self.value: Optional[tf.Tensor] = None
@staticmethod
def create_global_steps():
@@ -119,7 +122,7 @@ def create_learning_rate(
@staticmethod
def scaled_init(scale):
- return c_layers.variance_scaling_initializer(scale)
+ return tf.initializers.variance_scaling(scale)
@staticmethod
def swish(input_activation: tf.Tensor) -> tf.Tensor:
@@ -241,7 +244,7 @@ def create_vector_observation_encoder(
activation=activation,
reuse=reuse,
name="hidden_{}".format(i),
- kernel_initializer=c_layers.variance_scaling_initializer(1.0),
+ kernel_initializer=tf.initializers.variance_scaling(1.0),
)
return hidden
@@ -283,7 +286,7 @@ def create_visual_observation_encoder(
reuse=reuse,
name="conv_2",
)
- hidden = c_layers.flatten(conv2)
+ hidden = tf.layers.flatten(conv2)
with tf.variable_scope(scope + "/" + "flat_encoding"):
hidden_flat = LearningModel.create_vector_observation_encoder(
@@ -338,7 +341,7 @@ def create_nature_cnn_visual_observation_encoder(
reuse=reuse,
name="conv_3",
)
- hidden = c_layers.flatten(conv3)
+ hidden = tf.layers.flatten(conv3)
with tf.variable_scope(scope + "/" + "flat_encoding"):
hidden_flat = LearningModel.create_vector_observation_encoder(
@@ -406,7 +409,7 @@ def create_resnet_visual_observation_encoder(
)
hidden = tf.add(block_input, hidden)
hidden = tf.nn.relu(hidden)
- hidden = c_layers.flatten(hidden)
+ hidden = tf.layers.flatten(hidden)
with tf.variable_scope(scope + "/" + "flat_encoding"):
hidden_flat = LearningModel.create_vector_observation_encoder(
@@ -553,8 +556,8 @@ def create_recurrent_encoder(input_state, memory_in, sequence_length, name="lstm
memory_in = tf.reshape(memory_in[:, :], [-1, m_size])
half_point = int(m_size / 2)
with tf.variable_scope(name):
- rnn_cell = tf.contrib.rnn.BasicLSTMCell(half_point)
- lstm_vector_in = tf.contrib.rnn.LSTMStateTuple(
+ rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(half_point)
+ lstm_vector_in = tf.nn.rnn_cell.LSTMStateTuple(
memory_in[:, :half_point], memory_in[:, half_point:]
)
recurrent_output, lstm_state_out = tf.nn.dynamic_rnn(
@@ -573,7 +576,6 @@ def create_value_heads(self, stream_names, hidden_input):
:param hidden_input: The last layer of the Critic. The heads will consist of one dense hidden layer on top
of the hidden input.
"""
- self.value_heads = {}
for name in stream_names:
value = tf.layers.dense(hidden_input, 1, name="{}_value".format(name))
self.value_heads[name] = value
diff --git a/ml-agents/mlagents/trainers/ppo/models.py b/ml-agents/mlagents/trainers/ppo/models.py
index deac591129..805831d10e 100644
--- a/ml-agents/mlagents/trainers/ppo/models.py
+++ b/ml-agents/mlagents/trainers/ppo/models.py
@@ -1,7 +1,8 @@
import logging
-import numpy as np
+from typing import Optional
-import tensorflow as tf
+import numpy as np
+from mlagents.tf_utils import tf
from mlagents.trainers.models import LearningModel, EncoderType, LearningRateSchedule
logger = logging.getLogger("mlagents.trainers")
@@ -46,6 +47,11 @@ def __init__(
LearningModel.__init__(
self, m_size, normalize, use_recurrent, brain, seed, stream_names
)
+
+ self.optimizer: Optional[tf.train.AdamOptimizer] = None
+ self.grads = None
+ self.update_batch: Optional[tf.Operation] = None
+
if num_layers < 1:
num_layers = 1
if brain.vector_action_space_type == "continuous":
@@ -203,9 +209,7 @@ def create_dc_actor_critic(
)
)
- self.all_log_probs = tf.concat(
- [branch for branch in policy_branches], axis=1, name="action_probs"
- )
+ self.all_log_probs = tf.concat(policy_branches, axis=1, name="action_probs")
self.action_masks = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks"
diff --git a/ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py b/ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py
index 7230dff452..a13c845c14 100644
--- a/ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py
+++ b/ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py
@@ -1,11 +1,15 @@
import logging
+from typing import Any, Dict, List, Optional
+
+from mlagents.tf_utils import tf
-import tensorflow as tf
from tensorflow.python.client import device_lib
+from mlagents.envs.brain import BrainParameters
from mlagents.envs.timers import timed
from mlagents.trainers.models import EncoderType, LearningRateSchedule
from mlagents.trainers.ppo.policy import PPOPolicy
from mlagents.trainers.ppo.models import PPOModel
+from mlagents.trainers.components.reward_signals import RewardSignal
from mlagents.trainers.components.reward_signals.reward_signal_factory import (
create_reward_signal,
)
@@ -17,15 +21,21 @@
class MultiGpuPPOPolicy(PPOPolicy):
- def __init__(self, seed, brain, trainer_params, is_training, load):
- """
- Policy for Proximal Policy Optimization Networks with multi-GPU training
- :param seed: Random seed.
- :param brain: Assigned Brain object.
- :param trainer_params: Defined training parameters.
- :param is_training: Whether the model should be trained.
- :param load: Whether a pre-trained model will be loaded or a new one created.
- """
+ def __init__(
+ self,
+ seed: int,
+ brain: BrainParameters,
+ trainer_params: Dict[str, Any],
+ is_training: bool,
+ load: bool,
+ ):
+ self.towers: List[PPOModel] = []
+ self.devices: List[str] = []
+ self.model: Optional[PPOModel] = None
+ self.total_policy_loss: Optional[tf.Tensor] = None
+ self.reward_signal_towers: List[Dict[str, RewardSignal]] = []
+ self.reward_signals: Dict[str, RewardSignal] = {}
+
super().__init__(seed, brain, trainer_params, is_training, load)
def create_model(
@@ -39,7 +49,7 @@ def create_model(
:param seed: Random seed.
"""
self.devices = get_devices()
- self.towers = []
+
with self.graph.as_default():
with tf.variable_scope("", reuse=tf.AUTO_REUSE):
for device in self.devices:
@@ -116,7 +126,6 @@ def create_reward_signals(self, reward_signal_configs):
Create reward signals
:param reward_signal_configs: Reward signal config.
"""
- self.reward_signal_towers = []
with self.graph.as_default():
with tf.variable_scope(TOWER_SCOPE_NAME, reuse=tf.AUTO_REUSE):
for device_id, device in enumerate(self.devices):
@@ -201,7 +210,7 @@ def average_gradients(self, tower_grads):
return average_grads
-def get_devices():
+def get_devices() -> List[str]:
"""
Get all available GPU devices
"""
diff --git a/ml-agents/mlagents/trainers/ppo/policy.py b/ml-agents/mlagents/trainers/ppo/policy.py
index a81b97b03d..53c2ad0039 100644
--- a/ml-agents/mlagents/trainers/ppo/policy.py
+++ b/ml-agents/mlagents/trainers/ppo/policy.py
@@ -1,7 +1,8 @@
import logging
import numpy as np
from typing import Any, Dict, Optional
-import tensorflow as tf
+
+from mlagents.tf_utils import tf
from mlagents.envs.timers import timed
from mlagents.envs.brain import BrainInfo, BrainParameters
@@ -151,14 +152,10 @@ def evaluate(self, brain_info):
epsilon = None
if self.use_recurrent:
if not self.use_continuous_act:
- feed_dict[
- self.model.prev_action
- ] = brain_info.previous_vector_actions.reshape(
- [-1, len(self.model.act_size)]
+ feed_dict[self.model.prev_action] = self.retrieve_previous_action(
+ brain_info.agents
)
- if brain_info.memories.shape[1] == 0:
- brain_info.memories = self.make_empty_memory(len(brain_info.agents))
- feed_dict[self.model.memory_in] = brain_info.memories
+ feed_dict[self.model.memory_in] = self.retrieve_memories(brain_info.agents)
if self.use_continuous_act:
epsilon = np.random.normal(
size=(len(brain_info.vector_observations), self.model.act_size[0])
@@ -253,13 +250,9 @@ def get_value_estimates(
if self.use_vec_obs:
feed_dict[self.model.vector_in] = [brain_info.vector_observations[idx]]
if self.use_recurrent:
- if brain_info.memories.shape[1] == 0:
- brain_info.memories = self.make_empty_memory(len(brain_info.agents))
- feed_dict[self.model.memory_in] = [brain_info.memories[idx]]
+ feed_dict[self.model.memory_in] = self.retrieve_memories([idx])
if not self.use_continuous_act and self.use_recurrent:
- feed_dict[self.model.prev_action] = [
- brain_info.previous_vector_actions[idx]
- ]
+ feed_dict[self.model.prev_action] = self.retrieve_previous_action([idx])
value_estimates = self.sess.run(self.model.value_heads, feed_dict)
value_estimates = {k: float(v) for k, v in value_estimates.items()}
diff --git a/ml-agents/mlagents/trainers/ppo/trainer.py b/ml-agents/mlagents/trainers/ppo/trainer.py
index 19d6a17eff..2c27422a27 100644
--- a/ml-agents/mlagents/trainers/ppo/trainer.py
+++ b/ml-agents/mlagents/trainers/ppo/trainer.py
@@ -8,7 +8,7 @@
import numpy as np
-from mlagents.envs.brain import AllBrainInfo
+from mlagents.envs.brain import BrainInfo
from mlagents.trainers.ppo.policy import PPOPolicy
from mlagents.trainers.ppo.multi_gpu_policy import MultiGpuPPOPolicy, get_devices
from mlagents.trainers.rl_trainer import RLTrainer, AllRewardsOutput
@@ -79,34 +79,33 @@ def __init__(
self.collected_rewards[_reward_signal] = {}
def process_experiences(
- self, current_info: AllBrainInfo, new_info: AllBrainInfo
+ self, current_info: BrainInfo, next_info: BrainInfo
) -> None:
"""
Checks agent histories for processing condition, and processes them as necessary.
Processing involves calculating value and advantage targets for model updating step.
- :param current_info: Dictionary of all current brains and corresponding BrainInfo.
- :param new_info: Dictionary of all next brains and corresponding BrainInfo.
+ :param current_info: current BrainInfo.
+ :param next_info: next BrainInfo.
"""
- info = new_info[self.brain_name]
if self.is_training:
- self.policy.update_normalization(info.vector_observations)
- for l in range(len(info.agents)):
- agent_actions = self.training_buffer[info.agents[l]]["actions"]
+ self.policy.update_normalization(next_info.vector_observations)
+ for l in range(len(next_info.agents)):
+ agent_actions = self.training_buffer[next_info.agents[l]]["actions"]
if (
- info.local_done[l]
+ next_info.local_done[l]
or len(agent_actions) > self.trainer_parameters["time_horizon"]
) and len(agent_actions) > 0:
- agent_id = info.agents[l]
- if info.max_reached[l]:
+ agent_id = next_info.agents[l]
+ if next_info.max_reached[l]:
bootstrapping_info = self.training_buffer[agent_id].last_brain_info
idx = bootstrapping_info.agents.index(agent_id)
else:
- bootstrapping_info = info
+ bootstrapping_info = next_info
idx = l
value_next = self.policy.get_value_estimates(
bootstrapping_info,
idx,
- info.local_done[l] and not info.max_reached[l],
+ next_info.local_done[l] and not next_info.max_reached[l],
)
tmp_advantages = []
@@ -150,7 +149,7 @@ def process_experiences(
)
self.training_buffer[agent_id].reset_agent()
- if info.local_done[l]:
+ if next_info.local_done[l]:
self.stats["Environment/Episode Length"].append(
self.episode_steps.get(agent_id, 0)
)
@@ -228,7 +227,7 @@ def update_policy(self):
number_experiences=buffer_length,
mean_return=float(np.mean(self.cumulative_returns_since_policy_update)),
)
- self.cumulative_returns_since_policy_update = []
+ self.cumulative_returns_since_policy_update.clear()
# Make sure batch_size is a multiple of sequence length. During training, we
# will need to reshape the data into a batch_size x sequence_length tensor.
diff --git a/ml-agents/mlagents/trainers/rl_trainer.py b/ml-agents/mlagents/trainers/rl_trainer.py
index ce3bb9e377..bbb892212d 100644
--- a/ml-agents/mlagents/trainers/rl_trainer.py
+++ b/ml-agents/mlagents/trainers/rl_trainer.py
@@ -3,7 +3,7 @@
from typing import Dict, List, Any, NamedTuple
import numpy as np
-from mlagents.envs.brain import AllBrainInfo, BrainInfo
+from mlagents.envs.brain import BrainInfo
from mlagents.envs.action_info import ActionInfoOutputs
from mlagents.trainers.buffer import Buffer
from mlagents.trainers.trainer import Trainer, UnityTrainerException
@@ -57,14 +57,10 @@ def construct_curr_info(self, next_info: BrainInfo) -> BrainInfo:
[] for _ in next_info.visual_observations
] # TODO add types to brain.py methods
vector_observations = []
- text_observations = []
- memories = []
rewards = []
local_dones = []
max_reacheds = []
agents = []
- prev_vector_actions = []
- prev_text_actions = []
action_masks = []
for agent_id in next_info.agents:
agent_brain_info = self.training_buffer[agent_id].last_brain_info
@@ -78,36 +74,17 @@ def construct_curr_info(self, next_info: BrainInfo) -> BrainInfo:
vector_observations.append(
agent_brain_info.vector_observations[agent_index]
)
- text_observations.append(agent_brain_info.text_observations[agent_index])
- if self.policy.use_recurrent:
- if len(agent_brain_info.memories) > 0:
- memories.append(agent_brain_info.memories[agent_index])
- else:
- memories.append(self.policy.make_empty_memory(1))
rewards.append(agent_brain_info.rewards[agent_index])
local_dones.append(agent_brain_info.local_done[agent_index])
max_reacheds.append(agent_brain_info.max_reached[agent_index])
agents.append(agent_brain_info.agents[agent_index])
- prev_vector_actions.append(
- agent_brain_info.previous_vector_actions[agent_index]
- )
- prev_text_actions.append(
- agent_brain_info.previous_text_actions[agent_index]
- )
action_masks.append(agent_brain_info.action_masks[agent_index])
- # Check if memories exists (i.e. next_info is not empty) before attempting vstack
- if self.policy.use_recurrent and memories:
- memories = np.vstack(memories)
curr_info = BrainInfo(
visual_observations,
vector_observations,
- text_observations,
- memories,
rewards,
agents,
local_dones,
- prev_vector_actions,
- prev_text_actions,
max_reacheds,
action_masks,
)
@@ -115,14 +92,14 @@ def construct_curr_info(self, next_info: BrainInfo) -> BrainInfo:
def add_experiences(
self,
- curr_all_info: AllBrainInfo,
- next_all_info: AllBrainInfo,
+ curr_info: BrainInfo,
+ next_info: BrainInfo,
take_action_outputs: ActionInfoOutputs,
) -> None:
"""
Adds experiences to each agent's experience history.
- :param curr_all_info: Dictionary of all current brains and corresponding BrainInfo.
- :param next_all_info: Dictionary of all current brains and corresponding BrainInfo.
+ :param curr_info: current BrainInfo.
+ :param next_info: next BrainInfo.
:param take_action_outputs: The outputs of the Policy's get_action method.
"""
self.trainer_metrics.start_experience_collection_timer()
@@ -136,9 +113,6 @@ def add_experiences(
np.mean(take_action_outputs["value_heads"][name])
)
- curr_info = curr_all_info[self.brain_name]
- next_info = next_all_info[self.brain_name]
-
for agent_id in curr_info.agents:
self.training_buffer[agent_id].last_brain_info = curr_info
self.training_buffer[
@@ -153,7 +127,9 @@ def add_experiences(
# Evaluate and store the reward signals
tmp_reward_signal_outs = {}
for name, signal in self.policy.reward_signals.items():
- tmp_reward_signal_outs[name] = signal.evaluate(curr_to_use, next_info)
+ tmp_reward_signal_outs[name] = signal.evaluate(
+ curr_to_use, take_action_outputs["action"], next_info
+ )
# Store the environment reward
tmp_environment = np.array(next_info.rewards)
@@ -185,12 +161,8 @@ def add_experiences(
next_info.vector_observations[next_idx]
)
if self.policy.use_recurrent:
- if stored_info.memories.shape[1] == 0:
- stored_info.memories = np.zeros(
- (len(stored_info.agents), self.policy.m_size)
- )
self.training_buffer[agent_id]["memory"].append(
- stored_info.memories[idx]
+ self.policy.retrieve_memories([agent_id])[0, :]
)
self.training_buffer[agent_id]["masks"].append(1.0)
@@ -199,13 +171,13 @@ def add_experiences(
)
# Add the outputs of the last eval
self.add_policy_outputs(stored_take_action_outputs, agent_id, idx)
- # Store action masks if neccessary
+ # Store action masks if necessary
if not self.policy.use_continuous_act:
self.training_buffer[agent_id]["action_mask"].append(
stored_info.action_masks[idx], padding_value=1
)
self.training_buffer[agent_id]["prev_action"].append(
- stored_info.previous_vector_actions[idx]
+ self.policy.retrieve_previous_action([agent_id])[0, :]
)
values = stored_take_action_outputs["value_heads"]
@@ -230,6 +202,9 @@ def add_experiences(
if agent_id not in self.episode_steps:
self.episode_steps[agent_id] = 0
self.episode_steps[agent_id] += 1
+ self.policy.save_previous_action(
+ curr_info.agents, take_action_outputs["action"]
+ )
self.trainer_metrics.end_experience_collection_timer()
def end_episode(self) -> None:
@@ -263,7 +238,7 @@ def add_policy_outputs(
:param agent_idx: the index of the Agent agent_id
"""
raise UnityTrainerException(
- "The process_experiences method was not implemented."
+ "The add_policy_outputs method was not implemented."
)
def add_rewards_outputs(
@@ -285,5 +260,5 @@ def add_rewards_outputs(
:param agent_next_idx: the index of the Agent agent_id in the next brain info
"""
raise UnityTrainerException(
- "The process_experiences method was not implemented."
+ "The add_rewards_outputs method was not implemented."
)
diff --git a/ml-agents/mlagents/trainers/sac/models.py b/ml-agents/mlagents/trainers/sac/models.py
index 33266ae7d9..1e2911d8e8 100644
--- a/ml-agents/mlagents/trainers/sac/models.py
+++ b/ml-agents/mlagents/trainers/sac/models.py
@@ -1,9 +1,10 @@
import logging
import numpy as np
+from typing import Dict, List, Optional
+
+from mlagents.tf_utils import tf
-import tensorflow as tf
from mlagents.trainers.models import LearningModel, LearningRateSchedule, EncoderType
-import tensorflow.contrib.layers as c_layers
LOG_STD_MAX = 2
LOG_STD_MIN = -20
@@ -44,6 +45,43 @@ def __init__(
self.h_size = h_size
self.activ_fn = self.swish
+ self.policy_memory_in: Optional[tf.Tensor] = None
+ self.policy_memory_out: Optional[tf.Tensor] = None
+ self.value_memory_in: Optional[tf.Tensor] = None
+ self.value_memory_out: Optional[tf.Tensor] = None
+ self.q1: Optional[tf.Tensor] = None
+ self.q2: Optional[tf.Tensor] = None
+ self.q1_p: Optional[tf.Tensor] = None
+ self.q2_p: Optional[tf.Tensor] = None
+ self.q1_memory_in: Optional[tf.Tensor] = None
+ self.q2_memory_in: Optional[tf.Tensor] = None
+ self.q1_memory_out: Optional[tf.Tensor] = None
+ self.q2_memory_out: Optional[tf.Tensor] = None
+ self.action_holder: Optional[tf.Tensor] = None
+ self.prev_action: Optional[tf.Tensor] = None
+ self.action_masks: Optional[tf.Tensor] = None
+ self.external_action_in: Optional[tf.Tensor] = None
+ self.log_sigma_sq: Optional[tf.Tensor] = None
+ self.entropy: Optional[tf.Tensor] = None
+ self.deterministic_output: Optional[tf.Tensor] = None
+ self.all_log_probs: Optional[tf.Tensor] = None
+ self.normalized_logprobs: Optional[tf.Tensor] = None
+ self.action_probs: Optional[tf.Tensor] = None
+ self.selected_actions: Optional[tf.Tensor] = None
+ self.output: Optional[tf.Tensor] = None
+ self.output_oh: Optional[tf.Tensor] = None
+ self.output_pre: Optional[tf.Tensor] = None
+
+ self.value_vars = None
+ self.q_vars = None
+ self.critic_vars = None
+ self.policy_vars = None
+
+ self.q1_heads: Optional[Dict[str, tf.Tensor]] = None
+ self.q2_heads: Optional[Dict[str, tf.Tensor]] = None
+ self.q1_pheads: Optional[Dict[str, tf.Tensor]] = None
+ self.q2_pheads: Optional[Dict[str, tf.Tensor]] = None
+
def get_vars(self, scope):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
@@ -275,14 +313,11 @@ def create_dc_actor(self, hidden_policy, scope):
size,
activation=None,
use_bias=False,
- kernel_initializer=c_layers.variance_scaling_initializer(
- factor=0.01
- ),
+ kernel_initializer=tf.initializers.variance_scaling(0.01),
)
)
- all_logits = tf.concat(
- [branch for branch in policy_branches], axis=1, name="action_probs"
- )
+ all_logits = tf.concat(policy_branches, axis=1, name="action_probs")
+
output, normalized_probs, normalized_logprobs = self.create_discrete_action_masking_layer(
all_logits, self.action_masks, self.act_size
)
@@ -344,7 +379,6 @@ def create_sac_value_head(
:param h_size: size of hidden layers for value network
:param scope: TF scope for value network.
"""
- self.value_heads = {}
with tf.variable_scope(scope):
value_hidden = self.create_vector_observation_encoder(
hidden_input, h_size, self.activ_fn, num_layers, "encoder", False
@@ -670,6 +704,12 @@ def __init__(
if num_layers < 1:
num_layers = 1
+ self.target_init_op: List[tf.Tensor] = []
+ self.target_update_op: List[tf.Tensor] = []
+ self.update_batch_policy: Optional[tf.Operation] = None
+ self.update_batch_value: Optional[tf.Operation] = None
+ self.update_batch_entropy: Optional[tf.Operation] = None
+
self.policy_network = SACPolicyNetwork(
brain=brain,
m_size=m_size,
diff --git a/ml-agents/mlagents/trainers/sac/policy.py b/ml-agents/mlagents/trainers/sac/policy.py
index b8b996a427..54c0d5892a 100644
--- a/ml-agents/mlagents/trainers/sac/policy.py
+++ b/ml-agents/mlagents/trainers/sac/policy.py
@@ -1,7 +1,7 @@
import logging
from typing import Dict, Any, Optional
import numpy as np
-import tensorflow as tf
+from mlagents.tf_utils import tf
from mlagents.envs.timers import timed
from mlagents.envs.brain import BrainInfo, BrainParameters
@@ -175,14 +175,10 @@ def evaluate(self, brain_info: BrainInfo) -> Dict[str, np.ndarray]:
}
if self.use_recurrent:
if not self.use_continuous_act:
- feed_dict[
- self.model.prev_action
- ] = brain_info.previous_vector_actions.reshape(
- [-1, len(self.model.act_size)]
+ feed_dict[self.model.prev_action] = self.retrieve_previous_action(
+ brain_info.agents
)
- if brain_info.memories.shape[1] == 0:
- brain_info.memories = self.make_empty_memory(len(brain_info.agents))
- feed_dict[self.model.memory_in] = brain_info.memories
+ feed_dict[self.model.memory_in] = self.retrieve_memories(brain_info.agents)
feed_dict = self.fill_eval_dict(feed_dict, brain_info)
run_out = self._execute_model(feed_dict, self.inference_dict)
@@ -190,7 +186,7 @@ def evaluate(self, brain_info: BrainInfo) -> Dict[str, np.ndarray]:
@timed
def update(
- self, mini_batch: Dict[str, Any], num_sequences: int, update_target: bool = True
+ self, mini_batch: Dict[str, Any], num_sequences: int
) -> Dict[str, float]:
"""
Updates model using buffer.
@@ -207,8 +203,8 @@ def update(
update_vals = self._execute_model(feed_dict, self.update_dict)
for stat_name, update_name in stats_needed.items():
update_stats[stat_name] = update_vals[update_name]
- if update_target:
- self.sess.run(self.model.target_update_op)
+ # Update target network. By default, target update happens at every policy update.
+ self.sess.run(self.model.target_update_op)
return update_stats
def update_reward_signals(
diff --git a/ml-agents/mlagents/trainers/sac/trainer.py b/ml-agents/mlagents/trainers/sac/trainer.py
index c82a47e521..5bbc077185 100644
--- a/ml-agents/mlagents/trainers/sac/trainer.py
+++ b/ml-agents/mlagents/trainers/sac/trainer.py
@@ -5,12 +5,12 @@
import logging
from collections import defaultdict
-from typing import List, Dict
+from typing import Dict
import os
import numpy as np
-from mlagents.envs.brain import AllBrainInfo
+from mlagents.envs.brain import BrainInfo
from mlagents.envs.action_info import ActionInfoOutputs
from mlagents.envs.timers import timed
from mlagents.trainers.sac.policy import SACPolicy
@@ -159,26 +159,25 @@ def add_rewards_outputs(
)
def process_experiences(
- self, current_info: AllBrainInfo, new_info: AllBrainInfo
+ self, current_info: BrainInfo, next_info: BrainInfo
) -> None:
"""
Checks agent histories for processing condition, and processes them as necessary.
- :param current_info: Dictionary of all current brains and corresponding BrainInfo.
- :param new_info: Dictionary of all next brains and corresponding BrainInfo.
+ :param current_info: current BrainInfo.
+ :param next_info: next BrainInfo.
"""
- info = new_info[self.brain_name]
if self.is_training:
- self.policy.update_normalization(info.vector_observations)
- for l in range(len(info.agents)):
- agent_actions = self.training_buffer[info.agents[l]]["actions"]
+ self.policy.update_normalization(next_info.vector_observations)
+ for l in range(len(next_info.agents)):
+ agent_actions = self.training_buffer[next_info.agents[l]]["actions"]
if (
- info.local_done[l]
+ next_info.local_done[l]
or len(agent_actions) >= self.trainer_parameters["time_horizon"]
) and len(agent_actions) > 0:
- agent_id = info.agents[l]
+ agent_id = next_info.agents[l]
# Bootstrap using last brain info. Set last element to duplicate obs and remove dones.
- if info.max_reached[l]:
+ if next_info.max_reached[l]:
bootstrapping_info = self.training_buffer[agent_id].last_brain_info
idx = bootstrapping_info.agents.index(agent_id)
for i, obs in enumerate(bootstrapping_info.visual_observations):
@@ -198,7 +197,7 @@ def process_experiences(
)
self.training_buffer[agent_id].reset_agent()
- if info.local_done[l]:
+ if next_info.local_done[l]:
self.stats["Environment/Episode Length"].append(
self.episode_steps.get(agent_id, 0)
)
@@ -254,7 +253,7 @@ def update_sac_policy(self) -> None:
is greater than 1 and the reward signals are not updated in parallel.
"""
- self.cumulative_returns_since_policy_update: List[float] = []
+ self.cumulative_returns_since_policy_update.clear()
n_sequences = max(
int(self.trainer_parameters["batch_size"] / self.policy.sequence_length), 1
)
@@ -278,9 +277,7 @@ def update_sac_policy(self) -> None:
"{}_rewards".format(name)
] = signal.evaluate_batch(sampled_minibatch).scaled_reward
- update_stats = self.policy.update(
- sampled_minibatch, n_sequences, update_target=True
- )
+ update_stats = self.policy.update(sampled_minibatch, n_sequences)
for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)
diff --git a/ml-agents/mlagents/trainers/tensorflow_to_barracuda.py b/ml-agents/mlagents/trainers/tensorflow_to_barracuda.py
index 521252c0b0..358a8f0dd0 100644
--- a/ml-agents/mlagents/trainers/tensorflow_to_barracuda.py
+++ b/ml-agents/mlagents/trainers/tensorflow_to_barracuda.py
@@ -1,7 +1,9 @@
+# pylint: skip-file
+# flake8: noqa
from __future__ import print_function
import numpy as np
import struct # convert from Python values and C structs
-import tensorflow as tf
+from mlagents.tf_utils import tf
import re
# import barracuda
diff --git a/ml-agents/mlagents/trainers/tests/mock_brain.py b/ml-agents/mlagents/trainers/tests/mock_brain.py
index aaf03e8f04..0a56eed87f 100644
--- a/ml-agents/mlagents/trainers/tests/mock_brain.py
+++ b/ml-agents/mlagents/trainers/tests/mock_brain.py
@@ -7,7 +7,6 @@
def create_mock_brainparams(
number_visual_observations=0,
- num_stacked_vector_observations=1,
vector_action_space_type="continuous",
vector_observation_space_size=3,
vector_action_space_size=None,
@@ -20,9 +19,6 @@ def create_mock_brainparams(
vector_action_space_size = [2]
mock_brain = mock.Mock()
mock_brain.return_value.number_visual_observations = number_visual_observations
- mock_brain.return_value.num_stacked_vector_observations = (
- num_stacked_vector_observations
- )
mock_brain.return_value.vector_action_space_type = vector_action_space_type
mock_brain.return_value.vector_observation_space_size = (
vector_observation_space_size
@@ -74,8 +70,6 @@ def create_mock_braininfo(
mock_braininfo.return_value.memories = np.ones((num_agents, 8))
mock_braininfo.return_value.rewards = num_agents * [1.0]
mock_braininfo.return_value.local_done = num_agents * [False]
- mock_braininfo.return_value.text_observations = num_agents * [""]
- mock_braininfo.return_value.previous_text_actions = num_agents * [""]
mock_braininfo.return_value.max_reached = num_agents * [100]
mock_braininfo.return_value.action_masks = num_agents * [num_vector_acts * [1.0]]
mock_braininfo.return_value.agents = range(0, num_agents)
@@ -135,8 +129,11 @@ def create_buffer(brain_infos, brain_params, sequence_length, memory_size=8):
buffer[0]["next_vector_in"].append(
current_brain_info.vector_observations[0]
)
- buffer[0]["actions"].append(next_brain_info.previous_vector_actions[0])
- buffer[0]["prev_action"].append(current_brain_info.previous_vector_actions[0])
+ fake_action_size = len(brain_params.vector_action_space_size)
+ if brain_params.vector_action_space_type == "continuous":
+ fake_action_size = brain_params.vector_action_space_size[0]
+ buffer[0]["actions"].append(np.zeros(fake_action_size))
+ buffer[0]["prev_action"].append(np.zeros(fake_action_size))
buffer[0]["masks"].append(1.0)
buffer[0]["advantages"].append(1.0)
if brain_params.vector_action_space_type == "discrete":
@@ -240,9 +237,8 @@ def create_mock_banana_brain():
def make_brain_parameters(
discrete_action: bool = False,
visual_inputs: int = 0,
- stack: bool = True,
brain_name: str = "RealFakeBrain",
- vec_obs_size: int = 3,
+ vec_obs_size: int = 6,
) -> BrainParameters:
resolutions = [
CameraResolution(width=30, height=40, num_channels=3)
@@ -251,7 +247,6 @@ def make_brain_parameters(
return BrainParameters(
vector_observation_space_size=vec_obs_size,
- num_stacked_vector_observations=2 if stack else 1,
camera_resolutions=resolutions,
vector_action_space_size=[2],
vector_action_descriptions=["", ""],
diff --git a/ml-agents/mlagents/trainers/tests/test.demo b/ml-agents/mlagents/trainers/tests/test.demo
index 3148108ca0..e5c689aea2 100644
Binary files a/ml-agents/mlagents/trainers/tests/test.demo and b/ml-agents/mlagents/trainers/tests/test.demo differ
diff --git a/ml-agents/mlagents/trainers/tests/test_bc.py b/ml-agents/mlagents/trainers/tests/test_bc.py
index 12acf5c5b9..19c2a73187 100644
--- a/ml-agents/mlagents/trainers/tests/test_bc.py
+++ b/ml-agents/mlagents/trainers/tests/test_bc.py
@@ -3,7 +3,7 @@
import os
import numpy as np
-import tensorflow as tf
+from mlagents.tf_utils import tf
import yaml
from mlagents.trainers.bc.models import BehavioralCloningModel
@@ -25,7 +25,7 @@ def dummy_config():
use_recurrent: false
sequence_length: 32
memory_size: 32
- batches_per_epoch: 1
+ batches_per_epoch: 100 # Force code to use all possible batches
batch_size: 32
summary_freq: 2000
max_steps: 4000
@@ -33,7 +33,7 @@ def dummy_config():
)
-def create_bc_trainer(dummy_config, is_discrete=False):
+def create_bc_trainer(dummy_config, is_discrete=False, use_recurrent=False):
mock_env = mock.Mock()
if is_discrete:
mock_brain = mb.create_mock_pushblock_brain()
@@ -54,6 +54,7 @@ def create_bc_trainer(dummy_config, is_discrete=False):
trainer_parameters["demo_path"] = (
os.path.dirname(os.path.abspath(__file__)) + "/test.demo"
)
+ trainer_parameters["use_recurrent"] = use_recurrent
trainer = BCTrainer(
mock_brain, trainer_parameters, training=True, load=False, seed=0, run_id=0
)
@@ -61,8 +62,9 @@ def create_bc_trainer(dummy_config, is_discrete=False):
return trainer, env
-def test_bc_trainer_step(dummy_config):
- trainer, env = create_bc_trainer(dummy_config)
+@pytest.mark.parametrize("use_recurrent", [True, False])
+def test_bc_trainer_step(dummy_config, use_recurrent):
+ trainer, env = create_bc_trainer(dummy_config, use_recurrent=use_recurrent)
# Test get_step
assert trainer.get_step == 0
# Test update policy
@@ -77,17 +79,20 @@ def test_bc_trainer_add_proc_experiences(dummy_config):
trainer, env = create_bc_trainer(dummy_config)
# Test add_experiences
returned_braininfo = env.step()
+ brain_name = "Ball3DBrain"
trainer.add_experiences(
- returned_braininfo, returned_braininfo, {}
+ returned_braininfo[brain_name], returned_braininfo[brain_name], {}
) # Take action outputs is not used
- for agent_id in returned_braininfo["Ball3DBrain"].agents:
+ for agent_id in returned_braininfo[brain_name].agents:
assert trainer.evaluation_buffer[agent_id].last_brain_info is not None
assert trainer.episode_steps[agent_id] > 0
assert trainer.cumulative_rewards[agent_id] > 0
# Test process_experiences by setting done
- returned_braininfo["Ball3DBrain"].local_done = 12 * [True]
- trainer.process_experiences(returned_braininfo, returned_braininfo)
- for agent_id in returned_braininfo["Ball3DBrain"].agents:
+ returned_braininfo[brain_name].local_done = 12 * [True]
+ trainer.process_experiences(
+ returned_braininfo[brain_name], returned_braininfo[brain_name]
+ )
+ for agent_id in returned_braininfo[brain_name].agents:
assert trainer.episode_steps[agent_id] == 0
assert trainer.cumulative_rewards[agent_id] == 0
@@ -95,13 +100,16 @@ def test_bc_trainer_add_proc_experiences(dummy_config):
def test_bc_trainer_end_episode(dummy_config):
trainer, env = create_bc_trainer(dummy_config)
returned_braininfo = env.step()
+ brain_name = "Ball3DBrain"
trainer.add_experiences(
- returned_braininfo, returned_braininfo, {}
+ returned_braininfo[brain_name], returned_braininfo[brain_name], {}
) # Take action outputs is not used
- trainer.process_experiences(returned_braininfo, returned_braininfo)
+ trainer.process_experiences(
+ returned_braininfo[brain_name], returned_braininfo[brain_name]
+ )
# Should set everything to 0
trainer.end_episode()
- for agent_id in returned_braininfo["Ball3DBrain"].agents:
+ for agent_id in returned_braininfo[brain_name].agents:
assert trainer.episode_steps[agent_id] == 0
assert trainer.cumulative_rewards[agent_id] == 0
diff --git a/ml-agents/mlagents/trainers/tests/test_bcmodule.py b/ml-agents/mlagents/trainers/tests/test_bcmodule.py
index 24f7b7efb7..3a26fd9f56 100644
--- a/ml-agents/mlagents/trainers/tests/test_bcmodule.py
+++ b/ml-agents/mlagents/trainers/tests/test_bcmodule.py
@@ -138,6 +138,26 @@ def test_bcmodule_update(mock_env, trainer_config):
env.close()
+# Test with constant pretraining learning rate
+@pytest.mark.parametrize(
+ "trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
+)
+@mock.patch("mlagents.envs.environment.UnityEnvironment")
+def test_bcmodule_constant_lr_update(mock_env, trainer_config):
+ mock_brain = mb.create_mock_3dball_brain()
+ trainer_config["pretraining"]["steps"] = 0
+ env, policy = create_policy_with_bc_mock(
+ mock_env, mock_brain, trainer_config, False, "test.demo"
+ )
+ stats = policy.bc_module.update()
+ for _, item in stats.items():
+ assert isinstance(item, np.float32)
+ old_learning_rate = policy.bc_module.current_lr
+
+ stats = policy.bc_module.update()
+ assert old_learning_rate == policy.bc_module.current_lr
+
+
# Test with RNN
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
diff --git a/ml-agents/mlagents/trainers/tests/test_demo_dir/test.demo b/ml-agents/mlagents/trainers/tests/test_demo_dir/test.demo
deleted file mode 100644
index 3148108ca0..0000000000
Binary files a/ml-agents/mlagents/trainers/tests/test_demo_dir/test.demo and /dev/null differ
diff --git a/ml-agents/mlagents/trainers/tests/test_demo_dir/test2.demo b/ml-agents/mlagents/trainers/tests/test_demo_dir/test2.demo
deleted file mode 100644
index 3148108ca0..0000000000
Binary files a/ml-agents/mlagents/trainers/tests/test_demo_dir/test2.demo and /dev/null differ
diff --git a/ml-agents/mlagents/trainers/tests/test_demo_dir/test3.demo b/ml-agents/mlagents/trainers/tests/test_demo_dir/test3.demo
deleted file mode 100644
index 3148108ca0..0000000000
Binary files a/ml-agents/mlagents/trainers/tests/test_demo_dir/test3.demo and /dev/null differ
diff --git a/ml-agents/mlagents/trainers/tests/test_demo_dir/test4.demo b/ml-agents/mlagents/trainers/tests/test_demo_dir/test4.demo
new file mode 100644
index 0000000000..2bb34b31e2
Binary files /dev/null and b/ml-agents/mlagents/trainers/tests/test_demo_dir/test4.demo differ
diff --git a/ml-agents/mlagents/trainers/tests/test_demo_dir/test5.demo b/ml-agents/mlagents/trainers/tests/test_demo_dir/test5.demo
new file mode 100644
index 0000000000..2bb34b31e2
Binary files /dev/null and b/ml-agents/mlagents/trainers/tests/test_demo_dir/test5.demo differ
diff --git a/ml-agents/mlagents/trainers/tests/test_demo_dir/test6.demo b/ml-agents/mlagents/trainers/tests/test_demo_dir/test6.demo
new file mode 100644
index 0000000000..2bb34b31e2
Binary files /dev/null and b/ml-agents/mlagents/trainers/tests/test_demo_dir/test6.demo differ
diff --git a/ml-agents/mlagents/trainers/tests/test_demo_loader.py b/ml-agents/mlagents/trainers/tests/test_demo_loader.py
index aa9fb6f753..bdc7cfb1a5 100644
--- a/ml-agents/mlagents/trainers/tests/test_demo_loader.py
+++ b/ml-agents/mlagents/trainers/tests/test_demo_loader.py
@@ -1,29 +1,29 @@
import os
-from mlagents.trainers.demo_loader import load_demonstration, make_demo_buffer
+from mlagents.trainers.demo_loader import load_demonstration, demo_to_buffer
def test_load_demo():
path_prefix = os.path.dirname(os.path.abspath(__file__))
- brain_parameters, brain_infos, total_expected = load_demonstration(
+ brain_parameters, pair_infos, total_expected = load_demonstration(
path_prefix + "/test.demo"
)
assert brain_parameters.brain_name == "Ball3DBrain"
assert brain_parameters.vector_observation_space_size == 8
- assert len(brain_infos) == total_expected
+ assert len(pair_infos) == total_expected
- demo_buffer = make_demo_buffer(brain_infos, brain_parameters, 1)
+ _, demo_buffer = demo_to_buffer(path_prefix + "/test.demo", 1)
assert len(demo_buffer.update_buffer["actions"]) == total_expected - 1
def test_load_demo_dir():
path_prefix = os.path.dirname(os.path.abspath(__file__))
- brain_parameters, brain_infos, total_expected = load_demonstration(
+ brain_parameters, pair_infos, total_expected = load_demonstration(
path_prefix + "/test_demo_dir"
)
- assert brain_parameters.brain_name == "Ball3DBrain"
+ assert brain_parameters.brain_name == "3DBall"
assert brain_parameters.vector_observation_space_size == 8
- assert len(brain_infos) == total_expected
+ assert len(pair_infos) == total_expected
- demo_buffer = make_demo_buffer(brain_infos, brain_parameters, 1)
+ _, demo_buffer = demo_to_buffer(path_prefix + "/test_demo_dir", 1)
assert len(demo_buffer.update_buffer["actions"]) == total_expected - 1
diff --git a/ml-agents/mlagents/trainers/tests/test_multigpu.py b/ml-agents/mlagents/trainers/tests/test_multigpu.py
index a7723d7f54..d74dfcd3bf 100644
--- a/ml-agents/mlagents/trainers/tests/test_multigpu.py
+++ b/ml-agents/mlagents/trainers/tests/test_multigpu.py
@@ -1,7 +1,7 @@
import unittest.mock as mock
import pytest
-import tensorflow as tf
+from mlagents.tf_utils import tf
import yaml
from mlagents.trainers.ppo.multi_gpu_policy import MultiGpuPPOPolicy
diff --git a/ml-agents/mlagents/trainers/tests/test_policy.py b/ml-agents/mlagents/trainers/tests/test_policy.py
index 34c09ed297..b5a42f9d31 100644
--- a/ml-agents/mlagents/trainers/tests/test_policy.py
+++ b/ml-agents/mlagents/trainers/tests/test_policy.py
@@ -20,16 +20,19 @@ def test_take_action_returns_empty_with_no_agents():
policy = TFPolicy(test_seed, basic_mock_brain(), basic_params())
no_agent_brain_info = BrainInfo([], [], [], agents=[])
result = policy.get_action(no_agent_brain_info)
- assert result == ActionInfo([], [], [], None, None)
+ assert result == ActionInfo([], [], None)
def test_take_action_returns_nones_on_missing_values():
test_seed = 3
policy = TFPolicy(test_seed, basic_mock_brain(), basic_params())
policy.evaluate = MagicMock(return_value={})
- brain_info_with_agents = BrainInfo([], [], [], agents=["an-agent-id"])
+ policy.save_memories = MagicMock()
+ brain_info_with_agents = BrainInfo(
+ [], [], [], agents=["an-agent-id"], local_done=[False]
+ )
result = policy.get_action(brain_info_with_agents)
- assert result == ActionInfo(None, None, None, None, {})
+ assert result == ActionInfo(None, None, {})
def test_take_action_returns_action_info_when_available():
@@ -37,17 +40,15 @@ def test_take_action_returns_action_info_when_available():
policy = TFPolicy(test_seed, basic_mock_brain(), basic_params())
policy_eval_out = {
"action": np.array([1.0]),
- "memory_out": np.array([2.5]),
+ "memory_out": np.array([[2.5]]),
"value": np.array([1.1]),
}
policy.evaluate = MagicMock(return_value=policy_eval_out)
- brain_info_with_agents = BrainInfo([], [], [], agents=["an-agent-id"])
+ brain_info_with_agents = BrainInfo(
+ [], [], [], agents=["an-agent-id"], local_done=[False]
+ )
result = policy.get_action(brain_info_with_agents)
expected = ActionInfo(
- policy_eval_out["action"],
- policy_eval_out["memory_out"],
- None,
- policy_eval_out["value"],
- policy_eval_out,
+ policy_eval_out["action"], policy_eval_out["value"], policy_eval_out
)
assert result == expected
diff --git a/ml-agents/mlagents/trainers/tests/test_ppo.py b/ml-agents/mlagents/trainers/tests/test_ppo.py
index 441ed6a3a8..78d937ac92 100644
--- a/ml-agents/mlagents/trainers/tests/test_ppo.py
+++ b/ml-agents/mlagents/trainers/tests/test_ppo.py
@@ -2,7 +2,8 @@
import pytest
import numpy as np
-import tensorflow as tf
+from mlagents.tf_utils import tf
+
import yaml
from mlagents.trainers.ppo.models import PPOModel
@@ -298,7 +299,14 @@ def test_rl_functions():
def test_trainer_increment_step(dummy_config):
trainer_params = dummy_config
- brain_params = BrainParameters("test_brain", 1, 1, [], [2], [], 0)
+ brain_params = BrainParameters(
+ brain_name="test_brain",
+ vector_observation_space_size=1,
+ camera_resolutions=[],
+ vector_action_space_size=[2],
+ vector_action_descriptions=[],
+ vector_action_space_type=0,
+ )
trainer = PPOTrainer(brain_params, 0, trainer_params, True, False, 0, "0", False)
policy_mock = mock.Mock()
@@ -355,7 +363,14 @@ def test_trainer_update_policy(mock_env, dummy_config, use_discrete):
def test_add_rewards_output(dummy_config):
- brain_params = BrainParameters("test_brain", 1, 1, [], [2], [], 0)
+ brain_params = BrainParameters(
+ brain_name="test_brain",
+ vector_observation_space_size=1,
+ camera_resolutions=[],
+ vector_action_space_size=[2],
+ vector_action_descriptions=[],
+ vector_action_space_type=0,
+ )
dummy_config["summary_path"] = "./summaries/test_trainer_summary"
dummy_config["model_path"] = "./models/test_trainer_models/TestModel"
trainer = PPOTrainer(brain_params, 0, dummy_config, True, False, 0, "0", False)
diff --git a/ml-agents/mlagents/trainers/tests/test_reward_signals.py b/ml-agents/mlagents/trainers/tests/test_reward_signals.py
index b1d8e9f5e4..b0b3d3c4c3 100644
--- a/ml-agents/mlagents/trainers/tests/test_reward_signals.py
+++ b/ml-agents/mlagents/trainers/tests/test_reward_signals.py
@@ -2,6 +2,7 @@
import pytest
import yaml
import os
+import numpy as np
import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.ppo.policy import PPOPolicy
from mlagents.trainers.sac.policy import SACPolicy
@@ -125,8 +126,9 @@ def reward_signal_eval(env, policy, reward_signal_name):
brain_info = brain_infos[env.external_brain_names[0]]
next_brain_info = env.step()[env.external_brain_names[0]]
# Test evaluate
+ action = np.ones((len(brain_info.agents), policy.num_branches))
rsig_result = policy.reward_signals[reward_signal_name].evaluate(
- brain_info, next_brain_info
+ brain_info, action, next_brain_info
)
assert rsig_result.scaled_reward.shape == (NUM_AGENTS,)
assert rsig_result.unscaled_reward.shape == (NUM_AGENTS,)
diff --git a/ml-agents/mlagents/trainers/tests/test_rl_trainer.py b/ml-agents/mlagents/trainers/tests/test_rl_trainer.py
index 36da0ae769..f23fb8bc46 100644
--- a/ml-agents/mlagents/trainers/tests/test_rl_trainer.py
+++ b/ml-agents/mlagents/trainers/tests/test_rl_trainer.py
@@ -43,6 +43,8 @@ def create_mock_all_brain_info(brain_info):
def create_mock_policy():
mock_policy = mock.Mock()
mock_policy.reward_signals = {}
+ mock_policy.retrieve_memories.return_value = np.zeros((1, 1))
+ mock_policy.retrieve_previous_action.return_value = np.zeros((1, 1))
return mock_policy
@@ -64,11 +66,7 @@ def test_rl_trainer(add_policy_outputs, add_rewards_outputs, num_vis_obs):
num_vector_acts=2,
num_vis_observations=num_vis_obs,
)
- trainer.add_experiences(
- create_mock_all_brain_info(mock_braininfo),
- create_mock_all_brain_info(mock_braininfo),
- fake_action_outputs,
- )
+ trainer.add_experiences(mock_braininfo, mock_braininfo, fake_action_outputs)
# Remove one of the agents
next_mock_braininfo = mb.create_mock_braininfo(
@@ -83,7 +81,6 @@ def test_rl_trainer(add_policy_outputs, add_rewards_outputs, num_vis_obs):
assert len(brain_info.agents) == 1
assert len(brain_info.visual_observations) == num_vis_obs
assert len(brain_info.vector_observations) == 1
- assert len(brain_info.previous_vector_actions) == 1
# Test end episode
trainer.end_episode()
diff --git a/ml-agents/mlagents/trainers/tests/test_sac.py b/ml-agents/mlagents/trainers/tests/test_sac.py
index b3ad9a8934..c3e426c283 100644
--- a/ml-agents/mlagents/trainers/tests/test_sac.py
+++ b/ml-agents/mlagents/trainers/tests/test_sac.py
@@ -3,7 +3,8 @@
import yaml
import numpy as np
-import tensorflow as tf
+from mlagents.tf_utils import tf
+
from mlagents.trainers.sac.models import SACModel
from mlagents.trainers.sac.policy import SACPolicy
diff --git a/ml-agents/mlagents/trainers/tests/test_simple_rl.py b/ml-agents/mlagents/trainers/tests/test_simple_rl.py
index 339fc7287b..c0fe1c4679 100644
--- a/ml-agents/mlagents/trainers/tests/test_simple_rl.py
+++ b/ml-agents/mlagents/trainers/tests/test_simple_rl.py
@@ -11,6 +11,10 @@
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.brain import BrainInfo, AllBrainInfo, BrainParameters
from mlagents.envs.communicator_objects.agent_info_pb2 import AgentInfoProto
+from mlagents.envs.communicator_objects.observation_pb2 import (
+ ObservationProto,
+ NONE as COMPRESSION_TYPE_NONE,
+)
from mlagents.envs.simple_env_manager import SimpleEnvManager
from mlagents.envs.sampler_class import SamplerManager
@@ -41,7 +45,6 @@ def __init__(self, use_discrete):
brain_params = BrainParameters(
brain_name=BRAIN_NAME,
vector_observation_space_size=OBS_SIZE,
- num_stacked_vector_observations=1,
camera_resolutions=[],
vector_action_space_size=[2] if use_discrete else [1],
vector_action_descriptions=["moveDirection"],
@@ -59,7 +62,6 @@ def step(
self,
vector_action: Dict[str, Any] = None,
memory: Dict[str, Any] = None,
- text_action: Dict[str, Any] = None,
value: Dict[str, Any] = None,
) -> AllBrainInfo:
assert vector_action is not None
@@ -79,8 +81,14 @@ def step(
else:
reward = -TIME_PENALTY
+ vector_obs = [self.goal] * OBS_SIZE
+ vector_obs_proto = ObservationProto(
+ float_data=ObservationProto.FloatData(data=vector_obs),
+ shape=[len(vector_obs)],
+ compression_type=COMPRESSION_TYPE_NONE,
+ )
agent_info = AgentInfoProto(
- stacked_vector_observation=[self.goal] * OBS_SIZE, reward=reward, done=done
+ reward=reward, done=bool(done), observations=[vector_obs_proto]
)
if done:
@@ -105,11 +113,16 @@ def reset(
) -> AllBrainInfo: # type: ignore
self._reset_agent()
+ vector_obs = [self.goal] * OBS_SIZE
+ vector_obs_proto = ObservationProto(
+ float_data=ObservationProto.FloatData(data=vector_obs),
+ shape=[len(vector_obs)],
+ compression_type=COMPRESSION_TYPE_NONE,
+ )
agent_info = AgentInfoProto(
- stacked_vector_observation=[self.goal] * OBS_SIZE,
- done=False,
- max_step_reached=False,
+ done=False, max_step_reached=False, observations=[vector_obs_proto]
)
+
return {
BRAIN_NAME: BrainInfo.from_agent_proto(
0, [agent_info], self._brains[BRAIN_NAME]
@@ -156,13 +169,13 @@ def close(self):
SAC_CONFIG = """
default:
trainer: sac
- batch_size: 32
- buffer_size: 10240
- buffer_init_steps: 1000
- hidden_units: 64
+ batch_size: 8
+ buffer_size: 500
+ buffer_init_steps: 100
+ hidden_units: 16
init_entcoef: 0.01
learning_rate: 5.0e-3
- max_steps: 2000
+ max_steps: 1000
memory_size: 256
normalize: false
num_update: 1
diff --git a/ml-agents/mlagents/trainers/tests/test_trainer_controller.py b/ml-agents/mlagents/trainers/tests/test_trainer_controller.py
index 3a918afc52..427ade988b 100644
--- a/ml-agents/mlagents/trainers/tests/test_trainer_controller.py
+++ b/ml-agents/mlagents/trainers/tests/test_trainer_controller.py
@@ -1,5 +1,7 @@
from unittest.mock import MagicMock, Mock, patch
+from mlagents.tf_utils import tf
+
import yaml
import pytest
@@ -56,7 +58,7 @@ def basic_trainer_controller():
@patch("numpy.random.seed")
-@patch("tensorflow.set_random_seed")
+@patch.object(tf, "set_random_seed")
def test_initialization_seed(numpy_random_seed, tensorflow_set_seed):
seed = 27
TrainerController(
@@ -102,7 +104,7 @@ def take_step_sideeffect(env):
return tc, trainer_mock
-@patch("tensorflow.reset_default_graph")
+@patch.object(tf, "reset_default_graph")
def test_start_learning_trains_forever_if_no_train_model(tf_reset_graph):
tc, trainer_mock = trainer_controller_with_start_learning_mocks()
tc.train_model = False
@@ -123,7 +125,7 @@ def test_start_learning_trains_forever_if_no_train_model(tf_reset_graph):
env_mock.close.assert_called_once()
-@patch("tensorflow.reset_default_graph")
+@patch.object(tf, "reset_default_graph")
def test_start_learning_trains_until_max_steps_then_saves(tf_reset_graph):
tc, trainer_mock = trainer_controller_with_start_learning_mocks()
tf_reset_graph.return_value = None
@@ -158,10 +160,12 @@ def trainer_controller_with_take_step_mocks():
def test_take_step_adds_experiences_to_trainer_and_trains():
tc, trainer_mock = trainer_controller_with_take_step_mocks()
- action_info_dict = {"testbrain": MagicMock()}
+ brain_name = "testbrain"
+ action_info_dict = {brain_name: MagicMock()}
- old_step_info = EnvironmentStep(Mock(), Mock(), action_info_dict)
- new_step_info = EnvironmentStep(Mock(), Mock(), action_info_dict)
+ brain_info_dict = {brain_name: Mock()}
+ old_step_info = EnvironmentStep(brain_info_dict, brain_info_dict, action_info_dict)
+ new_step_info = EnvironmentStep(brain_info_dict, brain_info_dict, action_info_dict)
trainer_mock.is_ready_update = MagicMock(return_value=True)
env_mock = MagicMock()
@@ -172,12 +176,13 @@ def test_take_step_adds_experiences_to_trainer_and_trains():
env_mock.reset.assert_not_called()
env_mock.step.assert_called_once()
trainer_mock.add_experiences.assert_called_once_with(
- new_step_info.previous_all_brain_info,
- new_step_info.current_all_brain_info,
- new_step_info.brain_name_to_action_info["testbrain"].outputs,
+ new_step_info.previous_all_brain_info[brain_name],
+ new_step_info.current_all_brain_info[brain_name],
+ new_step_info.brain_name_to_action_info[brain_name].outputs,
)
trainer_mock.process_experiences.assert_called_once_with(
- new_step_info.previous_all_brain_info, new_step_info.current_all_brain_info
+ new_step_info.previous_all_brain_info[brain_name],
+ new_step_info.current_all_brain_info[brain_name],
)
trainer_mock.update_policy.assert_called_once()
trainer_mock.increment_step.assert_called_once()
@@ -187,10 +192,13 @@ def test_take_step_if_not_training():
tc, trainer_mock = trainer_controller_with_take_step_mocks()
tc.train_model = False
- action_info_dict = {"testbrain": MagicMock()}
+ brain_name = "testbrain"
+ action_info_dict = {brain_name: MagicMock()}
+
+ brain_info_dict = {brain_name: Mock()}
+ old_step_info = EnvironmentStep(brain_info_dict, brain_info_dict, action_info_dict)
+ new_step_info = EnvironmentStep(brain_info_dict, brain_info_dict, action_info_dict)
- old_step_info = EnvironmentStep(Mock(), Mock(), action_info_dict)
- new_step_info = EnvironmentStep(Mock(), Mock(), action_info_dict)
trainer_mock.is_ready_update = MagicMock(return_value=False)
env_mock = MagicMock()
@@ -201,11 +209,12 @@ def test_take_step_if_not_training():
env_mock.reset.assert_not_called()
env_mock.step.assert_called_once()
trainer_mock.add_experiences.assert_called_once_with(
- new_step_info.previous_all_brain_info,
- new_step_info.current_all_brain_info,
- new_step_info.brain_name_to_action_info["testbrain"].outputs,
+ new_step_info.previous_all_brain_info[brain_name],
+ new_step_info.current_all_brain_info[brain_name],
+ new_step_info.brain_name_to_action_info[brain_name].outputs,
)
trainer_mock.process_experiences.assert_called_once_with(
- new_step_info.previous_all_brain_info, new_step_info.current_all_brain_info
+ new_step_info.previous_all_brain_info[brain_name],
+ new_step_info.current_all_brain_info[brain_name],
)
trainer_mock.clear_update_buffer.assert_called_once()
diff --git a/ml-agents/mlagents/trainers/tests/testdcvis.demo b/ml-agents/mlagents/trainers/tests/testdcvis.demo
index 3933a3920f..bb9c48dfca 100644
Binary files a/ml-agents/mlagents/trainers/tests/testdcvis.demo and b/ml-agents/mlagents/trainers/tests/testdcvis.demo differ
diff --git a/ml-agents/mlagents/trainers/tf_policy.py b/ml-agents/mlagents/trainers/tf_policy.py
index df33746d21..3fa96c6f0f 100644
--- a/ml-agents/mlagents/trainers/tf_policy.py
+++ b/ml-agents/mlagents/trainers/tf_policy.py
@@ -1,8 +1,8 @@
import logging
-from typing import Any, Dict
+from typing import Any, Dict, List, Optional
import numpy as np
-import tensorflow as tf
+from mlagents.tf_utils import tf
from mlagents.envs.exception import UnityException
from mlagents.envs.policy import Policy
@@ -56,8 +56,13 @@ def __init__(self, seed, brain, trainer_parameters):
self.seed = seed
self.brain = brain
self.use_recurrent = trainer_parameters["use_recurrent"]
+ self.memory_dict: Dict[int, np.ndarray] = {}
+ self.num_branches = len(self.brain.vector_action_space_size)
+ self.previous_action_dict: Dict[int, np.array] = {}
self.normalize = trainer_parameters.get("normalize", False)
self.use_continuous_act = brain.vector_action_space_type == "continuous"
+ if self.use_continuous_act:
+ self.num_branches = self.brain.vector_action_space_size[0]
self.model_path = trainer_parameters["model_path"]
self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5)
self.graph = tf.Graph()
@@ -121,15 +126,21 @@ def get_action(self, brain_info: BrainInfo) -> ActionInfo:
to be passed to add experiences
"""
if len(brain_info.agents) == 0:
- return ActionInfo([], [], [], None, None)
+ return ActionInfo([], [], None)
- run_out = self.evaluate(brain_info)
+ agents_done = [
+ agent
+ for agent, done in zip(brain_info.agents, brain_info.local_done)
+ if done
+ ]
+
+ self.remove_memories(agents_done)
+ self.remove_previous_action(agents_done)
+
+ run_out = self.evaluate(brain_info) # pylint: disable=assignment-from-no-return
+ self.save_memories(brain_info.agents, run_out.get("memory_out"))
return ActionInfo(
- action=run_out.get("action"),
- memory=run_out.get("memory_out"),
- text=None,
- value=run_out.get("value"),
- outputs=run_out,
+ action=run_out.get("action"), value=run_out.get("value"), outputs=run_out
)
def update(self, mini_batch, num_sequences):
@@ -167,7 +178,55 @@ def make_empty_memory(self, num_agents):
:param num_agents: Number of agents.
:return: Numpy array of zeros.
"""
- return np.zeros((num_agents, self.m_size))
+ return np.zeros((num_agents, self.m_size), dtype=np.float)
+
+ def save_memories(
+ self, agent_ids: List[int], memory_matrix: Optional[np.ndarray]
+ ) -> None:
+ if memory_matrix is None:
+ return
+ for index, agent_id in enumerate(agent_ids):
+ self.memory_dict[agent_id] = memory_matrix[index, :]
+
+ def retrieve_memories(self, agent_ids: List[int]) -> np.ndarray:
+ memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float)
+ for index, agent_id in enumerate(agent_ids):
+ if agent_id in self.memory_dict:
+ memory_matrix[index, :] = self.memory_dict[agent_id]
+ return memory_matrix
+
+ def remove_memories(self, agent_ids):
+ for agent_id in agent_ids:
+ if agent_id in self.memory_dict:
+ self.memory_dict.pop(agent_id)
+
+ def make_empty_previous_action(self, num_agents):
+ """
+ Creates empty previous action for use with RNNs and discrete control
+ :param num_agents: Number of agents.
+ :return: Numpy array of zeros.
+ """
+ return np.zeros((num_agents, self.num_branches), dtype=np.int)
+
+ def save_previous_action(
+ self, agent_ids: List[int], action_matrix: Optional[np.ndarray]
+ ) -> None:
+ if action_matrix is None:
+ return
+ for index, agent_id in enumerate(agent_ids):
+ self.previous_action_dict[agent_id] = action_matrix[index, :]
+
+ def retrieve_previous_action(self, agent_ids: List[int]) -> np.ndarray:
+ action_matrix = np.zeros((len(agent_ids), self.num_branches), dtype=np.int)
+ for index, agent_id in enumerate(agent_ids):
+ if agent_id in self.previous_action_dict:
+ action_matrix[index, :] = self.previous_action_dict[agent_id]
+ return action_matrix
+
+ def remove_previous_action(self, agent_ids):
+ for agent_id in agent_ids:
+ if agent_id in self.previous_action_dict:
+ self.previous_action_dict.pop(agent_id)
def get_current_step(self):
"""
diff --git a/ml-agents/mlagents/trainers/trainer.py b/ml-agents/mlagents/trainers/trainer.py
index cc239c8668..744b06e4fc 100644
--- a/ml-agents/mlagents/trainers/trainer.py
+++ b/ml-agents/mlagents/trainers/trainer.py
@@ -2,7 +2,9 @@
import logging
from typing import Dict, List, Deque, Any
import os
-import tensorflow as tf
+
+from mlagents.tf_utils import tf
+
import numpy as np
from collections import deque, defaultdict
@@ -11,7 +13,7 @@
from mlagents.envs.timers import set_gauge
from mlagents.trainers.trainer_metrics import TrainerMetrics
from mlagents.trainers.tf_policy import TFPolicy
-from mlagents.envs.brain import BrainParameters, AllBrainInfo
+from mlagents.envs.brain import BrainParameters, BrainInfo
LOGGER = logging.getLogger("mlagents.trainers")
@@ -236,28 +238,26 @@ def write_tensorboard_text(self, key: str, input_dict: Dict[str, Any]) -> None:
def add_experiences(
self,
- curr_all_info: AllBrainInfo,
- next_all_info: AllBrainInfo,
+ curr_info: BrainInfo,
+ next_info: BrainInfo,
take_action_outputs: ActionInfoOutputs,
) -> None:
"""
Adds experiences to each agent's experience history.
- :param curr_all_info: Dictionary of all current brains and corresponding BrainInfo.
- :param next_all_info: Dictionary of all current brains and corresponding BrainInfo.
+ :param curr_info: current BrainInfo.
+ :param next_info: next BrainInfo.
:param take_action_outputs: The outputs of the Policy's get_action method.
"""
- raise UnityTrainerException(
- "The process_experiences method was not implemented."
- )
+ raise UnityTrainerException("The add_experiences method was not implemented.")
def process_experiences(
- self, current_info: AllBrainInfo, next_info: AllBrainInfo
+ self, current_info: BrainInfo, next_info: BrainInfo
) -> None:
"""
Checks agent histories for processing condition, and processes them as necessary.
Processing involves calculating value and advantage targets for model updating step.
- :param current_info: Dictionary of all current-step brains and corresponding BrainInfo.
- :param next_info: Dictionary of all next-step brains and corresponding BrainInfo.
+ :param current_info: current BrainInfo.
+ :param next_info: next BrainInfo.
"""
raise UnityTrainerException(
"The process_experiences method was not implemented."
diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py
index 6007807d59..3edfac1ef2 100644
--- a/ml-agents/mlagents/trainers/trainer_controller.py
+++ b/ml-agents/mlagents/trainers/trainer_controller.py
@@ -8,7 +8,7 @@
from typing import Dict, List, Optional, Set
import numpy as np
-import tensorflow as tf
+from mlagents.tf_utils import tf
from time import time
from mlagents.envs.env_manager import EnvironmentStep
@@ -276,15 +276,15 @@ def advance(self, env: EnvManager) -> int:
for brain_name, trainer in self.trainers.items():
if brain_name in self.trainer_metrics:
self.trainer_metrics[brain_name].add_delta_step(delta_time_step)
- if brain_name in step_info.brain_name_to_action_info:
+ if step_info.has_actions_for_brain(brain_name):
trainer.add_experiences(
- step_info.previous_all_brain_info,
- step_info.current_all_brain_info,
+ step_info.previous_all_brain_info[brain_name],
+ step_info.current_all_brain_info[brain_name],
step_info.brain_name_to_action_info[brain_name].outputs,
)
trainer.process_experiences(
- step_info.previous_all_brain_info,
- step_info.current_all_brain_info,
+ step_info.previous_all_brain_info[brain_name],
+ step_info.current_all_brain_info[brain_name],
)
for brain_name, trainer in self.trainers.items():
if brain_name in self.trainer_metrics:
diff --git a/ml-agents/setup.py b/ml-agents/setup.py
index 5211612ee4..296e77bcc0 100644
--- a/ml-agents/setup.py
+++ b/ml-agents/setup.py
@@ -4,8 +4,9 @@
from setuptools import setup, find_namespace_packages
from setuptools.command.install import install
+import mlagents.trainers
-VERSION = "0.11.0"
+VERSION = mlagents.trainers.__version__
here = os.path.abspath(os.path.dirname(__file__))
@@ -64,7 +65,7 @@ def run(self):
"Pillow>=4.2.1",
"protobuf>=3.6",
"pyyaml",
- "tensorflow>=1.7,<2.0",
+ "tensorflow>=1.7,<2.1",
'pypiwin32==223;platform_system=="Windows"',
],
python_requires=">=3.6.1",
diff --git a/protobuf-definitions/README.md b/protobuf-definitions/README.md
index eb2ef57db7..2ee8320bb9 100644
--- a/protobuf-definitions/README.md
+++ b/protobuf-definitions/README.md
@@ -40,7 +40,7 @@ Navigate to your installation of nuget and run the following:
## Running
-Whenever you change the fields of a custom message, you must follow the steps below to create C# and Python files corresponding to the new message.
+Whenever you change the fields of a message, you must follow the steps below to create C# and Python files corresponding to the new message.
1. Open a terminal. **Note:** If you're using Anaconda, don't forget to activate the ml-agents environment first.
2. Un-comment line 7 in `make.sh` (for Windows, use `make_for_win.bat`), and set to correct Grpc.Tools sub-directory.
diff --git a/protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_action.proto b/protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_action.proto
index 05a00a7f78..cd0545c415 100644
--- a/protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_action.proto
+++ b/protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_action.proto
@@ -1,14 +1,12 @@
syntax = "proto3";
-import "mlagents/envs/communicator_objects/custom_action.proto";
-
option csharp_namespace = "MLAgents.CommunicatorObjects";
package communicator_objects;
message AgentActionProto {
repeated float vector_actions = 1;
- string text_actions = 2;
- repeated float memories = 3;
+ reserved 2; // deprecated string text_actions = 2;
+ reserved 3; //deprecated repeated float memories = 3;
float value = 4;
- CustomActionProto custom_action = 5;
+ reserved 5; // deprecated CustomActionProto custom_action = 5;
}
diff --git a/protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_info.proto b/protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_info.proto
index f48130eb63..022e5e48c5 100644
--- a/protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_info.proto
+++ b/protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_info.proto
@@ -1,23 +1,22 @@
syntax = "proto3";
-import "mlagents/envs/communicator_objects/compressed_observation.proto";
-import "mlagents/envs/communicator_objects/custom_observation.proto";
+import "mlagents/envs/communicator_objects/observation.proto";
option csharp_namespace = "MLAgents.CommunicatorObjects";
package communicator_objects;
message AgentInfoProto {
- repeated float stacked_vector_observation = 1;
+ reserved 1; // deprecated repeated float stacked_vector_observation = 1;
reserved 2; // deprecated repeated bytes visual_observations = 2;
- string text_observation = 3;
- repeated float stored_vector_actions = 4;
- string stored_text_actions = 5;
- repeated float memories = 6;
+ reserved 3; // deprecated string text_observation = 3;
+ reserved 4; // repeated float stored_vector_actions = 4;
+ reserved 5; // deprecated string stored_text_actions = 5;
+ reserved 6; //repeated float memories = 6;
float reward = 7;
bool done = 8;
bool max_step_reached = 9;
int32 id = 10;
repeated bool action_mask = 11;
- CustomObservationProto custom_observation = 12;
- repeated CompressedObservationProto compressed_observations = 13;
+ reserved 12; // deprecated CustomObservationProto custom_observation = 12;
+ repeated ObservationProto observations = 13;
}
diff --git a/protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_info_action_pair.proto b/protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_info_action_pair.proto
new file mode 100644
index 0000000000..64ee306b31
--- /dev/null
+++ b/protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_info_action_pair.proto
@@ -0,0 +1,12 @@
+syntax = "proto3";
+
+import "mlagents/envs/communicator_objects/agent_info.proto";
+import "mlagents/envs/communicator_objects/agent_action.proto";
+
+option csharp_namespace = "MLAgents.CommunicatorObjects";
+package communicator_objects;
+
+message AgentInfoActionPairProto {
+ AgentInfoProto agent_info = 1;
+ AgentActionProto action_info = 2;
+}
diff --git a/protobuf-definitions/proto/mlagents/envs/communicator_objects/brain_parameters.proto b/protobuf-definitions/proto/mlagents/envs/communicator_objects/brain_parameters.proto
index 1b97bba002..002e2122ce 100644
--- a/protobuf-definitions/proto/mlagents/envs/communicator_objects/brain_parameters.proto
+++ b/protobuf-definitions/proto/mlagents/envs/communicator_objects/brain_parameters.proto
@@ -6,8 +6,8 @@ option csharp_namespace = "MLAgents.CommunicatorObjects";
package communicator_objects;
message BrainParametersProto {
- int32 vector_observation_size = 1;
- int32 num_stacked_vector_observations = 2;
+ reserved 1; // deprecated int32 vector_observation_size = 1;
+ reserved 2; // deprecated int32 num_stacked_vector_observations = 2;
repeated int32 vector_action_size = 3;
reserved 4; // deprecated repeated ResolutionProto camera_resolutions
repeated string vector_action_descriptions = 5;
diff --git a/protobuf-definitions/proto/mlagents/envs/communicator_objects/custom_action.proto b/protobuf-definitions/proto/mlagents/envs/communicator_objects/custom_action.proto
deleted file mode 100644
index 257adb46a0..0000000000
--- a/protobuf-definitions/proto/mlagents/envs/communicator_objects/custom_action.proto
+++ /dev/null
@@ -1,7 +0,0 @@
-syntax = "proto3";
-
-option csharp_namespace = "MLAgents.CommunicatorObjects";
-package communicator_objects;
-
-message CustomActionProto {
-}
diff --git a/protobuf-definitions/proto/mlagents/envs/communicator_objects/custom_observation.proto b/protobuf-definitions/proto/mlagents/envs/communicator_objects/custom_observation.proto
deleted file mode 100644
index 37203b66cb..0000000000
--- a/protobuf-definitions/proto/mlagents/envs/communicator_objects/custom_observation.proto
+++ /dev/null
@@ -1,7 +0,0 @@
-syntax = "proto3";
-
-option csharp_namespace = "MLAgents.CommunicatorObjects";
-package communicator_objects;
-
-message CustomObservationProto {
-}
diff --git a/protobuf-definitions/proto/mlagents/envs/communicator_objects/compressed_observation.proto b/protobuf-definitions/proto/mlagents/envs/communicator_objects/observation.proto
similarity index 55%
rename from protobuf-definitions/proto/mlagents/envs/communicator_objects/compressed_observation.proto
rename to protobuf-definitions/proto/mlagents/envs/communicator_objects/observation.proto
index 0a6798cd2d..07d2e8df1c 100644
--- a/protobuf-definitions/proto/mlagents/envs/communicator_objects/compressed_observation.proto
+++ b/protobuf-definitions/proto/mlagents/envs/communicator_objects/observation.proto
@@ -8,8 +8,15 @@ enum CompressionTypeProto {
PNG = 1;
}
-message CompressedObservationProto {
+message ObservationProto {
+ message FloatData {
+ repeated float data = 1;
+ }
+
repeated int32 shape = 1;
CompressionTypeProto compression_type = 2;
- bytes data = 3;
+ oneof observation_data {
+ bytes compressed_data = 3;
+ FloatData float_data = 4;
+ }
}
diff --git a/setup.cfg b/setup.cfg
index f9e51ef20c..770fbc0e14 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -16,3 +16,7 @@ ignore =
# Black tends to introduce things flake8 doesn't like, such as "line break before binary operator"
# or whitespace before ':'. Rather than fight with black, just ignore these for now.
W503, E203,
+ # flake-tidy-import adds this warning, which we don't really care about for now
+ I200
+
+banned-modules = tensorflow = use mlagents.tf_utils instead (it handles tf2 compat).
diff --git a/test_constraints_max_tf1_version.txt b/test_constraints_max_tf1_version.txt
new file mode 100644
index 0000000000..8495128e4a
--- /dev/null
+++ b/test_constraints_max_tf1_version.txt
@@ -0,0 +1,6 @@
+# pip constraints to use the *highest* versions allowed in ml-agents/setup.py
+# with the exception of tensorflow, which is constrained to <2
+# For projects with upper bounds, we should periodically update this list to the latest release version
+grpcio>=1.23.0
+numpy>=1.17.2
+tensorflow>=1.14.0,<2.0
diff --git a/test_constraints_max_version.txt b/test_constraints_max_tf2_version.txt
similarity index 89%
rename from test_constraints_max_version.txt
rename to test_constraints_max_tf2_version.txt
index 9d8f6832b9..57f6043e42 100644
--- a/test_constraints_max_version.txt
+++ b/test_constraints_max_tf2_version.txt
@@ -2,4 +2,4 @@
# For projects with upper bounds, we should periodically update this list to the latest release version
grpcio>=1.23.0
numpy>=1.17.2
-tensorflow>=1.14.0,<2.0
+tensorflow>=2.0.0,<2.1.0
diff --git a/utils/validate_versions.py b/utils/validate_versions.py
index 118779260f..77787d60e5 100755
--- a/utils/validate_versions.py
+++ b/utils/validate_versions.py
@@ -3,10 +3,15 @@
import os
import sys
from typing import Dict
+import argparse
-VERSION_LINE_START = "VERSION = "
+VERSION_LINE_START = "__version__ = "
-DIRECTORIES = ["ml-agents", "ml-agents-envs", "gym-unity"]
+DIRECTORIES = [
+ "ml-agents/mlagents/trainers",
+ "ml-agents-envs/mlagents/envs",
+ "gym-unity/gym_unity",
+]
def extract_version_string(filename):
@@ -20,7 +25,7 @@ def extract_version_string(filename):
def check_versions() -> bool:
version_by_dir: Dict[str, str] = {}
for directory in DIRECTORIES:
- path = os.path.join(directory, "setup.py")
+ path = os.path.join(directory, "__init__.py")
version = extract_version_string(path)
print(f"Found version {version} for {directory}")
version_by_dir[directory] = version
@@ -33,7 +38,25 @@ def check_versions() -> bool:
return True
+def set_version(new_version: str) -> None:
+ new_contents = f'{VERSION_LINE_START}"{new_version}"\n'
+ for directory in DIRECTORIES:
+ path = os.path.join(directory, "__init__.py")
+ print(f"Setting {path} to version {new_version}")
+ with open(path, "w") as f:
+ f.write(new_contents)
+
+
if __name__ == "__main__":
- ok = check_versions()
- return_code = 0 if ok else 1
- sys.exit(return_code)
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--new-version", default=None)
+ # unused, but allows precommit to pass filenames
+ parser.add_argument("files", nargs="*")
+ args = parser.parse_args()
+ if args.new_version:
+ print(f"Updating to verison {args.new_version}")
+ set_version(args.new_version)
+ else:
+ ok = check_versions()
+ return_code = 0 if ok else 1
+ sys.exit(return_code)