Skip to content

Commit

Permalink
Use XLA sum-reduce instead of CNN for CNN kernel propagation when `di…
Browse files Browse the repository at this point in the history
…agonal_spatial == True`. This gives a +25% speedup on V100!

PiperOrigin-RevId: 306550220
  • Loading branch information
romanngg committed Apr 15, 2020
1 parent 65215c5 commit 252ed85
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions neural_tangents/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def kernel_fn(kernels):

if kernels.diagonal_spatial:
def conv_unscaled(x, batch_ndim):
x = _conv_kernel_over_spatial(
x = _conv_kernel_diagonal_spatial(
x, filter_shape_kernel, strides_kernel, padding, batch_ndim)
return x

Expand All @@ -659,7 +659,7 @@ def conv_unscaled(x, batch_ndim):
is_reversed = not is_reversed

def conv_unscaled(x, batch_ndim):
x = _conv_kernel(
x = _conv_kernel_diagonal_full(
x, filter_shape_kernel, strides_kernel, padding, batch_ndim)
return x

Expand Down Expand Up @@ -2522,7 +2522,11 @@ def _pad_one_side(x, pads, axes, mode):
return x


def _conv_kernel(mat, filter_shape, strides, padding, batch_ndim):
def _conv_kernel_diagonal_full(mat,
filter_shape,
strides,
padding,
batch_ndim):
"""Compute covariance of the CNN outputs given inputs with covariance `mat`.
Used when `kernel.diagonal_spatial == False`.
Expand Down Expand Up @@ -2589,7 +2593,11 @@ def _conv_kernel(mat, filter_shape, strides, padding, batch_ndim):
return mat


def _conv_kernel_over_spatial(mat, filter_shape, strides, padding, batch_ndim):
def _conv_kernel_diagonal_spatial(mat,
filter_shape,
strides,
padding,
batch_ndim):
"""Compute covariance of the CNN outputs given inputs with covariance `mat`.
Used when `kernel.diagonal_spatial == True`.
Expand All @@ -2615,20 +2623,17 @@ def _conv_kernel_over_spatial(mat, filter_shape, strides, padding, batch_ndim):
if not utils.is_array(mat):
return mat

spatial_axes = tuple(range(mat.ndim)[batch_ndim:])

if padding == Padding.CIRCULAR:
spatial_axes = tuple(range(mat.ndim)[batch_ndim:])
mat = _same_pad_for_filter_shape(mat, filter_shape, strides,
spatial_axes, 'wrap')
padding = Padding.VALID

ker = np.full((1, 1) + filter_shape, 1. / np.prod(filter_shape), mat.dtype)

batch_shape, spatial_shape = mat.shape[:batch_ndim], mat.shape[batch_ndim:]
mat = np.reshape(mat, (-1,) + spatial_shape)
mat = np.expand_dims(mat, 1)
mat = lax.conv_general_dilated(mat, ker, strides, padding.name)
mat = mat.reshape(batch_shape + mat.shape[2:])
filter_shape = (1,) * batch_ndim + filter_shape
filter_size = functools.reduce(op.mul, filter_shape, 1)
strides = (1,) * batch_ndim + strides
mat = lax._reduce_window_sum(mat, filter_shape, strides, padding.name)
mat /= filter_size
return mat


Expand Down

0 comments on commit 252ed85

Please sign in to comment.