Skip to content

Commit

Permalink
Support for ONNX export (#3101)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Elion authored Feb 18, 2020
1 parent 2c2f930 commit a5626aa
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 48 deletions.
2 changes: 2 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
### Major Changes
- Agent.CollectObservations now takes a VectorSensor argument. It was also overloaded to optionally take an ActionMasker argument. (#3352, #3389)
- Beta support for ONNX export was added. If the `tf2onnx` python package is installed, models will be saved to `.onnx` as well as `.nn` format.
Note that Barracuda 0.6.0 or later is required to import the `.onnx` files properly

### Minor Changes
- Monitor.cs was moved to Examples. (#3372)
Expand Down
11 changes: 9 additions & 2 deletions docs/Unity-Inference-Engine.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,17 @@ but we only tested for the following platforms :
* iOS
* Android

## Supported formats
There are currently two supported model formats:
* Barracuda (`.nn`) files use a proprietary format produced by the [`tensorflow_to_barracuda.py`]() script.
* ONNX (`.onnx`) files use an [industry-standard open format](https://onnx.ai/about.html) produced by the [tf2onnx package](https://github.com/onnx/tensorflow-onnx).

Export to ONNX is currently considered beta. To enable it, make sure `tf2onnx>=1.5.5` is installed in pip.
tf2onnx does not currently support tensorflow 2.0.0 or later.

## Using the Unity Inference Engine

When using a model, drag the `.nn` file into the **Model** field
in the Inspector of the Agent.
When using a model, drag the model file into the **Model** field in the Inspector of the Agent.
Select the **Inference Device** : CPU or GPU you want to use for Inference.

**Note:** For most of the models generated with the ML-Agents toolkit, CPU will be faster than GPU.
Expand Down
2 changes: 1 addition & 1 deletion ml-agents-envs/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def run(self):
install_requires=[
"cloudpickle",
"grpcio>=1.11.0",
"numpy>=1.13.3,<2.0",
"numpy>=1.14.1,<2.0",
"Pillow>=4.2.1",
"protobuf>=3.6",
],
Expand Down
205 changes: 205 additions & 0 deletions ml-agents/mlagents/model_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import logging
from typing import Any, List, Set, NamedTuple

try:
import onnx
from tf2onnx.tfonnx import process_tf_graph, tf_optimize
from tf2onnx import optimizer

ONNX_EXPORT_ENABLED = True
except ImportError:
# Either onnx and tf2onnx not installed, or they're not compatible with the version of tensorflow
ONNX_EXPORT_ENABLED = False
pass

from mlagents.tf_utils import tf

from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util
from mlagents.trainers import tensorflow_to_barracuda as tf2bc

logger = logging.getLogger("mlagents.trainers")

POSSIBLE_INPUT_NODES = frozenset(
[
"action_masks",
"epsilon",
"prev_action",
"recurrent_in",
"sequence_length",
"vector_observation",
]
)

POSSIBLE_OUTPUT_NODES = frozenset(
["action", "action_probs", "recurrent_out", "value_estimate"]
)

MODEL_CONSTANTS = frozenset(
["action_output_shape", "is_continuous_control", "memory_size", "version_number"]
)
VISUAL_OBSERVATION_PREFIX = "visual_observation_"


class SerializationSettings(NamedTuple):
model_path: str
brain_name: str
convert_to_barracuda: bool = True
convert_to_onnx: bool = True
onnx_opset: int = 9


def export_policy_model(
settings: SerializationSettings, graph: tf.Graph, sess: tf.Session
) -> None:
"""
Exports latest saved model to .nn format for Unity embedding.
"""
frozen_graph_def = _make_frozen_graph(settings, graph, sess)
# Save frozen graph
frozen_graph_def_path = settings.model_path + "/frozen_graph_def.pb"
with gfile.GFile(frozen_graph_def_path, "wb") as f:
f.write(frozen_graph_def.SerializeToString())

# Convert to barracuda
if settings.convert_to_barracuda:
tf2bc.convert(frozen_graph_def_path, settings.model_path + ".nn")
logger.info(f"Exported {settings.model_path}.nn file")

# Save to onnx too (if we were able to import it)
if ONNX_EXPORT_ENABLED and settings.convert_to_onnx:
try:
onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def)
onnx_output_path = settings.model_path + ".onnx"
with open(onnx_output_path, "wb") as f:
f.write(onnx_graph.SerializeToString())
logger.info(f"Converting to {onnx_output_path}")
except Exception:
logger.exception(
"Exception trying to save ONNX graph. Please report this error on "
"https://github.com/Unity-Technologies/ml-agents/issues and "
"attach a copy of frozen_graph_def.pb"
)


def _make_frozen_graph(
settings: SerializationSettings, graph: tf.Graph, sess: tf.Session
) -> tf.GraphDef:
with graph.as_default():
target_nodes = ",".join(_process_graph(settings, graph))
graph_def = graph.as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph_def, target_nodes.replace(" ", "").split(",")
)
return output_graph_def


def convert_frozen_to_onnx(
settings: SerializationSettings, frozen_graph_def: tf.GraphDef
) -> Any:
# This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py

# Some constants in the graph need to be read by the inference system.
# These aren't used by the model anywhere, so trying to make sure they propagate
# through conversion and import is a losing battle. Instead, save them now,
# so that we can add them back later.
constant_values = {}
for n in frozen_graph_def.node:
if n.name in MODEL_CONSTANTS:
val = n.attr["value"].tensor.int_val[0]
constant_values[n.name] = val

inputs = _get_input_node_names(frozen_graph_def)
outputs = _get_output_node_names(frozen_graph_def)
logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}")

frozen_graph_def = tf_optimize(
inputs, outputs, frozen_graph_def, fold_constant=True
)

with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(frozen_graph_def, name="")
with tf.Session(graph=tf_graph):
g = process_tf_graph(
tf_graph,
input_names=inputs,
output_names=outputs,
opset=settings.onnx_opset,
)

onnx_graph = optimizer.optimize_graph(g)
model_proto = onnx_graph.make_model(settings.brain_name)

# Save the constant values back the graph initializer.
# This will ensure the importer gets them as global constants.
constant_nodes = []
for k, v in constant_values.items():
constant_node = _make_onnx_node_for_constant(k, v)
constant_nodes.append(constant_node)
model_proto.graph.initializer.extend(constant_nodes)
return model_proto


def _make_onnx_node_for_constant(name: str, value: int) -> Any:
tensor_value = onnx.TensorProto(
data_type=onnx.TensorProto.INT32,
name=name,
int32_data=[value],
dims=[1, 1, 1, 1],
)
return tensor_value


def _get_input_node_names(frozen_graph_def: Any) -> List[str]:
"""
Get the list of input node names from the graph.
Names are suffixed with ":0"
"""
node_names = _get_frozen_graph_node_names(frozen_graph_def)
input_names = node_names & POSSIBLE_INPUT_NODES

# Check visual inputs sequentially, and exit as soon as we don't find one
vis_index = 0
while True:
vis_node_name = f"{VISUAL_OBSERVATION_PREFIX}{vis_index}"
if vis_node_name in node_names:
input_names.add(vis_node_name)
else:
break
vis_index += 1
# Append the port
return [f"{n}:0" for n in input_names]


def _get_output_node_names(frozen_graph_def: Any) -> List[str]:
"""
Get the list of output node names from the graph.
Names are suffixed with ":0"
"""
node_names = _get_frozen_graph_node_names(frozen_graph_def)
output_names = node_names & POSSIBLE_OUTPUT_NODES
# Append the port
return [f"{n}:0" for n in output_names]


def _get_frozen_graph_node_names(frozen_graph_def: Any) -> Set[str]:
"""
Get all the node names from the graph.
"""
names = set()
for node in frozen_graph_def.node:
names.add(node.name)
return names


def _process_graph(settings: SerializationSettings, graph: tf.Graph) -> List[str]:
"""
Gets the list of the output nodes present in the graph for inference
:return: list of node names
"""
all_nodes = [x.name for x in graph.as_graph_def().node]
nodes = [x for x in all_nodes if x in POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS]
logger.info("List of nodes to export for brain :" + settings.brain_name)
for n in nodes:
logger.info("\t" + n)
return nodes
44 changes: 1 addition & 43 deletions ml-agents/mlagents/trainers/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
from typing import Any, Dict, List, Optional

import numpy as np

from mlagents.tf_utils import tf
from mlagents import tf_utils

from mlagents_envs.exception import UnityException
from mlagents.trainers.policy import Policy
from mlagents.trainers.action_info import ActionInfo
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util
from mlagents.trainers import tensorflow_to_barracuda as tf2bc
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.brain_conversion_utils import get_global_agent_id
Expand All @@ -34,17 +32,6 @@ class TFPolicy(Policy):
functions to interact with it to perform evaluate and updating.
"""

possible_output_nodes = [
"action",
"value_estimate",
"action_probs",
"recurrent_out",
"memory_size",
"version_number",
"is_continuous_control",
"action_output_shape",
]

def __init__(self, seed, brain, trainer_parameters):
"""
Initialized the policy.
Expand Down Expand Up @@ -328,35 +315,6 @@ def save_model(self, steps):
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)

