Skip to content

Commit

Permalink
Support multi output layers in the Sequential combinator
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 495905245
  • Loading branch information
aboSamoor authored and Flax Authors committed Dec 16, 2022
1 parent 4f24933 commit 6d5bc2a
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 2 deletions.
34 changes: 32 additions & 2 deletions flax/linen/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Combinators of modules, such as a Sequential."""

from typing import Any, Callable, Sequence
from typing import Any, Callable, Dict, Sequence

from flax.linen.module import Module

Expand Down Expand Up @@ -42,6 +42,31 @@ def __call__(self, x):
nn.relu,
nn.Dense(2),
nn.log_softmax])(x)
This combinator supports also layers that return multiple outputs if returned
as a tuple or a dictionary.
Example usage::
class CrossAttentionBlock(nn.Module):
num_heads: int = 2
qkv_features: int = 16
@nn.compact
def __call__(self, query, key_value):
output = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads, qkv_features=self.qkv_features)(query,
key_value)
output = nn.Dense(self.qkv_features)(output)
return dict(query=output, key_value=key_value) # also works for tuples
class CrossAttentionNetwork(nn.Module):
num_layers: Sequence[int]
@nn.compact
def __call__(self, x):
return nn.Sequential([CrossAttentionBlock() for _ in
range(self.num_layers)])(query, key_value)
"""
layers: Sequence[Callable[..., Any]]

Expand All @@ -51,5 +76,10 @@ def __call__(self, *args, **kwargs):

outputs = self.layers[0](*args, **kwargs)
for layer in self.layers[1:]:
outputs = layer(outputs)
if isinstance(outputs, tuple):
outputs = layer(*outputs)
elif isinstance(outputs, Dict):
outputs = layer(**outputs)
else:
outputs = layer(outputs)
return outputs
57 changes: 57 additions & 0 deletions tests/linen/linen_combinators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,30 @@ def __call__(self, inputs):
return self.activation_final(x)


class AttentionTuple(nn.Module):
num_heads: int = 2
qkv_features: int = 16

@nn.compact
def __call__(self, query, key_value):
output = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads, qkv_features=self.qkv_features)(query,
key_value)
return output, key_value


class AttentionDict(nn.Module):
num_heads: int = 2
qkv_features: int = 16

@nn.compact
def __call__(self, query, key_value):
output = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads, qkv_features=self.qkv_features)(query,
key_value)
return dict(query=output, key_value=key_value)


class SequentialTest(absltest.TestCase):

def test_construction(self):
Expand Down Expand Up @@ -103,5 +127,38 @@ def test_same_output_as_mlp_with_activation(self):
np.testing.assert_array_equal(output_1, output_2)


def test_tuple_output(self):
sequential = nn.Sequential([
AttentionTuple(),
AttentionTuple(),
])

key1, key2 = random.split(random.PRNGKey(0), 2)
query = random.uniform(key1, (3, 5))
key_value = random.uniform(key1, (9, 5))
params_1 = sequential.init(key2, query, key_value)
outputs = sequential.apply(params_1, query, key_value)
np.testing.assert_equal(len(outputs), 2)
out_query, out_key_value = outputs
np.testing.assert_equal(out_query.shape, (3, 5))
np.testing.assert_equal(out_key_value.shape, (9, 5))

def test_dict_output(self):
sequential = nn.Sequential([
AttentionDict(),
AttentionDict(),
])

key1, key2 = random.split(random.PRNGKey(0), 2)
query = random.uniform(key1, (3, 5))
key_value = random.uniform(key1, (9, 5))
params_1 = sequential.init(key2, query, key_value)
outputs = sequential.apply(params_1, query, key_value)
np.testing.assert_equal(len(outputs), 2)
out_query, out_key_value = outputs['query'], outputs['key_value']
np.testing.assert_equal(out_query.shape, (3, 5))
np.testing.assert_equal(out_key_value.shape, (9, 5))


if __name__ == '__main__':
absltest.main()

0 comments on commit 6d5bc2a

Please sign in to comment.