Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the general convolution operation to extensions #954

Merged
merged 8 commits into from
Aug 28, 2020
122 changes: 122 additions & 0 deletions trax/tf_numpy/extensions/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
tf.int64, tf.int32, tf.int16, tf.int8, tf.uint8, tf.uint16, tf.uint32,
tf.uint64
]
_tf_nn_APIs = {1: [tf.nn.conv1d, tf.nn.conv1d_transpose],
2: [tf.nn.conv2d, tf.nn.conv2d_transpose],
3: [tf.nn.conv3d, tf.nn.conv3d_transpose]}


def most_precise_int_dtype(x):
Expand Down Expand Up @@ -564,6 +567,125 @@ def tf_dot_general(lhs, rhs, dimension_numbers):
return tf.einsum(equation, lhs, rhs)


def _conv_general_param_type_converter(window_strides, lhs_dilation,
rhs_dilation, dim):
""" Convert the inputs strides, lhs_dilation, rhs_dilation to the standard
TF conv inputs.
For example,
in the 3D case, if lhs_dilation = 2, then convert it to [2, 2, 2]
if lhs_dilation = (2, 2, 2), convert it also to [2, 2, 2]
"""
def _as_list_of_size(item, size):
if item is None:
return None
return [item] * size if isinstance(item, int) else list(item)
return (_as_list_of_size(window_strides, dim),
_as_list_of_size(lhs_dilation, dim),
_as_list_of_size(rhs_dilation, dim))


# TODO (DarrenZhang01): Expand the test cases of general convolution and revise
# the according bugs.
# TODO (DarrenZhang01): Support feature_group_count, batch_group_count and
# precision, and allow lhs_dilation and rhs_dilation to happen at the
# same time.
def tf_conv_general_dilated(lhs, rhs, window_strides, padding, output_shape,
lhs_dilation=None, rhs_dilation=None,
dimension_numbers=None, feature_group_count=1,
batch_group_count=1, precision=None):
""" A general conv API for TensorFlow.

According JAX version:
https://jax.readthedocs.io/en/stable/_autosummary/jax.lax.conv_general_dilated.html

Args: (Use JAX documentation as a reference)
lhs: a rank n+2 dimensional input array.
rhs: a rank n+2 dimensional array of kernel weights.
window_strides: a sequence of n integers, representing the inter-window
strides.
padding: either the string ‘SAME’, the string ‘VALID’, or a sequence of n
(low, high) integer pairs that give the padding to apply before and
after each spatial dimension.
output_shape: the output shape of the convolution (only required for
transpose convolution).
lhs_dilation: None, or a sequence of n integers, giving the dilation factor
to apply in each spatial dimension of lhs. LHS dilation is
also known as transposed convolution.
rhs_dilation: None, or a sequence of n integers, giving the dilation factor
to apply in each spatial dimension of rhs. RHS dilation is
also known as atrous convolution.
dimension_numbers: either None, a ConvDimensionNumbers object, or a 3-tuple
(lhs_spec, rhs_spec, out_spec), where each element is a
string of length n+2.
feature_group_count: integer, default 1. Changing this is currently not
supported.
batch_group_count: integer, default 1. Changing this is currently not
supported.
precision: Optional. Either None, which means the default precision for the
backend, or a Precision enum value.
DarrenZhang01 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
A TF NumPy array that contains the convolution result.
"""
dim = None
DarrenZhang01 marked this conversation as resolved.
Show resolved Hide resolved
lhs_spec, rhs_spec, out_spec = dimension_numbers
if lhs_spec != out_spec:
raise ValueError("Current implementation requires the `data_format` of the "
"inputs and outputs to be the same.")
if len(lhs_spec) >= 6:
raise ValueError("Current implmentation does not support 4 or higher"
"dimensional convolution, but got: ", len(lhs_spec) - 2)
dim = len(lhs_spec) - 2
if lhs_dilation and rhs_dilation:
if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim:
lhs_dilation, rhs_dilation = None, None
else:
raise ValueError("Current implementation does not support that "
"deconvolution and dilation to be performed at the same "
"time, but got lhs_dilation: {}, rhs_dilation: {}"
.format(lhs_dilation, rhs_dilation))
if padding not in ["SAME", "VALID"]:
raise ValueError("Current implementation requires the padding parameter"
"to be either 'VALID' or 'SAME', but got: ", padding)
DarrenZhang01 marked this conversation as resolved.
Show resolved Hide resolved
if batch_group_count != 1 or feature_group_count != 1:
raise NotImplementedError("batch_group_count and feature_group_count "
"other than 1 is currently not supported, but"
" got feature_group_count: {}, batch_group_count"
": {}".format(feature_group_count,
batch_group_count))
DarrenZhang01 marked this conversation as resolved.
Show resolved Hide resolved
# Convert params from int/Sequence[int] to list of ints.
strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter(
window_strides, lhs_dilation, rhs_dilation, dim
)
# Preprocess the shapes
dim_maps = {}
if isinstance(lhs_spec, str):
dim_maps['I'] = list(rhs_spec).index('I')
dim_maps['O'] = list(rhs_spec).index('O')
dim_maps['N'] = list(lhs_spec).index('N')
dim_maps['C'] = list(lhs_spec).index('C')
else:
dim_maps['I'] = rhs_spec[1]
dim_maps['O'] = rhs_spec[0]
dim_maps['N'] = lhs_spec[0]
dim_maps['C'] = lhs_spec[1]

