Skip to content

Commit

Permalink
Residual Shuffle Exchange implementation
Browse files Browse the repository at this point in the history
https://arxiv.org/pdf/2004.04662.pdf

PiperOrigin-RevId: 344425889
  • Loading branch information
kkanska authored and copybara-github committed Nov 26, 2020
1 parent 7a3537d commit 3a8f402
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 11 deletions.
2 changes: 2 additions & 0 deletions trax/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from trax.models.research import funnel_transformer
from trax.models.research import layerdrop_transformer
from trax.models.research import rezero
from trax.models.research import rse
from trax.models.research import transformer2


Expand Down Expand Up @@ -90,3 +91,4 @@ def model_configure(*args, **kwargs):
funnel_transformer.FunnelTransformerEncoder)
FunnelTransformer = model_configure(
funnel_transformer.FunnelTransformer)
ResidualShuffleExchange = model_configure(rse.ResidualShuffleExchange)
68 changes: 60 additions & 8 deletions trax/models/rse.py → trax/models/research/rse.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _inverse_sigmoid(x):


@assert_shape('...->...')
class ClippedScaling(tl.Layer):
class _ClippedScaling(tl.Layer):
"""Pointwise multiplies by sigmoid(S) with a learnable vector S."""

def __init__(self,
Expand Down Expand Up @@ -89,7 +89,7 @@ def ResidualSwitchUnit(
tl.Fn('Scaling',
lambda x: x * np.sqrt(1 - residual_weight**2) * 0.25,
n_out=1),
shortcut=ClippedScaling(residual_weight)),
shortcut=_ClippedScaling(residual_weight)),
tl.Fn(
'UnPair',
lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] * 2, -1)),
Expand All @@ -98,7 +98,7 @@ def ResidualSwitchUnit(
)


def ror(x, n, p=1):
def _ror(x, n, p=1):
"""Bitwise right rotation.
Args:
Expand All @@ -117,7 +117,7 @@ def ror(x, n, p=1):
return a + d


def rol(x, n, p=1):
def _rol(x, n, p=1):
"""Bitwise left rotation.
Args:
Expand All @@ -136,7 +136,7 @@ def rol(x, n, p=1):
return np.bitwise_or(c, d)


def shuffle_layer(inputs, shuffle_fn):
def _shuffle_layer(inputs, shuffle_fn):
"""Shuffles the elements according to bitwise left or right rotation.
Args:
Expand All @@ -147,7 +147,7 @@ def shuffle_layer(inputs, shuffle_fn):
tf.Tensor: Inputs shifted according to shuffle_fn
"""
seq_length = inputs.shape[1]
n_bits = np.int32(np.ceil(np.log(seq_length - 1) / np.log(2.0)))
n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1

indices = np.arange(0, seq_length).astype('int32')
rev_indices = shuffle_fn(indices, n_bits)
Expand All @@ -157,10 +157,62 @@ def shuffle_layer(inputs, shuffle_fn):
@assert_shape('bld->bld')
def ShuffleLayer():
return tl.Fn(
'ShuffleLayer', lambda x: shuffle_layer(x, ror), n_out=1)
'ShuffleLayer', lambda x: _shuffle_layer(x, _rol), n_out=1)


@assert_shape('bld->bld')
def ReverseShuffleLayer():
return tl.Fn(
'ReverseShuffleLayer', lambda x: shuffle_layer(x, rol), n_out=1)
'ReverseShuffleLayer', lambda x: _shuffle_layer(x, _ror), n_out=1)


@assert_shape('...,bld->...,bld')
def _ForwardStep(d_model, dropout, mode):
"""Takes (n_layer, state) and returns (n_layer, shuffle_layer(rsu(state)))."""
return tl.Parallel([], tl.Serial(
ResidualSwitchUnit(d_model, dropout, mode),
ShuffleLayer(),
))


@assert_shape('...,bld->...,bld')
def _BackwardStep(d_model, dropout, mode):
"""Takes (n_layer, state) and returns (n_layer, reverse_shuffle_layer(rsu(state)))."""
return tl.Parallel([], tl.Serial(
ResidualSwitchUnit(d_model, dropout, mode),
ReverseShuffleLayer(),
))


