From e00c3d9f6f4b6cac6b1688ef378d06e91c2b0c97 Mon Sep 17 00:00:00 2001 From: Unknown Date: Mon, 5 Aug 2019 06:58:23 -0700 Subject: [PATCH 1/7] Support BatchMatMul with shapes greater than length 3 --- python/tvm/relay/frontend/tensorflow.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 756022ba663e..908071d8adf9 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -448,11 +448,29 @@ def _impl(inputs, attr, params): def _batch_matmul(): def _impl(inputs, attr, params): + input_x = inputs[0] + input_y = inputs[1] + orig_shape_x = attr['_input_shapes'][inputs[0]] + + # reshape n-dimensional batch matmul into 3d + if len(orig_shape_x > 3): + outer_dims = [input_x.shape[i] for i in range(0, len(orig_shape_x) - 2)] + num_outer_elts = reduce((lambda x, y: x * y), outer_dims) + new_shape = (num_outer_elts, orig_shape_x[:-2], orig_shape_x[:-1]) + input_x = _op.reshape(input_x, newshape=new_shape) + input_x = _op.reshape(input_y, newshape=new_shape) + adj_x = attr['adj_x'] adj_y = attr['adj_y'] - input_x = _op.transpose(inputs[0], axes=[0, 2, 1]) if adj_x else inputs[0] - input_y = _op.transpose(inputs[1], axes=[0, 2, 1]) if not adj_y else inputs[1] - ret = get_relay_op('batch_matmul')(input_x, input_y) + input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x + input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y + ret = _get_relay_op('batch_matmul')(input_x, input_y) + + # reshape result back to n-dimensional + if len(orig_shape_x > 3): + final_shape = attr['_output_shapes'][0] + ret = _op.reshape(ret, final_shape) + return ret return _impl From 4684c291e3e765ef3e0d0454a2ba5af7f6bc3a89 Mon Sep 17 00:00:00 2001 From: Unknown Date: Tue, 6 Aug 2019 22:37:25 -0700 Subject: [PATCH 2/7] Fixes --- python/tvm/relay/frontend/tensorflow.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 908071d8adf9..51ce4b3bc608 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -21,6 +21,7 @@ import warnings from collections import defaultdict +from functools import reduce # Numpy support import numpy as np @@ -451,14 +452,16 @@ def _impl(inputs, attr, params): input_x = inputs[0] input_y = inputs[1] orig_shape_x = attr['_input_shapes'][inputs[0]] + orig_shape_y = attr['_input_shapes'][inputs[1]] # reshape n-dimensional batch matmul into 3d - if len(orig_shape_x > 3): - outer_dims = [input_x.shape[i] for i in range(0, len(orig_shape_x) - 2)] + if len(orig_shape_x) > 3: + outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)] num_outer_elts = reduce((lambda x, y: x * y), outer_dims) - new_shape = (num_outer_elts, orig_shape_x[:-2], orig_shape_x[:-1]) - input_x = _op.reshape(input_x, newshape=new_shape) - input_x = _op.reshape(input_y, newshape=new_shape) + new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) + new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) + input_x = _op.reshape(input_x, newshape=new_shape_x) + input_y = _op.reshape(input_y, newshape=new_shape_y) adj_x = attr['adj_x'] adj_y = attr['adj_y'] @@ -467,7 +470,7 @@ def _impl(inputs, attr, params): ret = _get_relay_op('batch_matmul')(input_x, input_y) # reshape result back to n-dimensional - if len(orig_shape_x > 3): + if len(orig_shape_x) > 3: final_shape = attr['_output_shapes'][0] ret = _op.reshape(ret, final_shape) From 630af552a4b77e0668efbff68f45061690cc5b14 Mon Sep 17 00:00:00 2001 From: Unknown Date: Wed, 7 Aug 2019 22:34:25 -0700 Subject: [PATCH 3/7] Add tests --- python/tvm/relay/frontend/tensorflow.py | 4 ++-- tests/python/frontend/tensorflow/test_forward.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 51ce4b3bc608..85eb0fd77dcc 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -451,8 +451,8 @@ def _batch_matmul(): def _impl(inputs, attr, params): input_x = inputs[0] input_y = inputs[1] - orig_shape_x = attr['_input_shapes'][inputs[0]] - orig_shape_y = attr['_input_shapes'][inputs[1]] + orig_shape_x = attr['_input_shapes'][input_x] + orig_shape_y = attr['_input_shapes'][input_y] # reshape n-dimensional batch matmul into 3d if len(orig_shape_x) > 3: diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 576e3d9f71df..161326964e85 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -669,6 +669,10 @@ def test_forward_batch_matmul(): _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32', True, True) _test_batch_matmul((3, 5, 4), (3, 5, 4), 'int32', True, False) _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True) + _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'int32') + _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'float32', True, True) + _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 5, 6), 'int32', True, False) + _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 5, 6), 'float32', False, True) ####################################################################### From d0a4cd4fe73ef26eeec6b7b3c99e797528c6a742 Mon Sep 17 00:00:00 2001 From: Unknown Date: Wed, 7 Aug 2019 22:39:48 -0700 Subject: [PATCH 4/7] Remove dependency on Python3 --- python/tvm/relay/frontend/tensorflow.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 85eb0fd77dcc..c3ea1a662735 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -21,7 +21,6 @@ import warnings from collections import defaultdict -from functools import reduce # Numpy support import numpy as np @@ -457,7 +456,10 @@ def _impl(inputs, attr, params): # reshape n-dimensional batch matmul into 3d if len(orig_shape_x) > 3: outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)] - num_outer_elts = reduce((lambda x, y: x * y), outer_dims) + num_outer_elts = 1 + for outer_dim in outer_dims: + num_outer_elts *= outer_dim + new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) input_x = _op.reshape(input_x, newshape=new_shape_x) From e122e219b562b396546c31a060de852b7f5ad248 Mon Sep 17 00:00:00 2001 From: Unknown Date: Wed, 7 Aug 2019 22:40:45 -0700 Subject: [PATCH 5/7] Clean up --- python/tvm/relay/frontend/tensorflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index c3ea1a662735..64c1dde0a631 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -474,7 +474,7 @@ def _impl(inputs, attr, params): # reshape result back to n-dimensional if len(orig_shape_x) > 3: final_shape = attr['_output_shapes'][0] - ret = _op.reshape(ret, final_shape) + ret = _op.reshape(ret, newshape=final_shape) return ret return _impl From 89a9a7b62215f46f88f4a6c340c44d325098f17b Mon Sep 17 00:00:00 2001 From: Unknown Date: Wed, 7 Aug 2019 22:55:24 -0700 Subject: [PATCH 6/7] Merge with master --- python/tvm/relay/frontend/tensorflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 64c1dde0a631..afc8243e0de4 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -469,7 +469,7 @@ def _impl(inputs, attr, params): adj_y = attr['adj_y'] input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y - ret = _get_relay_op('batch_matmul')(input_x, input_y) + ret = get_relay_op('batch_matmul')(input_x, input_y) # reshape result back to n-dimensional if len(orig_shape_x) > 3: From 5e845007ccb237a7d919d646aee24aea3ce08c08 Mon Sep 17 00:00:00 2001 From: Unknown Date: Fri, 9 Aug 2019 21:25:02 -0700 Subject: [PATCH 7/7] Resolve comments --- python/tvm/relay/frontend/tensorflow.py | 5 +---- tests/python/frontend/tensorflow/test_forward.py | 6 +++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index afc8243e0de4..48f000e6e101 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -456,10 +456,7 @@ def _impl(inputs, attr, params): # reshape n-dimensional batch matmul into 3d if len(orig_shape_x) > 3: outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)] - num_outer_elts = 1 - for outer_dim in outer_dims: - num_outer_elts *= outer_dim - + num_outer_elts = np.prod(outer_dims) new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) input_x = _op.reshape(input_x, newshape=new_shape_x) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 161326964e85..6c309cdf7292 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -670,9 +670,9 @@ def test_forward_batch_matmul(): _test_batch_matmul((3, 5, 4), (3, 5, 4), 'int32', True, False) _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True) _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'int32') - _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'float32', True, True) - _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 5, 6), 'int32', True, False) - _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 5, 6), 'float32', False, True) + _test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), 'float32', True, True) + _test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), 'int32', True, False) + _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), 'float32', False, True) #######################################################################