Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF][Relay][Op] Pass module when infer shape #4287

Merged
merged 4 commits into from
Nov 11, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,20 +451,24 @@ def get_name(node):
return name


def infer_type(node):
def infer_type(node, mod=None):
"""A method to infer the type of an intermediate node in the relay graph."""
mod = node if isinstance(node, _module.Module) else _module.Module.from_expr(node)
mod = _transform.InferType()(mod)
entry = mod["main"]
new_mod = _module.Module.from_expr(node)
if mod is not None:
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body


def infer_shape(inputs):
"""A method to get the output shape of an intermediate node in the graph."""
out_type = infer_type(inputs)
out_shapes = get_const_tuple(out_type.checked_type.shape)
return out_shapes

def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
if hasattr(checked_type, 'shape'):
# Regular operator that outputs tensors
return get_const_tuple(out_type.checked_type.shape)
# The return type is not a tensor, for example List
return checked_type

def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
Expand Down
19 changes: 14 additions & 5 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def _get_list_param(params, input_node):
def _get_tuple_param(params, input_node):
return tuple(_get_param(params, input_node))

def _need_module_for_shape_inference(op):
return op in ['StridedSlice']
wweic marked this conversation as resolved.
Show resolved Hide resolved

def _need_prelude_for_shape_inference(op):
return "TensorArray" in op

def _rsqrt():
def _impl(inputs, attr, params):
inputs.append(tvm.relay.const(-0.5, attr['T'].name))
Expand Down Expand Up @@ -891,7 +897,7 @@ def _impl(inputs, attr, params):
return _impl

def _stridedSlice():
def _impl(inputs, attr, params):
def _impl(inputs, attr, params, mod):
"""Strided Slice.
Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
Expand Down Expand Up @@ -974,7 +980,7 @@ def _transform_mask(stride_dim, ellipsis_mask):
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
out_shape = _infer_shape(out)
out_shape = _infer_shape(out, mod=mod)
if not fshape_indices:
fshape_indices = range(len(out_shape))

Expand Down Expand Up @@ -2167,7 +2173,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
out_shapes = [_infer_shape(node_item) for node_item in self._nodes[node.name]]
out_shapes = [_infer_shape(node_item, self._mod)
for node_item in self._nodes[node.name]]
self._output_shapes[node.name] = out_shapes

if self._output_shapes[node.name] and shape and node.name in shape:
Expand All @@ -2177,7 +2184,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
node_output = self._nodes[node.name]
if shape and (not self._output_shapes[node.name][0]
or -1 in self._output_shapes[node.name][0]):
out_shapes = [_infer_shape(node_item) for node_item in node_output]
out_shapes = [_infer_shape(node_item, self._mod) for node_item in node_output]
self._output_shapes[node.name] = out_shapes

out = []
Expand Down Expand Up @@ -2468,8 +2475,10 @@ def _convert_operator(self, op_name, inputs, attrs,
if op_name in identity_list:
sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
if 'TensorArray' in op_name:
if _need_prelude_for_shape_inference(op_name):
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
elif _need_module_for_shape_inference(op_name):
sym = convert_map[op_name](inputs, attrs, self._params, self._mod)
else:
sym = convert_map[op_name](inputs, attrs, self._params)

Expand Down