From 083ba340a4c513a7ac58647d22341accc07323d1 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Thu, 27 Aug 2020 18:09:47 -0400 Subject: [PATCH 1/8] Resolve the conflicts. --- trax/tf_numpy/extensions/extensions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trax/tf_numpy/extensions/extensions.py b/trax/tf_numpy/extensions/extensions.py index ed154c2b2..fb6422759 100644 --- a/trax/tf_numpy/extensions/extensions.py +++ b/trax/tf_numpy/extensions/extensions.py @@ -25,6 +25,7 @@ import threading import numpy as np import six +from more_itertools import sort_together import tensorflow.compat.v2 as tf From acd76cc1e57bbbd36fad55cd134e8acc4d6e2637 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Thu, 27 Aug 2020 18:21:40 -0400 Subject: [PATCH 2/8] Add the latest general convolution operation to Trax extensions. --- trax/tf_numpy/extensions/extensions.py | 87 +++++++++++++++++++++ trax/tf_numpy/extensions/extensions_test.py | 48 ++++++++++++ 2 files changed, 135 insertions(+) diff --git a/trax/tf_numpy/extensions/extensions.py b/trax/tf_numpy/extensions/extensions.py index fb6422759..153ccc689 100644 --- a/trax/tf_numpy/extensions/extensions.py +++ b/trax/tf_numpy/extensions/extensions.py @@ -565,6 +565,93 @@ def tf_dot_general(lhs, rhs, dimension_numbers): return tf.einsum(equation, lhs, rhs) +# TODO (Zhibo Zhang): Run pylint and complement the docstring. +def _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilation): + """ Convert the inputs strides, lhs_dilation, rhs_dilation to the standard + TF conv inputs. + For example, + in the 3D case, if lhs_dilation = 2, then convert it to [2, 2, 2] + if lhs_dilation = (2, 2, 2), convert it also to [2, 2, 2] + """ + strides = [window_strides] * dim if isinstance(window_strides, int) else \ + list(window_strides) + if lhs_dilation: + lhs_dilation = [lhs_dilation] * dim if isinstance(lhs_dilation, int) else \ + list(lhs_dilation) + if rhs_dilation: + rhs_dilation = [rhs_dilation] * dim if isinstance(rhs_dilation, int) else \ + list(rhs_dilation) + return (strides, lhs_dilation, rhs_dilation) + + +# TODO (Zhibo Zhang): Run pylint and complement the docstring. +# TOTO (Zhibo Zhang): Expand the test cases of general convolution and revise +# the according bugs. +# TODO (Zhibo Zhang): Support feature_group_count, batch_group_count and precision, and +# allow lhs_dilation and rhs_dilation to happen at the same time. +def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, + lhs_dilation=None, rhs_dilation=None, + dimension_numbers=None, feature_group_count=1, + batch_group_count=1, precision=None): + """ A general conv API that integrates normal conv, deconvolution, + dilated convolution, etc.""" + dim = None + lhs_spec, rhs_spec, out_spec = dimension_numbers + if lhs_spec != out_spec: + raise TypeError("Current implementation requires the `data_format` of the " + "inputs and outputs to be the same.") + if len(lhs_spec) >= 6: + raise TypeError("Current implmentation does not support 4 or higher" + "dimensional convolution, but got: ", len(lhs_spec) - 2) + dim = len(lhs_spec) - 2 + if lhs_dilation and rhs_dilation: + if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim: + lhs_dilation, rhs_dilation = None, None + else: + raise TypeError("Current implementation does not support that deconvolution" + "and dilation to be performed at the same time, but got" + " lhs_dilation: {}, rhs_dilation: {}".format(lhs_dilation, + rhs_dilation)) + if padding not in ["SAME", "VALID"]: + raise TypeError("Current implementation requires the padding parameter" + "to be either 'VALID' or 'SAME', but got: ", padding) + # Convert params from int/Sequence[int] to list of ints. + strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( + window_strides, lhs_dilation, rhs_dilation + ) + # Preprocess the shapes + dim_maps = {} + if isinstance(lhs_spec, str): + dim_maps['I'] = list(rhs_spec).index('I') + dim_maps['O'] = list(rhs_spec).index('O') + dim_maps['N'] = list(lhs_spec).index('N') + dim_maps['C'] = list(lhs_spec).index('C') + else: + dim_maps['I'] = rhs_spec[1] + dim_maps['O'] = rhs_spec[0] + dim_maps['N'] = lhs_spec[0] + dim_maps['C'] = lhs_spec[1] + + lhs = np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1)) + # Adjust the filters, put the dimension 'I' and 'O' at last. + rhs = np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim)) + spatial_dim_maps = {1: 'W', 2: "HW", 3: "DHW"} + data_format = 'N' + spatial_dim_maps[dim] + 'C' + tf_nn_APIs = {1: [nn.conv1d, nn.conv1d_transpose], + 2: [nn.conv2d, nn.conv2d_transpose], + 3: [nn.conv3d, nn.conv3d_transpose]} + + output = None + if rhs_dilation or (lhs_dilation is None and rhs_dilation is None): + output = tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, + rhs_dilation) + else: + output = tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides, + padding, data_format, lhs_dilation) + output = np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C'])) + return np.asarray(output) + + def conv(inp, fltr, window_strides, diff --git a/trax/tf_numpy/extensions/extensions_test.py b/trax/tf_numpy/extensions/extensions_test.py index 05c416466..0c98f5b88 100644 --- a/trax/tf_numpy/extensions/extensions_test.py +++ b/trax/tf_numpy/extensions/extensions_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized from jax import lax +import jax.numpy as jnp import numpy as np import tensorflow.compat.v2 as tf @@ -421,6 +422,53 @@ def test_tf_dot_general(self, lhs_np, rhs_np, dims): result = extensions.tf_dot_general(lhs_np, rhs_np, dims) self.assertAllClose(result, np.array(ans)) + + # TODO (Zhibo Zhang): Run pylint on this function. + @parameterized.named_parameters([ + ("_lhs_shape={}_rhs_shape={}_strides={}_padding={}" + "_lhs_dilation={}_rhs_dilation={}" + "_feature_group_count={}_batch_group_count={}_dims={}" + "_perms={}".format(lhs_shape, rhs_shape, + strides, padding, lhs_dilation, rhs_dilation, + feature_group_count, batch_group_count, ",".join(dimension_numbers), perms), + lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation, + feature_group_count, batch_group_count, dimension_numbers, perms) + for batch_group_count, feature_group_count in [(1, 1)] + for lhs_shape, rhs_shape in [ + ((b * batch_group_count, i * feature_group_count, 9, w), + (j * feature_group_count * batch_group_count, i, 4, 5)) + for w in [0, 10] + for b, i, j in itertools.product([2, 3], repeat=3)] + for strides in [(1, 1), (2, 1)] + for padding in ['SAME'] + for lhs_dilation, rhs_dilation in [ + (None, (1, 1)) + ] + for dimension_numbers, perms in [ + (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])) + ]]) + def testConvGeneralDilated(self, lhs_shape, rhs_shape, strides, + padding, lhs_dilation, rhs_dilation, + feature_group_count, batch_group_count, + dimension_numbers, perms): + tf.print("dimension_numbers: {}".format(dimension_numbers), output_stream=sys.stdout) + lhs_perm, rhs_perm = perms # permute to compatible shapes + + lhs_tf = tf_np.transpose(tf_np.ones(lhs_shape), lhs_perm) + rhs_tf = tf_np.transpose(tf_np.ones(rhs_shape), rhs_perm) + + lhs_jax = jnp.transpose(jnp.ones(lhs_shape), lhs_perm) + rhs_jax = jnp.transpose(jnp.ones(rhs_shape), rhs_perm) + + jax_conv = jax.lax.conv_general_dilated(lhs_jax, rhs_jax, strides, padding, lhs_dilation, + rhs_dilation, dimension_numbers, feature_group_count, batch_group_count) + + tf_conv = lax.conv_general_dilated(lhs_tf, rhs_tf, strides, padding, jax_conv.shape, lhs_dilation, + rhs_dilation, dimension_numbers, feature_group_count, batch_group_count) + + self.assertAllEqual(tf_conv, tf_np.asarray(jax_conv)) + + def testConv(self): y = extensions.conv( np.ones([5, 320, 480, 3], dtype=np.float32), From 8c5a7b948d54e9d61d7dd7ffc033f60b9bd56560 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Thu, 27 Aug 2020 18:48:33 -0400 Subject: [PATCH 3/8] Add the helper function `_eval_output_shape`. --- trax/tf_numpy/extensions/extensions.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/trax/tf_numpy/extensions/extensions.py b/trax/tf_numpy/extensions/extensions.py index 153ccc689..42a4552ca 100644 --- a/trax/tf_numpy/extensions/extensions.py +++ b/trax/tf_numpy/extensions/extensions.py @@ -565,6 +565,20 @@ def tf_dot_general(lhs, rhs, dimension_numbers): return tf.einsum(equation, lhs, rhs) +# TODO (Zhibo Zhang): Run pylint and complement the docstring. +def _eval_output_shape(lhs_shape, rhs_shape, padding, window_strides): + """ Evaluate the output shape in for transpose convolutions. + """ + output_shape = [lhs_shape[0]] + for i in range(1, len(lhs_shape) - 1): + if padding == "SAME": + output_shape.append((lhs_shape[i] - 1) * window_strides[i-1] + rhs_shape[i]) + if padding == "VALID": + output_shape.append((lhs_shape[i] - 1) * window_strides[i-1]) + output_shape.append(lhs_shape[-1]) + return tf.constant(output_shape) + + # TODO (Zhibo Zhang): Run pylint and complement the docstring. def _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilation): """ Convert the inputs strides, lhs_dilation, rhs_dilation to the standard From e169591adb24d506b26bfe94e1c02a959bf1e49d Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Thu, 27 Aug 2020 19:58:02 -0400 Subject: [PATCH 4/8] Revise some format problems according to pylint. --- trax/tf_numpy/extensions/extensions.py | 35 ++++++++++++++------------ 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/trax/tf_numpy/extensions/extensions.py b/trax/tf_numpy/extensions/extensions.py index 42a4552ca..f0aa9a7a1 100644 --- a/trax/tf_numpy/extensions/extensions.py +++ b/trax/tf_numpy/extensions/extensions.py @@ -28,6 +28,7 @@ from more_itertools import sort_together import tensorflow.compat.v2 as tf +from tensorflow import nn import trax.tf_numpy.numpy as tf_np @@ -565,22 +566,24 @@ def tf_dot_general(lhs, rhs, dimension_numbers): return tf.einsum(equation, lhs, rhs) -# TODO (Zhibo Zhang): Run pylint and complement the docstring. +# TODO (DarrenZhang01): Complement the docstring. def _eval_output_shape(lhs_shape, rhs_shape, padding, window_strides): """ Evaluate the output shape in for transpose convolutions. """ output_shape = [lhs_shape[0]] for i in range(1, len(lhs_shape) - 1): if padding == "SAME": - output_shape.append((lhs_shape[i] - 1) * window_strides[i-1] + rhs_shape[i]) + output_shape.append((lhs_shape[i] - 1) * window_strides[i-1] + + rhs_shape[i]) if padding == "VALID": output_shape.append((lhs_shape[i] - 1) * window_strides[i-1]) output_shape.append(lhs_shape[-1]) return tf.constant(output_shape) -# TODO (Zhibo Zhang): Run pylint and complement the docstring. -def _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilation): +# TODO (DarrenZhang01): Complement the docstring. +def _conv_general_param_type_converter(window_strides, lhs_dilation, + rhs_dilation): """ Convert the inputs strides, lhs_dilation, rhs_dilation to the standard TF conv inputs. For example, @@ -598,11 +601,11 @@ def _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilatio return (strides, lhs_dilation, rhs_dilation) -# TODO (Zhibo Zhang): Run pylint and complement the docstring. -# TOTO (Zhibo Zhang): Expand the test cases of general convolution and revise +# TODO (DarrenZhang01): Expand the test cases of general convolution and revise # the according bugs. -# TODO (Zhibo Zhang): Support feature_group_count, batch_group_count and precision, and -# allow lhs_dilation and rhs_dilation to happen at the same time. +# TODO (DarrenZhang01): Support feature_group_count, batch_group_count and +# precision, and allow lhs_dilation and rhs_dilation to happen at the +# same time. def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, lhs_dilation=None, rhs_dilation=None, dimension_numbers=None, feature_group_count=1, @@ -612,26 +615,26 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, dim = None lhs_spec, rhs_spec, out_spec = dimension_numbers if lhs_spec != out_spec: - raise TypeError("Current implementation requires the `data_format` of the " + raise ValueError("Current implementation requires the `data_format` of the " "inputs and outputs to be the same.") if len(lhs_spec) >= 6: - raise TypeError("Current implmentation does not support 4 or higher" + raise ValueError("Current implmentation does not support 4 or higher" "dimensional convolution, but got: ", len(lhs_spec) - 2) dim = len(lhs_spec) - 2 if lhs_dilation and rhs_dilation: if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim: lhs_dilation, rhs_dilation = None, None else: - raise TypeError("Current implementation does not support that deconvolution" - "and dilation to be performed at the same time, but got" - " lhs_dilation: {}, rhs_dilation: {}".format(lhs_dilation, - rhs_dilation)) + raise ValueError("Current implementation does not support that " + "deconvolution and dilation to be performed at the same " + "time, but got lhs_dilation: {}, rhs_dilation: {}".format( + lhs_dilation, rhs_dilation)) if padding not in ["SAME", "VALID"]: - raise TypeError("Current implementation requires the padding parameter" + raise ValueError("Current implementation requires the padding parameter" "to be either 'VALID' or 'SAME', but got: ", padding) # Convert params from int/Sequence[int] to list of ints. strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( - window_strides, lhs_dilation, rhs_dilation + window_strides, lhs_dilation, rhs_dilation ) # Preprocess the shapes dim_maps = {} From ea29343f9af70d7d20f0a1ab3cf56e40d993a074 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Thu, 27 Aug 2020 20:34:15 -0400 Subject: [PATCH 5/8] Define inner functions as pointed out in the code review. --- trax/tf_numpy/extensions/extensions.py | 38 ++++++++++++-------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/trax/tf_numpy/extensions/extensions.py b/trax/tf_numpy/extensions/extensions.py index f0aa9a7a1..b205d322c 100644 --- a/trax/tf_numpy/extensions/extensions.py +++ b/trax/tf_numpy/extensions/extensions.py @@ -39,6 +39,9 @@ tf.int64, tf.int32, tf.int16, tf.int8, tf.uint8, tf.uint16, tf.uint32, tf.uint64 ] +_tf_nn_APIs = {1: [nn.conv1d, nn.conv1d_transpose], + 2: [nn.conv2d, nn.conv2d_transpose], + 3: [nn.conv3d, nn.conv3d_transpose]} def most_precise_int_dtype(x): @@ -583,22 +586,20 @@ def _eval_output_shape(lhs_shape, rhs_shape, padding, window_strides): # TODO (DarrenZhang01): Complement the docstring. def _conv_general_param_type_converter(window_strides, lhs_dilation, - rhs_dilation): + rhs_dilation, dim): """ Convert the inputs strides, lhs_dilation, rhs_dilation to the standard TF conv inputs. For example, in the 3D case, if lhs_dilation = 2, then convert it to [2, 2, 2] if lhs_dilation = (2, 2, 2), convert it also to [2, 2, 2] """ - strides = [window_strides] * dim if isinstance(window_strides, int) else \ - list(window_strides) - if lhs_dilation: - lhs_dilation = [lhs_dilation] * dim if isinstance(lhs_dilation, int) else \ - list(lhs_dilation) - if rhs_dilation: - rhs_dilation = [rhs_dilation] * dim if isinstance(rhs_dilation, int) else \ - list(rhs_dilation) - return (strides, lhs_dilation, rhs_dilation) + def _as_list_of_size(item, size): + if item is None: + return None + return [item] * size if isinstance(item, int) else list(item) + return (_as_list_of_size(window_strides, dim), + _as_list_of_size(lhs_dilation, dim), + _as_list_of_size(rhs_dilation, dim)) # TODO (DarrenZhang01): Expand the test cases of general convolution and revise @@ -634,7 +635,7 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, "to be either 'VALID' or 'SAME', but got: ", padding) # Convert params from int/Sequence[int] to list of ints. strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( - window_strides, lhs_dilation, rhs_dilation + window_strides, lhs_dilation, rhs_dilation, dim ) # Preprocess the shapes dim_maps = {} @@ -649,24 +650,21 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, dim_maps['N'] = lhs_spec[0] dim_maps['C'] = lhs_spec[1] - lhs = np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1)) + lhs = tf_np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1)) # Adjust the filters, put the dimension 'I' and 'O' at last. - rhs = np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim)) + rhs = tf_np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim)) spatial_dim_maps = {1: 'W', 2: "HW", 3: "DHW"} data_format = 'N' + spatial_dim_maps[dim] + 'C' - tf_nn_APIs = {1: [nn.conv1d, nn.conv1d_transpose], - 2: [nn.conv2d, nn.conv2d_transpose], - 3: [nn.conv3d, nn.conv3d_transpose]} output = None if rhs_dilation or (lhs_dilation is None and rhs_dilation is None): - output = tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, + output = _tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, rhs_dilation) else: - output = tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides, + output = _tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides, padding, data_format, lhs_dilation) - output = np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C'])) - return np.asarray(output) + output = tf_np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C'])) + return tf_np.asarray(output) def conv(inp, From a13a238041d5fa760b13c9c0c72f9c9db9c0eeb5 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Thu, 27 Aug 2020 21:02:17 -0400 Subject: [PATCH 6/8] Revise the rest of the issues according to the code review. --- trax/tf_numpy/extensions/extensions.py | 74 ++++++++++++--------- trax/tf_numpy/extensions/extensions_test.py | 45 +++++++------ 2 files changed, 66 insertions(+), 53 deletions(-) diff --git a/trax/tf_numpy/extensions/extensions.py b/trax/tf_numpy/extensions/extensions.py index b205d322c..c615e50d1 100644 --- a/trax/tf_numpy/extensions/extensions.py +++ b/trax/tf_numpy/extensions/extensions.py @@ -25,7 +25,6 @@ import threading import numpy as np import six -from more_itertools import sort_together import tensorflow.compat.v2 as tf from tensorflow import nn @@ -569,22 +568,6 @@ def tf_dot_general(lhs, rhs, dimension_numbers): return tf.einsum(equation, lhs, rhs) -# TODO (DarrenZhang01): Complement the docstring. -def _eval_output_shape(lhs_shape, rhs_shape, padding, window_strides): - """ Evaluate the output shape in for transpose convolutions. - """ - output_shape = [lhs_shape[0]] - for i in range(1, len(lhs_shape) - 1): - if padding == "SAME": - output_shape.append((lhs_shape[i] - 1) * window_strides[i-1] + - rhs_shape[i]) - if padding == "VALID": - output_shape.append((lhs_shape[i] - 1) * window_strides[i-1]) - output_shape.append(lhs_shape[-1]) - return tf.constant(output_shape) - - -# TODO (DarrenZhang01): Complement the docstring. def _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilation, dim): """ Convert the inputs strides, lhs_dilation, rhs_dilation to the standard @@ -607,32 +590,58 @@ def _as_list_of_size(item, size): # TODO (DarrenZhang01): Support feature_group_count, batch_group_count and # precision, and allow lhs_dilation and rhs_dilation to happen at the # same time. -def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, - lhs_dilation=None, rhs_dilation=None, - dimension_numbers=None, feature_group_count=1, - batch_group_count=1, precision=None): - """ A general conv API that integrates normal conv, deconvolution, - dilated convolution, etc.""" +def tf_conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, + lhs_dilation=None, rhs_dilation=None, + dimension_numbers=None, feature_group_count=1, + batch_group_count=1, precision=None): + """ A general conv API for TensorFlow. + + According JAX version: + https://jax.readthedocs.io/en/stable/_autosummary/jax.lax.conv_general_dilated.html + + Args: (Use JAX documentation as a reference) + lhs: a rank n+2 dimensional input array. + rhs: a rank n+2 dimensional array of kernel weights. + window_strides: a sequence of n integers, representing the inter-window + strides. + padding: either the string ‘SAME’, the string ‘VALID’, or a sequence of n + (low, high) integer pairs that give the padding to apply before and + after each spatial dimension. + output_shape: the output shape of the convolution. + lhs_dilation: None, or a sequence of n integers, giving the dilation factor + to apply in each spatial dimension of lhs. LHS dilation is + also known as transposed convolution. + rhs_dilation: None, or a sequence of n integers, giving the dilation factor + to apply in each spatial dimension of rhs. RHS dilation is + also known as atrous convolution. + dimension_numbers: either None, a ConvDimensionNumbers object, or a 3-tuple + (lhs_spec, rhs_spec, out_spec), where each element is a + string of length n+2. + feature_group_count: integer, default 1. + batch_group_count: integer, default 1. + precision: Optional. Either None, which means the default precision for the + backend, or a Precision enum value. + """ dim = None lhs_spec, rhs_spec, out_spec = dimension_numbers if lhs_spec != out_spec: raise ValueError("Current implementation requires the `data_format` of the " - "inputs and outputs to be the same.") + "inputs and outputs to be the same.") if len(lhs_spec) >= 6: raise ValueError("Current implmentation does not support 4 or higher" - "dimensional convolution, but got: ", len(lhs_spec) - 2) + "dimensional convolution, but got: ", len(lhs_spec) - 2) dim = len(lhs_spec) - 2 if lhs_dilation and rhs_dilation: if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim: lhs_dilation, rhs_dilation = None, None else: raise ValueError("Current implementation does not support that " - "deconvolution and dilation to be performed at the same " - "time, but got lhs_dilation: {}, rhs_dilation: {}".format( - lhs_dilation, rhs_dilation)) + "deconvolution and dilation to be performed at the same " + "time, but got lhs_dilation: {}, rhs_dilation: {}" + .format(lhs_dilation, rhs_dilation)) if padding not in ["SAME", "VALID"]: raise ValueError("Current implementation requires the padding parameter" - "to be either 'VALID' or 'SAME', but got: ", padding) + "to be either 'VALID' or 'SAME', but got: ", padding) # Convert params from int/Sequence[int] to list of ints. strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( window_strides, lhs_dilation, rhs_dilation, dim @@ -656,15 +665,14 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, spatial_dim_maps = {1: 'W', 2: "HW", 3: "DHW"} data_format = 'N' + spatial_dim_maps[dim] + 'C' - output = None if rhs_dilation or (lhs_dilation is None and rhs_dilation is None): output = _tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, - rhs_dilation) + rhs_dilation) else: output = _tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides, - padding, data_format, lhs_dilation) + padding, data_format, lhs_dilation) output = tf_np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C'])) - return tf_np.asarray(output) + return output def conv(inp, diff --git a/trax/tf_numpy/extensions/extensions_test.py b/trax/tf_numpy/extensions/extensions_test.py index 0c98f5b88..ba8a06388 100644 --- a/trax/tf_numpy/extensions/extensions_test.py +++ b/trax/tf_numpy/extensions/extensions_test.py @@ -21,16 +21,18 @@ import functools from absl import flags +import itertools from absl.testing import parameterized +import jax from jax import lax -import jax.numpy as jnp import numpy as np import tensorflow.compat.v2 as tf from trax.tf_numpy import extensions import trax.tf_numpy.numpy as tf_np + FLAGS = flags.FLAGS flags.DEFINE_bool("requires_tpu", False, "Requires TPU.") @@ -423,16 +425,16 @@ def test_tf_dot_general(self, lhs_np, rhs_np, dims): self.assertAllClose(result, np.array(ans)) - # TODO (Zhibo Zhang): Run pylint on this function. @parameterized.named_parameters([ ("_lhs_shape={}_rhs_shape={}_strides={}_padding={}" "_lhs_dilation={}_rhs_dilation={}" "_feature_group_count={}_batch_group_count={}_dims={}" "_perms={}".format(lhs_shape, rhs_shape, - strides, padding, lhs_dilation, rhs_dilation, - feature_group_count, batch_group_count, ",".join(dimension_numbers), perms), - lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation, - feature_group_count, batch_group_count, dimension_numbers, perms) + strides, padding, lhs_dilation, rhs_dilation, + feature_group_count, batch_group_count, ",".join( + dimension_numbers), perms), + lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation, + feature_group_count, batch_group_count, dimension_numbers, perms) for batch_group_count, feature_group_count in [(1, 1)] for lhs_shape, rhs_shape in [ ((b * batch_group_count, i * feature_group_count, 9, w), @@ -442,29 +444,32 @@ def test_tf_dot_general(self, lhs_np, rhs_np, dims): for strides in [(1, 1), (2, 1)] for padding in ['SAME'] for lhs_dilation, rhs_dilation in [ - (None, (1, 1)) + (None, (1, 1)) ] for dimension_numbers, perms in [ - (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])) + (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])) ]]) def testConvGeneralDilated(self, lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation, feature_group_count, batch_group_count, dimension_numbers, perms): - tf.print("dimension_numbers: {}".format(dimension_numbers), output_stream=sys.stdout) lhs_perm, rhs_perm = perms # permute to compatible shapes - lhs_tf = tf_np.transpose(tf_np.ones(lhs_shape), lhs_perm) - rhs_tf = tf_np.transpose(tf_np.ones(rhs_shape), rhs_perm) - - lhs_jax = jnp.transpose(jnp.ones(lhs_shape), lhs_perm) - rhs_jax = jnp.transpose(jnp.ones(rhs_shape), rhs_perm) - - jax_conv = jax.lax.conv_general_dilated(lhs_jax, rhs_jax, strides, padding, lhs_dilation, - rhs_dilation, dimension_numbers, feature_group_count, batch_group_count) - - tf_conv = lax.conv_general_dilated(lhs_tf, rhs_tf, strides, padding, jax_conv.shape, lhs_dilation, - rhs_dilation, dimension_numbers, feature_group_count, batch_group_count) + lhs = np.transpose(np.ones(lhs_shape), lhs_perm) + rhs = np.transpose(np.ones(rhs_shape), rhs_perm) + + jax_conv = jax.lax.conv_general_dilated(lhs, rhs, strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers, + feature_group_count, + batch_group_count) + + tf_conv = extensions.tf_conv_general_dilated(lhs, rhs, strides, + padding, jax_conv.shape, + lhs_dilation, rhs_dilation, + dimension_numbers, + feature_group_count, + batch_group_count) self.assertAllEqual(tf_conv, tf_np.asarray(jax_conv)) From 5887a7ac29a8420139bc2f09bc5e38848b9b4821 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Thu, 27 Aug 2020 21:39:12 -0400 Subject: [PATCH 7/8] Revise some details on the format, etc. --- trax/tf_numpy/extensions/extensions.py | 25 +++++++++++++++------ trax/tf_numpy/extensions/extensions_test.py | 5 ++--- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/trax/tf_numpy/extensions/extensions.py b/trax/tf_numpy/extensions/extensions.py index c615e50d1..b532d1b73 100644 --- a/trax/tf_numpy/extensions/extensions.py +++ b/trax/tf_numpy/extensions/extensions.py @@ -27,7 +27,6 @@ import six import tensorflow.compat.v2 as tf -from tensorflow import nn import trax.tf_numpy.numpy as tf_np @@ -38,9 +37,9 @@ tf.int64, tf.int32, tf.int16, tf.int8, tf.uint8, tf.uint16, tf.uint32, tf.uint64 ] -_tf_nn_APIs = {1: [nn.conv1d, nn.conv1d_transpose], - 2: [nn.conv2d, nn.conv2d_transpose], - 3: [nn.conv3d, nn.conv3d_transpose]} +_tf_nn_APIs = {1: [tf.nn.conv1d, tf.nn.conv1d_transpose], + 2: [tf.nn.conv2d, tf.nn.conv2d_transpose], + 3: [tf.nn.conv3d, tf.nn.conv3d_transpose]} def most_precise_int_dtype(x): @@ -607,7 +606,8 @@ def tf_conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, padding: either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. - output_shape: the output shape of the convolution. + output_shape: the output shape of the convolution (only required for + transpose convolution). lhs_dilation: None, or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of lhs. LHS dilation is also known as transposed convolution. @@ -617,10 +617,15 @@ def tf_conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, dimension_numbers: either None, a ConvDimensionNumbers object, or a 3-tuple (lhs_spec, rhs_spec, out_spec), where each element is a string of length n+2. - feature_group_count: integer, default 1. - batch_group_count: integer, default 1. + feature_group_count: integer, default 1. Changing this is currently not + supported. + batch_group_count: integer, default 1. Changing this is currently not + supported. precision: Optional. Either None, which means the default precision for the backend, or a Precision enum value. + + Returns: + A TF NumPy array that contains the convolution result. """ dim = None lhs_spec, rhs_spec, out_spec = dimension_numbers @@ -642,6 +647,12 @@ def tf_conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, if padding not in ["SAME", "VALID"]: raise ValueError("Current implementation requires the padding parameter" "to be either 'VALID' or 'SAME', but got: ", padding) + if batch_group_count != 1 or feature_group_count != 1: + raise NotImplementedError("batch_group_count and feature_group_count " + "other than 1 is currently not supported, but" + " got feature_group_count: {}, batch_group_count" + ": {}".format(feature_group_count, + batch_group_count)) # Convert params from int/Sequence[int] to list of ints. strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( window_strides, lhs_dilation, rhs_dilation, dim diff --git a/trax/tf_numpy/extensions/extensions_test.py b/trax/tf_numpy/extensions/extensions_test.py index ba8a06388..d74a179ab 100644 --- a/trax/tf_numpy/extensions/extensions_test.py +++ b/trax/tf_numpy/extensions/extensions_test.py @@ -24,7 +24,6 @@ import itertools from absl.testing import parameterized -import jax from jax import lax import numpy as np import tensorflow.compat.v2 as tf @@ -458,14 +457,14 @@ def testConvGeneralDilated(self, lhs_shape, rhs_shape, strides, lhs = np.transpose(np.ones(lhs_shape), lhs_perm) rhs = np.transpose(np.ones(rhs_shape), rhs_perm) - jax_conv = jax.lax.conv_general_dilated(lhs, rhs, strides, padding, + jax_conv = lax.conv_general_dilated(lhs, rhs, strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count) tf_conv = extensions.tf_conv_general_dilated(lhs, rhs, strides, - padding, jax_conv.shape, + padding, None, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, From 9943cf0c38b2cb91df74940bd0afcaaad9a8bef2 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Fri, 28 Aug 2020 15:38:48 -0400 Subject: [PATCH 8/8] Check if `precision` is in default setting, raise `NotImplementedError` otherwise. --- trax/tf_numpy/extensions/extensions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trax/tf_numpy/extensions/extensions.py b/trax/tf_numpy/extensions/extensions.py index b532d1b73..bfaf4e4a5 100644 --- a/trax/tf_numpy/extensions/extensions.py +++ b/trax/tf_numpy/extensions/extensions.py @@ -653,6 +653,9 @@ def tf_conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, " got feature_group_count: {}, batch_group_count" ": {}".format(feature_group_count, batch_group_count)) + if precision is not None: + raise NotImplementedError("precision other than `None` is currently not " + "supported, but got: {}".format(precision)) # Convert params from int/Sequence[int] to list of ints. strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( window_strides, lhs_dilation, rhs_dilation, dim