Skip to content

Commit

Permalink
Add attention modules.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701245045
  • Loading branch information
jan-matthis authored and copybara-github committed Nov 29, 2024
1 parent 21c2943 commit c4dc8ff
Show file tree
Hide file tree
Showing 2 changed files with 383 additions and 0 deletions.
287 changes: 287 additions & 0 deletions connectomics/jax/models/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Attention layers for n-dimensional spatial inputs.
Currently, contains pixel-based, patch-based, axis-based attention mechanisms
with typical defaults and options used for vision transformers, such as position
biases and learnable positional embeddings.
"""

from collections.abc import Callable
import functools
from absl import logging
import einops
from flax import linen as nn
from flax.linen import initializers
import jax.numpy as jnp
from scenic.model_lib.layers import attention_layers as scenic_attn

Array = jnp.ndarray


class PositionalEmbedding(nn.Module):
"""Adds learnable positional embeddings to [b, ..., l, d] inputs."""

@nn.compact
def __call__(self, x):
*_, l, d = x.shape
initializer = initializers.normal(stddev=d**-0.5)
pos_embed = self.param('pos_embed', initializer, (l, d))
return x + pos_embed


class Attention(nn.Module):
"""Multi-head attention customized from scenic.
Attributes:
num_heads: Number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
qkv_features: Dimension of the key, query, and value.
dropout: Dropout rate.
positional_embed: Whether to add positional embeddings.
relative_attention_bias: Whether to use relative attention bias.
"""

num_heads: int = 32
qkv_features: int | None = None
dropout: float = 0.0
positional_embed: bool = True
relative_attention_bias: bool = True
seq_shard_fn: Callable[[Array], Array] = lambda x: x

def to_seq(self, x: Array) -> tuple[Array, tuple[int, ...]]:
"""Reshape input to sequence and return spatial shape."""
return x, x.shape[1:-1]

def from_seq(self, x: Array, spatial_shape: tuple[int, ...]) -> Array:
"""Reshape input to spatial and return output."""
return x

def get_attn_shape(self, spatial_shape: tuple[int, ...]) -> tuple[int, ...]:
"""Get n-dimensional shape that attention is applied to."""
return spatial_shape

@nn.compact
def __call__(self, x: Array, train: bool = False) -> jnp.ndarray:
"""Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
This can be used for encoder-decoder attention by specifying both `inputs_q`
and `inputs_kv` or for self-attention by only specifying `inputs_q` and
setting `inputs_kv` to None.
Args:
x: Inputs of shape `[bs, ..., features]`.
train: Whether the model is in train mode.
Returns:
Output of shape `[bs, ..., features]`.
"""
x, spatial_shape = self.to_seq(x)
x = self.seq_shard_fn(x)
out_features = x.shape[-1]
qkv_features = self.qkv_features or x.shape[-1]
assert (
qkv_features % self.num_heads == 0
), 'Memory dimension must be divisible by number of heads.'
head_dim = qkv_features // self.num_heads

if self.positional_embed:
x = PositionalEmbedding(name='pos_embed')(x)

# Project inputs_q to multi-headed q/k/v with dimensions
# [..., l, num_heads, num_features_per_head].
dense = functools.partial(
nn.DenseGeneral,
axis=-1,
features=(self.num_heads, head_dim),
)
query, key, value = (
dense(name='query')(x),
dense(name='key')(x),
dense(name='value')(x),
)
query = nn.LayerNorm(name='query_ln', use_bias=False)(query)
key = nn.LayerNorm(name='key_ln', use_bias=False)(key)

if self.relative_attention_bias:
attention_bias = scenic_attn.RelativeAttentionBias(
self.num_heads, self.get_attn_shape(spatial_shape)
)()
else:
attention_bias = None

if train and self.dropout > 0:
dropout_rng = self.make_rng('dropout')
else:
dropout_rng = None

x = scenic_attn.dot_product_attention(
query,
key,
value,
bias=attention_bias,
dropout_rate=self.dropout,
dropout_rng=dropout_rng,
deterministic=not train,
)

# Back to the original inputs dimensions.
out = nn.DenseGeneral(
features=out_features, axis=(-2, -1), use_bias=True, name='out'
)(x)

return self.from_seq(out, spatial_shape)


class VoxelAttention(Attention):
"""Multi-head attention with voxels as sequence elements."""

def to_seq(self, x: Array) -> tuple[Array, tuple[int, ...]]:
b, *_, c = x.shape
return x.reshape(b, -1, c), x.shape[1:-1]

def from_seq(self, x: Array, spatial_shape: tuple[int, ...]) -> Array:
b, *_, c = x.shape
return x.reshape(b, *spatial_shape, c)


class EinopAttention(Attention):
"""Multi-head attention with einops patterns."""

def _get_pattern(
self, spatial_shape: tuple[int, ...]
) -> tuple[str, str, dict[str, int]]:
raise NotImplementedError()

def to_seq(self, x: Array) -> tuple[Array, tuple[int, ...]]:
spatial_shape = x.shape[1:-1]
in_pattern, out_pattern, axes = self._get_pattern(spatial_shape)
pattern = in_pattern + ' -> ' + out_pattern
logging.info(
'Volume to sequence pattern %r for shape %r', pattern, spatial_shape
)
x = einops.rearrange(x, pattern, **axes)
return x, spatial_shape

def from_seq(self, x: Array, spatial_shape: tuple[int, ...]) -> Array:
in_pattern, out_pattern, axes = self._get_pattern(spatial_shape)
pattern = out_pattern + ' -> ' + in_pattern
x = einops.rearrange(x, pattern, **axes)
return x

def get_attn_shape(self, spatial_shape: tuple[int, ...]) -> tuple[int, ...]:
raise NotImplementedError()


class PatchAttention(EinopAttention):
"""Multi-head attention between contiguous patches of voxels.
For 2D and with p{i} denoting patch sizes and d{i} patched dimensions,
the pattern is 'b (d1 p1) (d2 p2) c -> b (d1 d2) (p1 p2 c)' so we have
d1 x d2 sequence items with p1 p2 c feature dimension.
"""

patch_sizes: tuple[int, ...] = (8,)

def _get_pattern(
self, spatial_shape: tuple[int, ...]
) -> tuple[str, str, dict[str, int]]:
spatial_in, spatial_out, feature_out = '', '', ''
axes = dict()
if len(spatial_shape) != len(self.patch_sizes):
raise ValueError('spatial_shape and patch_sizes must have same length')
for i, (dim, patch_size) in enumerate(zip(spatial_shape, self.patch_sizes)):
spatial_in += f'(d{i+1} p{i+1}) '
spatial_out += f'd{i+1} '
feature_out += f'p{i+1} '
axes[f'p{i+1}'] = patch_size
axes[f'd{i+1}'] = dim // patch_size
in_pattern = f'b {spatial_in}c'
out_pattern = f'b ({spatial_out}) ({feature_out}c)'
return in_pattern, out_pattern, axes

def get_attn_shape(self, spatial_shape: tuple[int, ...]) -> tuple[int, ...]:
return tuple(d // p for d, p in zip(spatial_shape, self.patch_sizes))


class BlockAttention(EinopAttention):
"""Multi-head attention within contiguous patches of voxels.
For 2D and with p{i} denoting patch sizes and d{i} patched dimensions,
the pattern is 'b (d1 p1) (d2 p2) c -> b d1 d2 (p1 p2) c' so we have
p1 x p2 sequence items with c feature dimension (d1, d2 treated as batch).
"""

patch_sizes: tuple[int, ...] = (8,)

def _get_pattern(
self, spatial_shape: tuple[int, ...]
) -> tuple[str, str, dict[str, int]]:
spatial_in, spatial_out, feature_out = '', '', ''
axes = dict()
if len(spatial_shape) != len(self.patch_sizes):
raise ValueError('spatial_shape and patch_sizes must have same length')
for i, (dim, patch_size) in enumerate(zip(spatial_shape, self.patch_sizes)):
spatial_in += f'(d{i+1} p{i+1}) '
spatial_out += f'd{i+1} '
feature_out += f'p{i+1} '
axes[f'p{i+1}'] = patch_size
axes[f'd{i+1}'] = dim // patch_size
in_pattern = f'b {spatial_in}c'
out_pattern = f'b {spatial_out} ({feature_out}) c'
logging.info(
'block attention patterns %r %r and axes %r',
in_pattern,
out_pattern,
axes,
)
return in_pattern, out_pattern, axes

def get_attn_shape(self, spatial_shape: tuple[int, ...]) -> tuple[int, ...]:
return self.patch_sizes


class GridAttention(EinopAttention):
"""Multi-head attention within strided grid spaced across input volume.
For 2D and with p{i} denoting patch sizes and d{i} patched dimensions,
the pattern is 'b (p1 d1) (p2 d2) c -> b d1 d2 (p1 p2) c' so we have
p1 x p2 sequence items with c feature dimension (d1, d2 treated as batch).
"""

patch_sizes: tuple[int, ...] = (8,)

def _get_pattern(
self, spatial_shape: tuple[int, ...]
) -> tuple[str, str, dict[str, int]]:
spatial_in, spatial_out, feature_out = '', '', ''
axes = dict()
if len(spatial_shape) != len(self.patch_sizes):
raise ValueError('spatial_shape and patch_sizes must have same length')
for i, (dim, patch_size) in enumerate(zip(spatial_shape, self.patch_sizes)):
spatial_in += f'(p{i+1} d{i+1}) '
spatial_out += f'd{i+1} '
feature_out += f'p{i+1} '
axes[f'p{i+1}'] = patch_size
axes[f'd{i+1}'] = dim // patch_size
in_pattern = f'b {spatial_in}c'
out_pattern = f'b {spatial_out} ({feature_out}) c'
return in_pattern, out_pattern, axes

def get_attn_shape(self, spatial_shape: tuple[int, ...]) -> tuple[int, ...]:
return self.patch_sizes
96 changes: 96 additions & 0 deletions connectomics/jax/models/attention_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test vision attention modules."""

