Skip to content

Commit

Permalink
[Frontend][TensorFlow]TensorFlow Parser Control Flow Enhancement (apa…
Browse files Browse the repository at this point in the history
…che#5020)

* Improve TF control flow major logic

* Pass mod into operator convert function

* Fix LoopBound

* Add more control flow tests

* Add two test cases for stridedslice

* Fix docstring

* Fix lint

* Fix import

* Fix test assert

* Minor fix conv3d

* Add more comments

* Fix for dilation2d

* Change newly added atan

* Change newly added unravel
  • Loading branch information
kevinthesun authored and zhiics committed Apr 17, 2020
1 parent fb5fbc2 commit 1a90cd1
Show file tree
Hide file tree
Showing 5 changed files with 641 additions and 325 deletions.
42 changes: 28 additions & 14 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=broad-except
"""Common utilities"""
from __future__ import absolute_import as _abs
import logging
Expand Down Expand Up @@ -482,24 +483,37 @@ def infer_channels(inputs, transpose=False):
return channels


def infer_value(input_val, params):
def infer_value(input_val, params, mod=None):
"""A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that
whose output shape depends on the value of a tensor.
"""
# pylint: disable=import-outside-toplevel
from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _function.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
return m.get_output(0)
try:
# TODO(kevinthesun): Use VM for all cases.
# pylint: disable=import-outside-toplevel
from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _function.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
return m.get_output(0)
except Exception:
if isinstance(mod, IRModule):
mod["main"] = _expr.Function(analysis.free_vars(input_val), input_val)
else:
mod = IRModule.from_expr(input_val)
exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
inputs = []
for param in mod['main'].params:
inputs.append(tvm.nd.array(params[param.name_hint]))
result = exc.evaluate()(*inputs)
return result


def infer_value_simulated(input_val, params):
Expand Down
Loading

0 comments on commit 1a90cd1

Please sign in to comment.