Skip to content

Commit

Permalink
[Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array (a…
Browse files Browse the repository at this point in the history
…pache#5243)

* Support TF Frontend Static TensorArray

* Fix pylint

* Fix lint

* Move get_tensor_array_shape into prelude

* Fix lint

* Fix common
  • Loading branch information
kevinthesun authored and dpankratz committed Apr 24, 2020
1 parent 18e0eb2 commit c6461fa
Show file tree
Hide file tree
Showing 5 changed files with 450 additions and 99 deletions.
41 changes: 25 additions & 16 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,22 +456,20 @@ def get_name(node):

def infer_type(node, mod=None):
"""A method to infer the type of an intermediate node in the relay graph."""
new_mod = IRModule.from_expr(node)
if mod is not None:
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
return entry if isinstance(node, _function.Function) else entry.body
if isinstance(mod, IRModule):
mod["main"] = _function.Function([], node)
mod = _transform.InferType()(mod)
entry = mod["main"]
ret = entry.body
else:
new_mod = IRModule.from_expr(node)
if mod is not None:
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
ret = entry if isinstance(node, _function.Function) else entry.body

def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
if hasattr(checked_type, 'shape'):
# Regular operator that outputs tensors
return get_const_tuple(out_type.checked_type.shape)
# The return type is not a tensor, for example List
return checked_type
return ret

def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
Expand All @@ -483,6 +481,17 @@ def infer_channels(inputs, transpose=False):
return channels


def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
if hasattr(checked_type, 'shape'):
# Regular operator that outputs tensors
return get_const_tuple(checked_type.shape)
# The return type is not a tensor, for example List
return checked_type


def infer_value(input_val, params, mod=None):
"""A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that
Expand All @@ -505,7 +514,7 @@ def infer_value(input_val, params, mod=None):
return m.get_output(0)
except Exception:
if isinstance(mod, IRModule):
mod["main"] = _expr.Function(analysis.free_vars(input_val), input_val)
mod["main"] = _function.Function(analysis.free_vars(input_val), input_val)
else:
mod = IRModule.from_expr(input_val)
exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
Expand Down
Loading

0 comments on commit c6461fa

Please sign in to comment.