Skip to content

Commit

Permalink
[TF][Op] Add TF op Where
Browse files Browse the repository at this point in the history
  • Loading branch information
wweic committed Oct 1, 2019
1 parent 2f1edb9 commit 75dad48
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 22 deletions.
6 changes: 5 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,10 @@ def _impl(inputs, attr, params):

def _where():
def _impl(inputs, attr, params):
return AttrCvt(op_name="where")(inputs, attr)
if len(inputs) == 1:
return AttrCvt(op_name="argwhere")(inputs, attr)
else:
return AttrCvt(op_name="where")(inputs, attr)
return _impl

def _clip_by_value():
Expand Down Expand Up @@ -1354,6 +1357,7 @@ def _impl(inputs, attr, params):
'Transpose' : _transpose(),
'TruncateMod' : _elemwise('mod'),
'Unpack' : _unpack(),
'Where' : _where(),
'ZerosLike' : AttrCvt('zeros_like'),

}
Expand Down
89 changes: 68 additions & 21 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,34 @@ def convert_to_list(x):
x = [x]
return x

def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.TensorObject):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject):
result = []
for f in o:
result.extend(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.TupleValue):
result = []
for f in o.fields:
result.append(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'cons':
tl = vmobj_to_list(o.fields[1])
hd = vmobj_to_list(o.fields[0])
hd.extend(tl)
return hd
elif o.constructor.name_hint == 'nil':
return []
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" % type(o))

def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
target='llvm', out_names=None, opt_level=3):
target='llvm', out_names=None, opt_level=3, mode='graph_runtime'):
""" Generic function to compile on relay and execute on tvm """
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
Expand All @@ -63,24 +89,32 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
layout=layout,
shape=shape_dict,
outputs=out_names)
with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build(mod, target, target_host, params)

ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
for e, i in zip(input_node, input_data):
m.set_input(e, tvm.nd.array(i))

m.set_input(**params)
# execute
m.run()
# get outputs
assert out_names is None or num_output == len(out_names), (
"out_names: {} num_output: {}".format(out_names, num_output))
tvm_output_list = [m.get_output(i).asnumpy() for i in range(num_output)]
return tvm_output_list
if mode == 'interp':
ex = relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
inputs = []
for param in mod['main'].params:
inputs.append(tvm.nd.array(params[param.name_hint]))
result = ex.evaluate()(*inputs)
return vmobj_to_list(result)
else:
with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build(mod, target, target_host, params)

ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
for e, i in zip(input_node, input_data):
m.set_input(e, tvm.nd.array(i))

m.set_input(**params)
# execute
m.run()
# get outputs
assert out_names is None or num_output == len(out_names), (
"out_names: {} num_output: {}".format(out_names, num_output))
tvm_output_list = [m.get_output(i).asnumpy() for i in range(num_output)]
return tvm_output_list

def run_tf_graph(sess, input_data, input_node, output_node):
""" Generic function to execute tensorflow """
Expand All @@ -97,7 +131,7 @@ def run_tf_graph(sess, input_data, input_node, output_node):


def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
no_gpu=False, opt_level=3):
no_gpu=False, opt_level=3, mode='graph_runtime'):
"""Generic function to generate and compare tensorflow and TVM output"""
def name_without_num(name):
return name.split(':')[0] if ":" in name else name
Expand Down Expand Up @@ -128,7 +162,7 @@ def name_without_num(name):

tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
target=device, out_names=out_name,
num_output=len(out_name), opt_level=opt_level)
num_output=len(out_name), opt_level=opt_level, mode=mode)
# since the names from tensorflow and relay runs are not exactly same,
# first len(tf_output) will be compared
for i in range(len(tf_output)):
Expand Down Expand Up @@ -325,6 +359,19 @@ def test_forward_biasadd():
_test_biasadd([4, 17, 17, 19], 'NHWC')
_test_biasadd([4, 3, 3, 124], 'NHWC')

def _test_forward_where(input_shape):
with tf.Graph().as_default():
dtype = tf.float32
t = tf.constant(np.random.choice([0, 1, 2, 3], size=input_shape).astype(dtype.name))
out = tf.where(t)
compare_tf_with_tvm([], [], out.name, mode='interp')

def test_forward_argwhere():
_test_forward_where((5, 5))
_test_forward_where((5, 5, 5))
_test_forward_where((5, 5, 5, 5))
_test_forward_where((5, 5, 5, 5, 5))

#######################################################################
# SpaceToBatchND
# --------------
Expand Down

0 comments on commit 75dad48

Please sign in to comment.