lhs = tf_np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1))
# Adjust the filters, put the dimension 'I' and 'O' at last.
rhs = tf_np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim))
spatial_dim_maps = {1: 'W', 2: "HW", 3: "DHW"}
data_format = 'N' + spatial_dim_maps[dim] + 'C'

if rhs_dilation or (lhs_dilation is None and rhs_dilation is None):
output = _tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format,
rhs_dilation)
else:
output = _tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides,
padding, data_format, lhs_dilation)
output = tf_np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C']))
return output


def conv(inp,
fltr,
window_strides,
Expand Down
52 changes: 52 additions & 0 deletions trax/tf_numpy/extensions/extensions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import functools
from absl import flags
import itertools
from absl.testing import parameterized

from jax import lax
Expand All @@ -30,6 +31,7 @@
from trax.tf_numpy import extensions
import trax.tf_numpy.numpy as tf_np


FLAGS = flags.FLAGS

flags.DEFINE_bool("requires_tpu", False, "Requires TPU.")
Expand Down Expand Up @@ -421,6 +423,56 @@ def test_tf_dot_general(self, lhs_np, rhs_np, dims):
result = extensions.tf_dot_general(lhs_np, rhs_np, dims)
self.assertAllClose(result, np.array(ans))


@parameterized.named_parameters([
("_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
"_lhs_dilation={}_rhs_dilation={}"
"_feature_group_count={}_batch_group_count={}_dims={}"
"_perms={}".format(lhs_shape, rhs_shape,
strides, padding, lhs_dilation, rhs_dilation,
feature_group_count, batch_group_count, ",".join(
dimension_numbers), perms),
lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation,
feature_group_count, batch_group_count, dimension_numbers, perms)
for batch_group_count, feature_group_count in [(1, 1)]
for lhs_shape, rhs_shape in [
((b * batch_group_count, i * feature_group_count, 9, w),
(j * feature_group_count * batch_group_count, i, 4, 5))
for w in [0, 10]
for b, i, j in itertools.product([2, 3], repeat=3)]
for strides in [(1, 1), (2, 1)]
for padding in ['SAME']
for lhs_dilation, rhs_dilation in [
(None, (1, 1))
]
for dimension_numbers, perms in [
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0]))
]])
def testConvGeneralDilated(self, lhs_shape, rhs_shape, strides,
padding, lhs_dilation, rhs_dilation,
feature_group_count, batch_group_count,
dimension_numbers, perms):
lhs_perm, rhs_perm = perms # permute to compatible shapes

lhs = np.transpose(np.ones(lhs_shape), lhs_perm)
rhs = np.transpose(np.ones(rhs_shape), rhs_perm)

jax_conv = lax.conv_general_dilated(lhs, rhs, strides, padding,
lhs_dilation, rhs_dilation,
dimension_numbers,
feature_group_count,
batch_group_count)

tf_conv = extensions.tf_conv_general_dilated(lhs, rhs, strides,
padding, None,
lhs_dilation, rhs_dilation,
dimension_numbers,
feature_group_count,
batch_group_count)

self.assertAllEqual(tf_conv, tf_np.asarray(jax_conv))


def testConv(self):
y = extensions.conv(
np.ones([5, 320, 480, 3], dtype=np.float32),
Expand Down