Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][TensorFlow]TensorFlow Parser Control Flow Enhancement #5020

Merged
merged 14 commits into from
Mar 23, 2020
42 changes: 28 additions & 14 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=broad-except
"""Common utilities"""
from __future__ import absolute_import as _abs
import logging
Expand Down Expand Up @@ -482,24 +483,37 @@ def infer_channels(inputs, transpose=False):
return channels


def infer_value(input_val, params):
def infer_value(input_val, params, mod=None):
"""A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that
whose output shape depends on the value of a tensor.
"""
# pylint: disable=import-outside-toplevel
from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _function.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
return m.get_output(0)
try:
# TODO(kevinthesun): Use VM for all cases.
# pylint: disable=import-outside-toplevel
from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _function.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
return m.get_output(0)
except Exception:
if isinstance(mod, IRModule):
mod["main"] = _expr.Function(analysis.free_vars(input_val), input_val)
else:
mod = IRModule.from_expr(input_val)
exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
inputs = []
for param in mod['main'].params:
inputs.append(tvm.nd.array(params[param.name_hint]))
result = exc.evaluate()(*inputs)
return result


def infer_value_simulated(input_val, params):
Expand Down
Loading