Skip to content

Commit

Permalink
[Frontend][TensorFlow] Improve Control Flow and TensorArray (#5699)
Browse files Browse the repository at this point in the history
* Improve TF parser control flow and tensor array

* Fix tf tensor array scatter

* Add ssd test

* Add back static ta test

* Minor fix for frontend and test_forward

* SplitRel for dynamic shape

* Fix test ssd

* Fix loop var naming issue

* Minor improve

* Fix format

* Fix clang format

* Fix tensor array in pytorch frontend

* Fix stack size issue for ssd test

* Address comments

* Fix slice size

* Fix build

* Rebase
  • Loading branch information
kevinthesun authored Jun 12, 2020
1 parent 54bde85 commit ae119f8
Show file tree
Hide file tree
Showing 8 changed files with 602 additions and 381 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,13 +497,13 @@ def infer_value(input_val, params, mod=None):
portion of the relay graph. This is often needed for functions that
whose output shape depends on the value of a tensor.
"""
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
try:
# TODO(kevinthesun): Use VM for all cases.
# pylint: disable=import-outside-toplevel
from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _function.Function(analysis.free_vars(input_val), input_val)
with tvm.transform.PassContext(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
Expand All @@ -520,7 +520,7 @@ def infer_value(input_val, params, mod=None):
exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
inputs = []
for param in mod['main'].params:
inputs.append(tvm.nd.array(params[param.name_hint]))
inputs.append(params[param.name_hint])
result = exc.evaluate()(*inputs)
return result

Expand Down
18 changes: 9 additions & 9 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,12 @@ def tensor_array_concat(lst, axis):
assert axis == 0, "Tensor array concat supported only for axis 0"
tensor_array, shape = _convert_to_tensor_array(lst, prelude)
concat_shape = (Any(),) + shape[1:]
static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
static_tensor_array_ops.define_tensor_get_data(concat_shape)

concat = prelude.get_var_static('tensor_array_concat', "float32", shape)
concatenated = concat(tensor_array)
get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape)

static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape)
static_tensor_array_ops.register()
get_tensor = prelude.get_var_static('tensor_get_data', "float32", concat_shape)
return get_tensor(concatenated)

def _impl(inputs, input_types):
Expand Down Expand Up @@ -1619,14 +1619,14 @@ def _impl(inputs, input_types):
def _tensor_array_stack(prelude):
def _impl(inputs, input_types):
tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude)

stacked_shape = (Any(),) + shape
stack = prelude.get_var_static('tensor_array_stack', "float32", shape)
stacked = stack(tensor_array)

stacked_shape = (Any(),) + shape
static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
static_tensor_array_ops.define_tensor_get_data(stacked_shape)
# passing stacked_shape below gives "'Prelude' object has no attribute" error
get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape)
static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", stacked_shape)
static_tensor_array_ops.register()
get_tensor = prelude.get_var_static('tensor_get_data', "float32", stacked_shape)
return get_tensor(stacked)
return _impl

Expand Down
814 changes: 469 additions & 345 deletions python/tvm/relay/frontend/tensorflow.py

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,21 +555,21 @@ def define_tensor_array_gather(self):
self.prelude.mod[gather_var] = \
Function([tensor_array, indices], body, output_tensor_type_var(), [])

def define_tensor_get_data(self, data_shape):
def define_tensor_get_data(self):
"""Defines a function to get a Tensor from tensor_t with given shape.
"""
tensor_get_data_name = self.get_name("tensor_get_data")
tensor_get_data_var = self._create_global_var(tensor_get_data_name)
setattr(self.prelude, tensor_get_data_name, tensor_get_data_var)

tensor_type_var, tensor_constructor = self._get_adt_by_shape(data_shape)
tensor_type_var = self.get_var('tensor_t')
tensor_constructor = self.get_var('tensor_constructor')
t = Var('tensor', tensor_type_var())
tvar = Var('t')
case =\
Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), tvar)
self.prelude.mod[tensor_get_data_var] = \
Function([t], Match(t, [case], False),
TensorType(data_shape, self.dtype), [])
TensorType(self.shape, self.dtype), [])

def register(self):
"""Register all tensor array ops in Prelude"""
Expand All @@ -586,6 +586,7 @@ def register(self):
self.define_tensor_array_concat()
self.define_tensor_array_stack()
self.define_tensor_array_gather()
self.define_tensor_get_data()

def _get_adt_by_shape(self, shape):
"""Get ADT type and constructor with given shape."""
Expand Down
26 changes: 19 additions & 7 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2088,13 +2088,19 @@ bool SplitRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK_GE(axis, 0) << "axis should be within the input dimension range.";

if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) {
CHECK(reporter->Assert(indexmod(data->shape[axis], sections->value) ==
tir::make_zero(DataType::Int(64))))
<< "indices_or_sections need to be able to divide input.shape[axis]";
if (!data->shape[axis].as<AnyNode>()) {
CHECK(reporter->Assert(indexmod(data->shape[axis], sections->value) ==
tir::make_zero(DataType::Int(64))))
<< "indices_or_sections need to be able to divide input.shape[axis]";
}
std::vector<Type> fields;
for (int i = 0; i < sections->value; ++i) {
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] = indexdiv(oshape[axis], sections->value);
if (data->shape[axis].as<AnyNode>()) {
oshape[axis] = Any();
} else {
oshape[axis] = indexdiv(oshape[axis], sections->value);
}
auto vec_type = TensorType(oshape, data->dtype);
fields.push_back(vec_type);
}
Expand All @@ -2112,10 +2118,16 @@ bool SplitRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
auto vec_type = TensorType(oshape, data->dtype);
fields.push_back(vec_type);
}
CHECK(reporter->Assert(begin < data->shape[axis]))
<< "The sum of sections must match the input.shape[axis]";
if (!data->shape[axis].as<AnyNode>()) {
CHECK(reporter->Assert(begin < data->shape[axis]))
<< "The sum of sections must match the input.shape[axis]";
}
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] = data->shape[axis] - begin;
if (data->shape[axis].as<AnyNode>()) {
oshape[axis] = Any();
} else {
oshape[axis] = data->shape[axis] - begin;
}
auto vec_type = TensorType(oshape, data->dtype);
fields.push_back(vec_type);
reporter->Assign(types[1], TupleType(Array<Type>(fields)));
Expand Down
26 changes: 24 additions & 2 deletions tests/python/frontend/tensorflow/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def check_equal(graph, tf_out, input_map=None):
def test_vanilla_loop():
graph = tf.Graph()
with graph.as_default():
i = tf.constant(0)
i = tf.constant(0, name="while/constant")

def c(i): return tf.less(i, 10)

Expand Down Expand Up @@ -368,7 +368,6 @@ def condition(x, y):

check_equal(graph, tf_out, {dname: np_data})


def test_switch():
graph = tf.Graph()

Expand All @@ -385,6 +384,28 @@ def test_switch():

check_equal(graph, tf_out, {dname: data_np, flag_name: False})

def test_loop_tuple_input():
graph = tf.Graph()

with graph.as_default():
data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32')
dname = 'data'
data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname)
split = tf.split(data, 2, axis=0)

def body(x, y):
return x + 2, y + 1

start = tf.constant(0)
def condition(x, y):
return tf.less(y, 20)

r = tf.while_loop(condition, body, loop_vars=[split[1], start])
with tf.Session() as sess:
tf_out = sess.run(r, feed_dict={data.name: data_np})

check_equal(graph, tf_out, {dname: data_np})


if __name__ == "__main__":
# tf.while_loop
Expand All @@ -410,3 +431,4 @@ def test_switch():
test_nested_loop_bound()

test_switch()
test_loop_tuple_input()
81 changes: 72 additions & 9 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
This article is a test script to test tensorflow operator with Relay.
"""
from __future__ import print_function
import threading
import numpy as np
import pytest
try:
Expand All @@ -45,6 +46,7 @@
from tvm import te
from tvm import relay
import tvm.relay.testing.tf as tf_testing
from tvm.runtime.vm import VirtualMachine
from packaging import version as package_version

#######################################################################
Expand Down Expand Up @@ -98,11 +100,10 @@ def vmobj_to_list(o):

def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
target='llvm', out_names=None, opt_level=3, mode='graph_runtime',
cuda_layout="NCHW"):
cuda_layout="NCHW", layout=None, disabled_pass=None):
""" Generic function to compile on relay and execute on tvm """
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
layout = None
if target == "cuda":
layout = cuda_layout
target_host = None
Expand All @@ -111,7 +112,8 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
layout=layout,
shape=shape_dict,
outputs=out_names)
if mode in ['debug', 'vm']:
ctx = tvm.context(target, 0)
if mode == 'debug':
ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target="llvm")
inputs = []
for param in mod['main'].params:
Expand All @@ -126,11 +128,19 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
inputs.append(tvm.nd.array(params[param.name_hint]))
result = ex.evaluate()(*inputs)
return vmobj_to_list(result)
elif mode == 'vm':
with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass):
vm_exec = relay.vm.compile(mod, target="llvm", params=params)
vm = VirtualMachine(vm_exec)
vm.init(tvm.cpu())
inputs = {}
for e, i in zip(input_node, input_data):
inputs[e] = i
result = vm.invoke("main", **inputs)
return vmobj_to_list(result)
else:
with tvm.transform.PassContext(opt_level=opt_level):
with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass):
graph, lib, params = relay.build(mod, target, target_host, params)

ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
Expand Down Expand Up @@ -888,10 +898,15 @@ def test_tensor_array_scatter():
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
if infer_shape:
element_shape = tf.TensorShape([tf.Dimension(None)])
else:
element_shape = None
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
indices = tf.constant([2, 1, 0])
ta1 = tf.TensorArray(dtype=dtype, size=3,
infer_shape=infer_shape)
infer_shape=infer_shape,
element_shape=element_shape)
ta2 = ta1.scatter(indices, t)
out0 = ta2.read(0)
out1 = ta2.read(1)
Expand Down Expand Up @@ -967,8 +982,14 @@ def test_tensor_array_size():
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
np_data = np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str)
in_data = [np_data, np_data]
t1 = tf.constant(np_data, dtype=dtype)
t2 = tf.constant(np_data, dtype=dtype)
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape)
out = ta1.size()
ta2 = ta1.write(0, t1)
ta3 = ta2.write(1, t2)
out = ta3.size()
g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
for dtype in ["float32", "int8"]:
Expand Down Expand Up @@ -2267,6 +2288,48 @@ def test_forward_resnetv2():
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]),
rtol=1e-5, atol=1e-5)