def export_model(self):
"""
Exports latest saved model to .nn format for Unity embedding.
"""

with self.graph.as_default():
target_nodes = ",".join(self._process_graph())
graph_def = self.graph.as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
self.sess, graph_def, target_nodes.replace(" ", "").split(",")
)
frozen_graph_def_path = self.model_path + "/frozen_graph_def.pb"
with gfile.GFile(frozen_graph_def_path, "wb") as f:
f.write(output_graph_def.SerializeToString())
tf2bc.convert(frozen_graph_def_path, self.model_path + ".nn")
logger.info("Exported " + self.model_path + ".nn file")

def _process_graph(self):
"""
Gets the list of the output nodes present in the graph for inference
:return: list of node names
"""
all_nodes = [x.name for x in self.graph.as_graph_def().node]
nodes = [x for x in all_nodes if x in self.possible_output_nodes]
logger.info("List of nodes to export for brain :" + self.brain.brain_name)
for n in nodes:
logger.info("\t" + n)
return nodes

def update_normalization(self, vector_obs: np.ndarray) -> None:
"""
If this policy normalizes vector observations, this will update the norm values in the graph.
Expand Down
5 changes: 4 additions & 1 deletion ml-agents/mlagents/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from mlagents_envs.exception import UnityException
from mlagents_envs.timers import set_gauge
from mlagents.model_serialization import export_policy_model, SerializationSettings
from mlagents.trainers.tf_policy import TFPolicy
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.trajectory import Trajectory
Expand Down Expand Up @@ -192,7 +193,9 @@ def export_model(self, name_behavior_id: str) -> None:
"""
Exports the model
"""
self.get_policy(name_behavior_id).export_model()
policy = self.get_policy(name_behavior_id)
settings = SerializationSettings(policy.model_path, policy.brain.brain_name)
export_policy_model(settings, policy.graph, policy.sess)

def _write_summary(self, step: int) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion test_constraints_min_version.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pip constraints to use the *lowest* versions allowed in ml-agents/setup.py
grpcio==1.11.0
numpy==1.13.3
numpy==1.14.1
Pillow==4.2.1
protobuf==3.6
tensorflow==1.7
Expand Down
4 changes: 4 additions & 0 deletions test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@
pytest>4.0.0,<6.0.0
pytest-cov==2.6.1
pytest-xdist

# Tests install onnx and tf2onnx, but this doesn't support tensorflow>=2.0.0
# Since we test tensorflow2.0 with python3.7, exclude it based on the python version
tf2onnx>=1.5.5; python_version < '3.7'

0 comments on commit a5626aa

Please sign in to comment.