Skip to content

Commit

Permalink
[FRONTEND][TENSORFLOW] bug fix for tensorflow official slim models. (a…
Browse files Browse the repository at this point in the history
…pache#2864)

* [FRONTEND][TENSORFLOW] bug fix for tensorflow official slim models.

* 	* review comments
  • Loading branch information
srkreddy1238 authored and MarisaKirisame committed Apr 9, 2019
1 parent 4aea5e6 commit d581963
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,25 +543,23 @@ def _impl(inputs, attr, params):
op_name="reshape",
extras={'newshape':tuple(shape_arg.asnumpy())},
ignores=['Tshape'])(inputs, attr)
except KeyError:
except AttributeError:
# Shape operator is already pruned, hence
# try to infer shape by precompute prune if possible.
if all(in_node in params for in_node in inputs[1].list_input_names()):
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.context("llvm", 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
params_new = m.get_output(0)
inputs.pop(1)
return AttrCvt(
op_name="reshape",
extras={'newshape':tuple(params_new.asnumpy().flatten())},
ignores=['Tshape'])(inputs, attr)
raise RuntimeError("Reshape with dynamic shape input not supported yet.")
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.context("llvm", 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
params_new = m.get_output(0)
inputs.pop(1)
return AttrCvt(
op_name="reshape",
extras={'newshape':tuple(params_new.asnumpy().astype('int64').flatten())},
ignores=['Tshape'])(inputs, attr)
return _impl

def _bias_add():
Expand Down

0 comments on commit d581963

Please sign in to comment.