#######################################################################
# SSD
# ---


def _test_ssd_impl():
'''Test SSD with backbone MobileNet V1'''
with tf.Graph().as_default():
graph_def = tf_testing.get_workload(
"object_detection/ssd_mobilenet_v1_ppn_shared_"
"box_predictor_300x300_coco14_sync_2018_07_03.pb")
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)

data = np.random.uniform(0.0, 255.0, size=(1, 512, 512, 3)).astype('uint8')
in_node = "image_tensor"
out_node = ['detection_boxes', "detection_scores", "detection_classes"]

with tf.Session() as sess:
tf_output = run_tf_graph(
sess, data, '{}:0'.format(in_node), ["{}:0".format(oname) for oname in out_node])
# TODO(kevinthesun): enable gpu test when VM heterogeneous execution is ready.
for device in ["llvm"]:
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue
tvm_output = run_tvm_graph(graph_def, data, in_node, len(out_node),
target=device, layout="NCHW", out_names=out_node,
mode="vm", disabled_pass=["FoldScaleAxis"])
for i in range(len(out_node)):
tvm.testing.assert_allclose(tvm_output[i], tf_output[i],
rtol=1e-3, atol=1e-3)

def test_forward_ssd():
run_thread = threading.Thread(target=_test_ssd_impl, args=())
old_stack_size = threading.stack_size(100 * 1024 * 1024)
run_thread.start()
run_thread.join()
threading.stack_size(old_stack_size)


#######################################################################
# Placeholder
# -----------
Expand Down Expand Up @@ -3559,7 +3622,6 @@ def test_forward_spop():
# Main
# ----
if __name__ == '__main__':

# Transforms
test_forward_slice()
test_forward_transpose()
Expand Down Expand Up @@ -3664,6 +3726,7 @@ def test_forward_spop():
test_forward_inception_v1()
test_forward_mobilenet()
test_forward_resnetv2()
test_forward_ssd()
test_forward_placeholder()
test_forward_ptb()

Expand Down
1 change: 0 additions & 1 deletion tests/python/relay/test_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,7 +1336,6 @@ def run(dtype, shape):
p = Prelude(mod)
static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape)
static_tensor_array_ops.register()
static_tensor_array_ops.define_tensor_get_data(shape)

np_data_list = []
ta_length = 3
Expand Down

0 comments on commit ae119f8

Please sign in to comment.