Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array #5243

Merged
merged 6 commits into from
Apr 11, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,22 +456,20 @@ def get_name(node):

def infer_type(node, mod=None):
"""A method to infer the type of an intermediate node in the relay graph."""
new_mod = IRModule.from_expr(node)
if mod is not None:
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
return entry if isinstance(node, _function.Function) else entry.body
if isinstance(mod, IRModule):
mod["main"] = _function.Function([], node)
mod = _transform.InferType()(mod)
entry = mod["main"]
ret = entry.body
else:
new_mod = IRModule.from_expr(node)
if mod is not None:
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
ret = entry if isinstance(node, _function.Function) else entry.body

def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
if hasattr(checked_type, 'shape'):
# Regular operator that outputs tensors
return get_const_tuple(out_type.checked_type.shape)
# The return type is not a tensor, for example List
return checked_type
return ret

def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
Expand All @@ -483,6 +481,17 @@ def infer_channels(inputs, transpose=False):
return channels


def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
if hasattr(checked_type, 'shape'):
# Regular operator that outputs tensors
return get_const_tuple(out_type.checked_type.shape)
# The return type is not a tensor, for example List
kevinthesun marked this conversation as resolved.
Show resolved Hide resolved
return checked_type


def infer_value(input_val, params, mod=None):
"""A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that
Expand All @@ -505,7 +514,7 @@ def infer_value(input_val, params, mod=None):
return m.get_output(0)
except Exception:
if isinstance(mod, IRModule):
mod["main"] = _expr.Function(analysis.free_vars(input_val), input_val)
mod["main"] = _function.Function(analysis.free_vars(input_val), input_val)
else:
mod = IRModule.from_expr(input_val)
exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
Expand Down
Loading