From 4fae9f3f56cc6573afe61b548e312707ce803c3b Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Thu, 21 Sep 2017 19:18:46 -0700 Subject: [PATCH 1/2] made a nice error if a placeholder is missing or if a placeholder is not in the graph --- .../ML-Agents/Scripts/CoreBrainInternal.cs | 47 +++++++++++++++---- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs b/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs index 6fd62ceb3f..a83123bcd0 100755 --- a/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs +++ b/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs @@ -51,7 +51,7 @@ public enum tensorType public string[] ObservationPlaceholderName; /// Modify only in inspector : Name of the action node public string ActionPlaceholderName = "action"; -#if ENABLE_TENSORFLOW + #if ENABLE_TENSORFLOW TFGraph graph; TFSession session; bool hasRecurrent; @@ -62,7 +62,7 @@ public enum tensorType float[,] inputState; List observationMatrixList; float[,] inputOldMemories; -#endif + #endif /// Reference to the brain that uses this CoreBrainInternal public Brain brain; @@ -190,13 +190,22 @@ public void DecideAction() foreach (TensorFlowAgentPlaceholder placeholder in graphPlaceholders) { - if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.FloatingPoint) + try { - runner.AddInput(graph[graphScope + placeholder.name][0], new float[] { Random.Range(placeholder.minValue, placeholder.maxValue) }); + if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.FloatingPoint) + { + runner.AddInput(graph[graphScope + placeholder.name][0], new float[] { Random.Range(placeholder.minValue, placeholder.maxValue) }); + } + else if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.Integer) + { + runner.AddInput(graph[graphScope + placeholder.name][0], new int[] { Random.Range((int)placeholder.minValue, (int)placeholder.maxValue + 1) }); + } } - else if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.Integer) + catch { - runner.AddInput(graph[graphScope + placeholder.name][0], new int[] { Random.Range((int)placeholder.minValue, (int)placeholder.maxValue + 1) }); + throw new UnityAgentsException(string.Format(@"One of the Tensorflow placeholder cound nout be found. + In brain {0}, there are no {1} placeholder named {2}.", + brain.gameObject.name, placeholder.valueType.ToString(), graphScope + placeholder.name)); } } @@ -212,6 +221,26 @@ public void DecideAction() runner.AddInput(graph[graphScope + ObservationPlaceholderName[obs_number]][0], observationMatrixList[obs_number]); } + TFTensor[] runned; + try + { + runned = runner.Run(); + } + catch (TFException e) + { + string errorMessage = e.Message; + try + { + errorMessage = string.Format(@"The tensorflow graph needs an input for {0} of type {1}", + e.Message.Split(new string[]{ "Node: " }, 0)[1].Split('=')[0], + e.Message.Split(new string[]{ "dtype=" }, 0)[1].Split(',')[0]); + } + finally + { + throw new UnityAgentsException(errorMessage); + } + + } // Create the recurrent tensor if (hasRecurrent) @@ -220,7 +249,7 @@ public void DecideAction() runner.AddInput(graph[graphScope + RecurrentInPlaceholderName][0], inputOldMemories); runner.Fetch(graph[graphScope + RecurrentOutPlaceholderName][0]); - float[,] recurrent_tensor = runner.Run()[1].GetValue() as float[,]; + float[,] recurrent_tensor = runned[1].GetValue() as float[,]; int i = 0; foreach (int k in agentKeys) @@ -241,7 +270,7 @@ public void DecideAction() if (brain.brainParameters.actionSpaceType == StateType.continuous) { - float[,] output = runner.Run()[0].GetValue() as float[,]; + float[,] output = runned[0].GetValue() as float[,]; int i = 0; foreach (int k in agentKeys) { @@ -256,7 +285,7 @@ public void DecideAction() } else if (brain.brainParameters.actionSpaceType == StateType.discrete) { - long[,] output = runner.Run()[0].GetValue() as long[,]; + long[,] output = runned[0].GetValue() as long[,]; int i = 0; foreach (int k in agentKeys) { From caa5ba97817b76fb13a402275aaac006b13d9218 Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Fri, 22 Sep 2017 10:17:02 -0700 Subject: [PATCH 2/2] renamed run to networkOutput --- .../Assets/ML-Agents/Scripts/CoreBrainInternal.cs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs b/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs index a83123bcd0..8362284f69 100755 --- a/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs +++ b/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs @@ -221,10 +221,10 @@ public void DecideAction() runner.AddInput(graph[graphScope + ObservationPlaceholderName[obs_number]][0], observationMatrixList[obs_number]); } - TFTensor[] runned; + TFTensor[] networkOutput; try { - runned = runner.Run(); + networkOutput = runner.Run(); } catch (TFException e) { @@ -249,7 +249,7 @@ public void DecideAction() runner.AddInput(graph[graphScope + RecurrentInPlaceholderName][0], inputOldMemories); runner.Fetch(graph[graphScope + RecurrentOutPlaceholderName][0]); - float[,] recurrent_tensor = runned[1].GetValue() as float[,]; + float[,] recurrent_tensor = networkOutput[1].GetValue() as float[,]; int i = 0; foreach (int k in agentKeys) @@ -270,7 +270,7 @@ public void DecideAction() if (brain.brainParameters.actionSpaceType == StateType.continuous) { - float[,] output = runned[0].GetValue() as float[,]; + float[,] output = networkOutput[0].GetValue() as float[,]; int i = 0; foreach (int k in agentKeys) { @@ -285,7 +285,7 @@ public void DecideAction() } else if (brain.brainParameters.actionSpaceType == StateType.discrete) { - long[,] output = runned[0].GetValue() as long[,]; + long[,] output = networkOutput[0].GetValue() as long[,]; int i = 0; foreach (int k in agentKeys) {