Skip to content

Commit

Permalink
[FRONTEND][TENSORFLOW] bug fix for tensorflow official slim models.
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Mar 20, 2019
1 parent 89acfeb commit de6deb0
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,25 +530,23 @@ def _impl(inputs, attr, params):
op_name="reshape",
extras={'newshape':tuple(shape_arg.asnumpy())},
ignores=['Tshape'])(inputs, attr)
except KeyError:
except AttributeError:
# Shape operator is already pruned, hence
# try to infer shape by precompute prune if possible.
if all(in_node in params for in_node in inputs[1].list_input_names()):
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.context("llvm", 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
params_new = m.get_output(0)
inputs.pop(1)
return AttrCvt(
op_name="reshape",
extras={'newshape':tuple(params_new.asnumpy().flatten())},
ignores=['Tshape'])(inputs, attr)
raise RuntimeError("Reshape with dynamic shape input not supported yet.")
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.context("llvm", 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
params_new = m.get_output(0)
inputs.pop(1)
return AttrCvt(
op_name="reshape",
extras={'newshape':tuple(params_new.asnumpy().astype('int32').flatten())},
ignores=['Tshape'])(inputs, attr)
return _impl

def _bias_add():
Expand Down

0 comments on commit de6deb0

Please sign in to comment.