diff --git a/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs b/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs index 6fd62ceb3f..8362284f69 100755 --- a/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs +++ b/unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs @@ -51,7 +51,7 @@ public enum tensorType public string[] ObservationPlaceholderName; /// Modify only in inspector : Name of the action node public string ActionPlaceholderName = "action"; -#if ENABLE_TENSORFLOW + #if ENABLE_TENSORFLOW TFGraph graph; TFSession session; bool hasRecurrent; @@ -62,7 +62,7 @@ public enum tensorType float[,] inputState; List observationMatrixList; float[,] inputOldMemories; -#endif + #endif /// Reference to the brain that uses this CoreBrainInternal public Brain brain; @@ -190,13 +190,22 @@ public void DecideAction() foreach (TensorFlowAgentPlaceholder placeholder in graphPlaceholders) { - if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.FloatingPoint) + try { - runner.AddInput(graph[graphScope + placeholder.name][0], new float[] { Random.Range(placeholder.minValue, placeholder.maxValue) }); + if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.FloatingPoint) + { + runner.AddInput(graph[graphScope + placeholder.name][0], new float[] { Random.Range(placeholder.minValue, placeholder.maxValue) }); + } + else if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.Integer) + { + runner.AddInput(graph[graphScope + placeholder.name][0], new int[] { Random.Range((int)placeholder.minValue, (int)placeholder.maxValue + 1) }); + } } - else if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.Integer) + catch { - runner.AddInput(graph[graphScope + placeholder.name][0], new int[] { Random.Range((int)placeholder.minValue, (int)placeholder.maxValue + 1) }); + throw new UnityAgentsException(string.Format(@"One of the Tensorflow placeholder cound nout be found. + In brain {0}, there are no {1} placeholder named {2}.", + brain.gameObject.name, placeholder.valueType.ToString(), graphScope + placeholder.name)); } } @@ -212,6 +221,26 @@ public void DecideAction() runner.AddInput(graph[graphScope + ObservationPlaceholderName[obs_number]][0], observationMatrixList[obs_number]); } + TFTensor[] networkOutput; + try + { + networkOutput = runner.Run(); + } + catch (TFException e) + { + string errorMessage = e.Message; + try + { + errorMessage = string.Format(@"The tensorflow graph needs an input for {0} of type {1}", + e.Message.Split(new string[]{ "Node: " }, 0)[1].Split('=')[0], + e.Message.Split(new string[]{ "dtype=" }, 0)[1].Split(',')[0]); + } + finally + { + throw new UnityAgentsException(errorMessage); + } + + } // Create the recurrent tensor if (hasRecurrent) @@ -220,7 +249,7 @@ public void DecideAction() runner.AddInput(graph[graphScope + RecurrentInPlaceholderName][0], inputOldMemories); runner.Fetch(graph[graphScope + RecurrentOutPlaceholderName][0]); - float[,] recurrent_tensor = runner.Run()[1].GetValue() as float[,]; + float[,] recurrent_tensor = networkOutput[1].GetValue() as float[,]; int i = 0; foreach (int k in agentKeys) @@ -241,7 +270,7 @@ public void DecideAction() if (brain.brainParameters.actionSpaceType == StateType.continuous) { - float[,] output = runner.Run()[0].GetValue() as float[,]; + float[,] output = networkOutput[0].GetValue() as float[,]; int i = 0; foreach (int k in agentKeys) { @@ -256,7 +285,7 @@ public void DecideAction() } else if (brain.brainParameters.actionSpaceType == StateType.discrete) { - long[,] output = runner.Run()[0].GetValue() as long[,]; + long[,] output = networkOutput[0].GetValue() as long[,]; int i = 0; foreach (int k in agentKeys) {