Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun committed Jun 11, 2020
1 parent 21b662e commit 7f20b9d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 20 deletions.
26 changes: 9 additions & 17 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,7 +1436,7 @@ def _impl(inputs, attr, params, mod):
# a partial symbolic shape, such as (1, ?), and get a static shape
# (1,). Directly slice on shape_of will result in fully dynamic shape.
# TODO(kevinthesun): Can we generalize this process with partial eval?
if isinstance(inputs[0], _expr.Call) and "shape_of" in str(inputs[0].op):
if isinstance(inputs[0], _expr.Call) and inputs[0].op == _op.get("shape_of"):
bg = begin[0]
ed = end[0]
st = stride[0]
Expand All @@ -1448,16 +1448,15 @@ def _impl(inputs, attr, params, mod):
dtype = in_type.checked_type.dtype
out_data = []
idx = bg
is_constant = True
while idx < ed:
if isinstance(in_shape[idx], int):
out_data.append(in_shape[idx])
else:
is_constant = False
break
idx += st

if is_constant:
# Only return when in_shape is fully static in the range from begin to end.
if idx >= st:
ret = _expr.const(out_data, dtype)
if shrink_axis_mask:
ret = _op.squeeze(ret)
Expand Down Expand Up @@ -2423,11 +2422,7 @@ def is_tensor_array_constuctor(tf_node):
is_ta = False
ta_start = "TensorArrayV"
if tf_node.op.startswith(ta_start):
try:
int(tf_node.op[len(ta_start)])
is_ta = True
except ValueError:
pass
is_ta = tf_node.op[len(ta_start)].isnumeric()
return is_ta

def find_parent_loop_name(node_name, while_loop_name_set):
Expand Down Expand Up @@ -2472,7 +2467,8 @@ def _in_while_loop(control_flow_node_map, op_name):

class RewriteSubgraph(ExprMutator):
"""
A helper class to rewrite expr in while loop function to variable
A helper class to rewrite expr in while loop function to variable.
Parameters
----------
rewrite_map : Dict[expr, expr]
Expand Down Expand Up @@ -2687,18 +2683,14 @@ def _while_loop(self):
for lv, exp in self._lvar2expr[self._loop_name].items():
if lv not in self.loop_vars:
var_checker = VarChecker(lv)
used = False
for bd in self.body + [cond]:
var_checker.visit(bd)
if var_checker.used:
used = True
lv_list.append(lv)
expr_list.append(exp)
extra_vars.append(lv)
break

if used:
lv_list.append(lv)
expr_list.append(exp)
extra_vars.append(lv)

with sb.if_scope(cond):
sb.ret(wl(*list(self.body + extra_vars)))
with sb.else_scope():
Expand Down
6 changes: 3 additions & 3 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
outputs=out_names)
ctx = tvm.context(target, 0)
if mode == 'debug':
ex = relay.create_executor(mode, mod=mod, ctx=ctx, target="llvm")
ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target="llvm")
inputs = []
for param in mod['main'].params:
found = False
Expand All @@ -130,9 +130,9 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
return vmobj_to_list(result)
elif mode == 'vm':
with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass):
vm_exec = relay.vm.compile(mod, target=target, params=params)
vm_exec = relay.vm.compile(mod, target="llvm", params=params)
vm = VirtualMachine(vm_exec)
vm.init(ctx)
vm.init(tvm.cpu())
inputs = {}
for e, i in zip(input_node, input_data):
inputs[e] = i
Expand Down

0 comments on commit 7f20b9d

Please sign in to comment.