from absl.testing import absltest
from connectomics.jax.models import attention as attn
import einops
import jax


class AttentionTest(absltest.TestCase):

def setUp(self):
super().setUp()
self.batch_size = 2
self.features = 3
self.rng = jax.random.PRNGKey(0)

def test_base_attention(self):
seq = jax.random.normal(self.rng, (self.batch_size, 128, self.features))
model = attn.Attention(num_heads=2, qkv_features=64)
variables = model.init(self.rng, seq)
seq_out = model.apply(variables, seq, train=True)
self.assertSequenceEqual(seq_out.shape, seq.shape)

def test_attention_dropout(self):
seq = jax.random.normal(self.rng, (self.batch_size, 128, self.features))
model = attn.Attention(num_heads=2, qkv_features=64, dropout=0.2)
variables = model.init(self.rng, seq)
seq_out_determined = model.apply(variables, seq, train=False).mean()
seq_out_dropout = model.apply(
variables, seq, train=True, rngs={"dropout": self.rng}
).mean()
self.assertNotEqual(seq_out_dropout, seq_out_determined)

def test_voxel_attention(self):
x = jax.random.normal(self.rng, (self.batch_size, 4, 2, self.features))
model = attn.VoxelAttention(num_heads=2, qkv_features=4)
variables = model.init(self.rng, x)
x_out = model.apply(variables, x)
self.assertSequenceEqual(x_out.shape, x.shape)

