diff --git a/trax/tf_numpy/extensions/extensions.py b/trax/tf_numpy/extensions/extensions.py index ed154c2b2..bfaf4e4a5 100644 --- a/trax/tf_numpy/extensions/extensions.py +++ b/trax/tf_numpy/extensions/extensions.py @@ -37,6 +37,9 @@ tf.int64, tf.int32, tf.int16, tf.int8, tf.uint8, tf.uint16, tf.uint32, tf.uint64 ] +_tf_nn_APIs = {1: [tf.nn.conv1d, tf.nn.conv1d_transpose], + 2: [tf.nn.conv2d, tf.nn.conv2d_transpose], + 3: [tf.nn.conv3d, tf.nn.conv3d_transpose]} def most_precise_int_dtype(x): @@ -564,6 +567,128 @@ def tf_dot_general(lhs, rhs, dimension_numbers): return tf.einsum(equation, lhs, rhs) +def _conv_general_param_type_converter(window_strides, lhs_dilation, + rhs_dilation, dim): + """ Convert the inputs strides, lhs_dilation, rhs_dilation to the standard + TF conv inputs. + For example, + in the 3D case, if lhs_dilation = 2, then convert it to [2, 2, 2] + if lhs_dilation = (2, 2, 2), convert it also to [2, 2, 2] + """ + def _as_list_of_size(item, size): + if item is None: + return None + return [item] * size if isinstance(item, int) else list(item) + return (_as_list_of_size(window_strides, dim), + _as_list_of_size(lhs_dilation, dim), + _as_list_of_size(rhs_dilation, dim)) + + +# TODO (DarrenZhang01): Expand the test cases of general convolution and revise +# the according bugs. +# TODO (DarrenZhang01): Support feature_group_count, batch_group_count and +# precision, and allow lhs_dilation and rhs_dilation to happen at the +# same time. +def tf_conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, + lhs_dilation=None, rhs_dilation=None, + dimension_numbers=None, feature_group_count=1, + batch_group_count=1, precision=None): + """ A general conv API for TensorFlow. + + According JAX version: + https://jax.readthedocs.io/en/stable/_autosummary/jax.lax.conv_general_dilated.html + + Args: (Use JAX documentation as a reference) + lhs: a rank n+2 dimensional input array. + rhs: a rank n+2 dimensional array of kernel weights. + window_strides: a sequence of n integers, representing the inter-window + strides. + padding: either the string ‘SAME’, the string ‘VALID’, or a sequence of n + (low, high) integer pairs that give the padding to apply before and + after each spatial dimension. + output_shape: the output shape of the convolution (only required for + transpose convolution). + lhs_dilation: None, or a sequence of n integers, giving the dilation factor + to apply in each spatial dimension of lhs. LHS dilation is + also known as transposed convolution. + rhs_dilation: None, or a sequence of n integers, giving the dilation factor + to apply in each spatial dimension of rhs. RHS dilation is + also known as atrous convolution. + dimension_numbers: either None, a ConvDimensionNumbers object, or a 3-tuple + (lhs_spec, rhs_spec, out_spec), where each element is a + string of length n+2. + feature_group_count: integer, default 1. Changing this is currently not + supported. + batch_group_count: integer, default 1. Changing this is currently not + supported. + precision: Optional. Either None, which means the default precision for the + backend, or a Precision enum value. + + Returns: + A TF NumPy array that contains the convolution result. + """ + dim = None + lhs_spec, rhs_spec, out_spec = dimension_numbers + if lhs_spec != out_spec: + raise ValueError("Current implementation requires the `data_format` of the " + "inputs and outputs to be the same.") + if len(lhs_spec) >= 6: + raise ValueError("Current implmentation does not support 4 or higher" + "dimensional convolution, but got: ", len(lhs_spec) - 2) + dim = len(lhs_spec) - 2 + if lhs_dilation and rhs_dilation: + if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim: + lhs_dilation, rhs_dilation = None, None + else: + raise ValueError("Current implementation does not support that " + "deconvolution and dilation to be performed at the same " + "time, but got lhs_dilation: {}, rhs_dilation: {}" + .format(lhs_dilation, rhs_dilation)) + if padding not in ["SAME", "VALID"]: + raise ValueError("Current implementation requires the padding parameter" + "to be either 'VALID' or 'SAME', but got: ", padding) + if batch_group_count != 1 or feature_group_count != 1: + raise NotImplementedError("batch_group_count and feature_group_count " + "other than 1 is currently not supported, but" + " got feature_group_count: {}, batch_group_count" + ": {}".format(feature_group_count, + batch_group_count)) + if precision is not None: + raise NotImplementedError("precision other than `None` is currently not " + "supported, but got: {}".format(precision)) + # Convert params from int/Sequence[int] to list of ints. + strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( + window_strides, lhs_dilation, rhs_dilation, dim + ) + # Preprocess the shapes + dim_maps = {} + if isinstance(lhs_spec, str): + dim_maps['I'] = list(rhs_spec).index('I') + dim_maps['O'] = list(rhs_spec).index('O') + dim_maps['N'] = list(lhs_spec).index('N') + dim_maps['C'] = list(lhs_spec).index('C') + else: + dim_maps['I'] = rhs_spec[1] + dim_maps['O'] = rhs_spec[0] + dim_maps['N'] = lhs_spec[0] + dim_maps['C'] = lhs_spec[1] + + lhs = tf_np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1)) + # Adjust the filters, put the dimension 'I' and 'O' at last. + rhs = tf_np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim)) + spatial_dim_maps = {1: 'W', 2: "HW", 3: "DHW"} + data_format = 'N' + spatial_dim_maps[dim] + 'C' + + if rhs_dilation or (lhs_dilation is None and rhs_dilation is None): + output = _tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, + rhs_dilation) + else: + output = _tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides, + padding, data_format, lhs_dilation) + output = tf_np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C'])) + return output + + def conv(inp, fltr, window_strides, diff --git a/trax/tf_numpy/extensions/extensions_test.py b/trax/tf_numpy/extensions/extensions_test.py index 05c416466..d74a179ab 100644 --- a/trax/tf_numpy/extensions/extensions_test.py +++ b/trax/tf_numpy/extensions/extensions_test.py @@ -21,6 +21,7 @@ import functools from absl import flags +import itertools from absl.testing import parameterized from jax import lax @@ -30,6 +31,7 @@ from trax.tf_numpy import extensions import trax.tf_numpy.numpy as tf_np + FLAGS = flags.FLAGS flags.DEFINE_bool("requires_tpu", False, "Requires TPU.") @@ -421,6 +423,56 @@ def test_tf_dot_general(self, lhs_np, rhs_np, dims): result = extensions.tf_dot_general(lhs_np, rhs_np, dims) self.assertAllClose(result, np.array(ans)) + + @parameterized.named_parameters([ + ("_lhs_shape={}_rhs_shape={}_strides={}_padding={}" + "_lhs_dilation={}_rhs_dilation={}" + "_feature_group_count={}_batch_group_count={}_dims={}" + "_perms={}".format(lhs_shape, rhs_shape, + strides, padding, lhs_dilation, rhs_dilation, + feature_group_count, batch_group_count, ",".join( + dimension_numbers), perms), + lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation, + feature_group_count, batch_group_count, dimension_numbers, perms) + for batch_group_count, feature_group_count in [(1, 1)] + for lhs_shape, rhs_shape in [ + ((b * batch_group_count, i * feature_group_count, 9, w), + (j * feature_group_count * batch_group_count, i, 4, 5)) + for w in [0, 10] + for b, i, j in itertools.product([2, 3], repeat=3)] + for strides in [(1, 1), (2, 1)] + for padding in ['SAME'] + for lhs_dilation, rhs_dilation in [ + (None, (1, 1)) + ] + for dimension_numbers, perms in [ + (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])) + ]]) + def testConvGeneralDilated(self, lhs_shape, rhs_shape, strides, + padding, lhs_dilation, rhs_dilation, + feature_group_count, batch_group_count, + dimension_numbers, perms): + lhs_perm, rhs_perm = perms # permute to compatible shapes + + lhs = np.transpose(np.ones(lhs_shape), lhs_perm) + rhs = np.transpose(np.ones(rhs_shape), rhs_perm) + + jax_conv = lax.conv_general_dilated(lhs, rhs, strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers, + feature_group_count, + batch_group_count) + + tf_conv = extensions.tf_conv_general_dilated(lhs, rhs, strides, + padding, None, + lhs_dilation, rhs_dilation, + dimension_numbers, + feature_group_count, + batch_group_count) + + self.assertAllEqual(tf_conv, tf_np.asarray(jax_conv)) + + def testConv(self): y = extensions.conv( np.ones([5, 320, 480, 3], dtype=np.float32),