Skip to content

Commit

Permalink
Merge pull request #907 from jignparm/jignparm/fix_transpose_pad
Browse files Browse the repository at this point in the history
Fix Transpose + Pad handler, for Keras app MobilenetV2 model
  • Loading branch information
jignparm authored May 5, 2020
2 parents 59fed17 + b9ba4e1 commit d3dd7f0
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 5 deletions.
12 changes: 12 additions & 0 deletions tests/run_pretrained_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,18 @@ keras_resnet50:
model: ResNet50
model_type: keras
input_get: get_ramp
inputs:
"input_1:0": [1, 224, 224, 3]
outputs:
- Identity:0

keras_mobilenet_v2:
tf_min_version: 2.1
disabled: false
url: module://tensorflow.keras.applications.mobilenet_v2/MobileNetV2
model: MobileNetV2
model_type: keras
input_get: get_ramp
inputs:
"input_1:0": [1, 224, 224, 3]
outputs:
Expand Down
7 changes: 7 additions & 0 deletions tf2onnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,13 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
self.set_dtype(name, utils.map_numpy_to_onnx_dtype(np_val.dtype))
return node

def copy_const(self, node, name=None):
"""Copy a const node, using name if specified"""
# TODO: support attr copy starting at opset 12
if name is None:
name = utils.make_name(node.name)
return self.make_const(name, node.get_tensor_value(as_list=False))

def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, skip_conversion=True,
op_name_scope=None, name=None, shapes=None, dtypes=None, domain=constants.ONNX_DOMAIN,
infer_shape_dtype=True):
Expand Down
15 changes: 10 additions & 5 deletions tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,14 +602,19 @@ def _pad_handler(self, trans, node):
new_pads = [pads[0], pads[3], pads[1], pads[2], pads[4], pads[7], pads[5], pads[6]]
node.set_attr("pads", new_pads)
return self._switch_transpose_and_node(node, trans)
if node.inputs[1].is_const():
if node.inputs[1].data_format in ["NHWC", "unkown"]:
pads = node.inputs[1].get_tensor_value()

input1 = node.inputs[1]
if input1.is_const():
if input1.data_format in ["NHWC", "unkown"]:
if not self._nodes_has_single_consumer_node([input1]):
input1 = self._g.copy_const(input1)
node.input[1] = input1.output[0]
pads = input1.get_tensor_value()
# NHWC->NCHW
new_pads = np.array([pads[0], pads[3], pads[1], pads[2], pads[4], pads[7], pads[5], pads[6]],
dtype=np.int64)
node.inputs[1].set_tensor_value(new_pads)
node.inputs[1].data_format = "NCHW"
input1.set_tensor_value(new_pads)
input1.data_format = "NCHW"
return self._switch_transpose_and_node(node, trans)
return False

Expand Down
1 change: 1 addition & 0 deletions tf2onnx/tf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def from_keras(model_path, input_names, output_names):
# Handles Keras when Eager mode is enabled.
custom_objects = None
if context.executing_eagerly():
_keras.backend.clear_session()
_keras.backend.set_learning_phase(False)
keras_model = _keras.models.load_model(model_path, custom_objects)

Expand Down

0 comments on commit d3dd7f0

Please sign in to comment.