def test_patch_attention(self):
x = jax.random.normal(self.rng, (self.batch_size, 8, 4, 2, self.features))
ps = (2, 2, 2)
model = attn.PatchAttention(num_heads=2, qkv_features=4, patch_sizes=ps)
variables = model.init(self.rng, x)
x_out = model.apply(variables, x)
self.assertSequenceEqual(x_out.shape, x.shape)

def test_grid_attention(self):
x = jax.random.normal(self.rng, (self.batch_size, 8, 4, 2, self.features))
ps = (2, 2, 2)
model = attn.GridAttention(num_heads=2, qkv_features=4, patch_sizes=ps)
variables = model.init(self.rng, x)
x_out = model.apply(variables, x)
self.assertSequenceEqual(x_out.shape, x.shape)

def test_block_attention(self):
x = jax.random.normal(self.rng, (self.batch_size, 8, 4, 2, self.features))
ps = (2, 2, 2)
model = attn.BlockAttention(num_heads=2, qkv_features=4, patch_sizes=ps)
variables = model.init(self.rng, x)
x_out = model.apply(variables, x)
self.assertSequenceEqual(x_out.shape, x.shape)

def test_grid_non_uniform_patches(self):
x = jax.random.normal(self.rng, (self.batch_size, 1, 8, 2, self.features))
ps = (1, 4, 2)
model = attn.GridAttention(num_heads=2, qkv_features=4, patch_sizes=ps)
variables = model.init(self.rng, x)
x_out = model.apply(variables, x)
self.assertSequenceEqual(x_out.shape, x.shape)

def test_non_divisible_patches_fail(self):
x = jax.random.normal(self.rng, (self.batch_size, 1, 8, 2, self.features))
ps = (1, 3, 2)
model = attn.GridAttention(num_heads=2, qkv_features=4, patch_sizes=ps)
with self.assertRaises(einops.EinopsError):
model.init(self.rng, x)


if __name__ == "__main__":
absltest.main()

0 comments on commit c4dc8ff

Please sign in to comment.