@assert_shape('bld->bld')
def BenesBlock(d_model, dropout, mode):
def bit_sequence(inputs):
seq_length = inputs.shape[1]
n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1
return jnp.arange(0, n_bits)
return tl.Serial(
tl.Dup(),
tl.Fn('BitSeq', bit_sequence, n_out=1),
tl.Scan(_ForwardStep(d_model, dropout, mode)),
tl.Scan(_BackwardStep(d_model, dropout, mode)),
tl.Select([1]),
)


@assert_shape('bl->blv')
def ResidualShuffleExchange(vocab_size,
d_model,
dropout,
mode='train',
n_blocks=2):
"""Returns a Residual Shuffle Exchange Network model."""
benes_blocks = [BenesBlock(d_model, dropout, mode) for _ in range(n_blocks)]
return tl.Serial(
tl.Embedding(vocab_size, d_model),
# Apply Benes Block n_blocks times.
*benes_blocks,
ResidualSwitchUnit(d_model, dropout, mode),
# Produce probabilities.
tl.Dense(vocab_size),
tl.LogSoftmax(),
)
16 changes: 13 additions & 3 deletions trax/models/rse_test.py → trax/models/research/rse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np

from trax import shapes
from trax.models import rse
from trax.models.research import rse


class RSETest(absltest.TestCase):
Expand All @@ -42,7 +42,7 @@ def test_shuffle_layer(self):
print(x.shape)
_, _ = shuffle_layer.init(shapes.signature(x))
y = shuffle_layer(x)
expected_output = np.array([[[0], [4], [1], [5], [2], [6], [3], [7]]])
expected_output = np.array([[[0], [2], [4], [6], [1], [3], [5], [7]]])
self._assert_equal_tensors(y, expected_output)

def test_shuffle_layer_log_times_is_identity(self):
Expand All @@ -62,7 +62,7 @@ def test_reverse_shuffle_layer(self):
print(x.shape)
_, _ = reverse_shuffle_layer.init(shapes.signature(x))
y = reverse_shuffle_layer(x)
expected_output = np.array([[[0], [2], [4], [6], [1], [3], [5], [7]]])
expected_output = np.array([[[0], [4], [1], [5], [2], [6], [3], [7]]])
self._assert_equal_tensors(y, expected_output)

def test_reverse_shuffle_layer_log_times_is_identity(self):
Expand All @@ -76,6 +76,16 @@ def test_reverse_shuffle_layer_log_times_is_identity(self):
y = reverse_shuffle_layer(y)
self._assert_equal_tensors(x, y)

def test_rse_forward_shape(self):
vocab_size = 12
seq_len = 32
model = rse.ResidualShuffleExchange(
vocab_size=vocab_size, d_model=17, dropout=0.1, mode='train')
x = np.ones((3, seq_len)).astype(np.int32)
_, _ = model.init(shapes.signature(x))
y = model(x)
self.assertEqual(y.shape, (3, seq_len, vocab_size))

def _assert_equal_tensors(self, x, y):
self.assertEqual(y.shape, x.shape)
for i in range(x.shape[0]):
Expand Down
54 changes: 54 additions & 0 deletions trax/supervised/configs/rse_addition.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2020 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import trax.models
import trax.optimizers
import trax.data.inputs
import trax.data.tf_inputs
import trax.supervised.trainer_lib

vocab_size = 5 # For arithmetic operations, base = vocab_size - 3.

# Parameters for addition_inputs:
# ==============================================================================
addition_inputs.vocab_size = %vocab_size
addition_inputs.batch_size = 128
addition_inputs.train_length = 16
addition_inputs.eval_min_length = 16
addition_inputs.eval_max_length = 16
addition_inputs.pad_to_multiple = 16
addition_inputs.encdec = True

# Parameters for multifactor:
# ==============================================================================
multifactor.constant = 0.5
multifactor.factors = 'constant * linear_warmup * rsqrt_decay'
multifactor.warmup_steps = 4000

# Parameters for RSE:
# ==============================================================================
ResidualShuffleExchange.vocab_size = %vocab_size
ResidualShuffleExchange.d_model = 192
ResidualShuffleExchange.dropout = 0.05
ResidualShuffleExchange.mode = 'train'
ResidualShuffleExchange.n_blocks = 2

# Parameters for train:
# ==============================================================================
train.inputs = @trax.data.inputs.addition_inputs
train.eval_frequency = 1000
train.eval_steps = 50
train.optimizer = @trax.optimizers.Adam
train.steps = 30000
train.model = @trax.models.ResidualShuffleExchange

0 comments on commit 3a8f402

Please sign in to comment.