-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
2 changed files
with
383 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,287 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Attention layers for n-dimensional spatial inputs. | ||
Currently, contains pixel-based, patch-based, axis-based attention mechanisms | ||
with typical defaults and options used for vision transformers, such as position | ||
biases and learnable positional embeddings. | ||
""" | ||
|
||
from collections.abc import Callable | ||
import functools | ||
from absl import logging | ||
import einops | ||
from flax import linen as nn | ||
from flax.linen import initializers | ||
import jax.numpy as jnp | ||
from scenic.model_lib.layers import attention_layers as scenic_attn | ||
|
||
Array = jnp.ndarray | ||
|
||
|
||
class PositionalEmbedding(nn.Module): | ||
"""Adds learnable positional embeddings to [b, ..., l, d] inputs.""" | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
*_, l, d = x.shape | ||
initializer = initializers.normal(stddev=d**-0.5) | ||
pos_embed = self.param('pos_embed', initializer, (l, d)) | ||
return x + pos_embed | ||
|
||
|
||
class Attention(nn.Module): | ||
"""Multi-head attention customized from scenic. | ||
Attributes: | ||
num_heads: Number of attention heads. Features (i.e. inputs_q.shape[-1]) | ||
should be divisible by the number of heads. | ||
qkv_features: Dimension of the key, query, and value. | ||
dropout: Dropout rate. | ||
positional_embed: Whether to add positional embeddings. | ||
relative_attention_bias: Whether to use relative attention bias. | ||
""" | ||
|
||
num_heads: int = 32 | ||
qkv_features: int | None = None | ||
dropout: float = 0.0 | ||
positional_embed: bool = True | ||
relative_attention_bias: bool = True | ||
seq_shard_fn: Callable[[Array], Array] = lambda x: x | ||
|
||
def to_seq(self, x: Array) -> tuple[Array, tuple[int, ...]]: | ||
"""Reshape input to sequence and return spatial shape.""" | ||
return x, x.shape[1:-1] | ||
|
||
def from_seq(self, x: Array, spatial_shape: tuple[int, ...]) -> Array: | ||
"""Reshape input to spatial and return output.""" | ||
return x | ||
|
||
def get_attn_shape(self, spatial_shape: tuple[int, ...]) -> tuple[int, ...]: | ||
"""Get n-dimensional shape that attention is applied to.""" | ||
return spatial_shape | ||
|
||
@nn.compact | ||
def __call__(self, x: Array, train: bool = False) -> jnp.ndarray: | ||
"""Applies multi-head dot product attention on the input data. | ||
Projects the inputs into multi-headed query, key, and value vectors, | ||
applies dot-product attention and project the results to an output vector. | ||
This can be used for encoder-decoder attention by specifying both `inputs_q` | ||
and `inputs_kv` or for self-attention by only specifying `inputs_q` and | ||
setting `inputs_kv` to None. | ||
Args: | ||
x: Inputs of shape `[bs, ..., features]`. | ||
train: Whether the model is in train mode. | ||
Returns: | ||
Output of shape `[bs, ..., features]`. | ||
""" | ||
x, spatial_shape = self.to_seq(x) | ||
x = self.seq_shard_fn(x) | ||
out_features = x.shape[-1] | ||
qkv_features = self.qkv_features or x.shape[-1] | ||
assert ( | ||
qkv_features % self.num_heads == 0 | ||
), 'Memory dimension must be divisible by number of heads.' | ||
head_dim = qkv_features // self.num_heads | ||
|
||
if self.positional_embed: | ||
x = PositionalEmbedding(name='pos_embed')(x) | ||
|
||
# Project inputs_q to multi-headed q/k/v with dimensions | ||
# [..., l, num_heads, num_features_per_head]. | ||
dense = functools.partial( | ||
nn.DenseGeneral, | ||
axis=-1, | ||
features=(self.num_heads, head_dim), | ||
) | ||
query, key, value = ( | ||
dense(name='query')(x), | ||
dense(name='key')(x), | ||
dense(name='value')(x), | ||
) | ||
query = nn.LayerNorm(name='query_ln', use_bias=False)(query) | ||
key = nn.LayerNorm(name='key_ln', use_bias=False)(key) | ||
|
||
if self.relative_attention_bias: | ||
attention_bias = scenic_attn.RelativeAttentionBias( | ||
self.num_heads, self.get_attn_shape(spatial_shape) | ||
)() | ||
else: | ||
attention_bias = None | ||
|
||
if train and self.dropout > 0: | ||
dropout_rng = self.make_rng('dropout') | ||
else: | ||
dropout_rng = None | ||
|
||
x = scenic_attn.dot_product_attention( | ||
query, | ||
key, | ||
value, | ||
bias=attention_bias, | ||
dropout_rate=self.dropout, | ||
dropout_rng=dropout_rng, | ||
deterministic=not train, | ||
) | ||
|
||
# Back to the original inputs dimensions. | ||
out = nn.DenseGeneral( | ||
features=out_features, axis=(-2, -1), use_bias=True, name='out' | ||
)(x) | ||
|
||
return self.from_seq(out, spatial_shape) | ||
|
||
|
||
class VoxelAttention(Attention): | ||
"""Multi-head attention with voxels as sequence elements.""" | ||
|
||
def to_seq(self, x: Array) -> tuple[Array, tuple[int, ...]]: | ||
b, *_, c = x.shape | ||
return x.reshape(b, -1, c), x.shape[1:-1] | ||
|
||
def from_seq(self, x: Array, spatial_shape: tuple[int, ...]) -> Array: | ||
b, *_, c = x.shape | ||
return x.reshape(b, *spatial_shape, c) | ||
|
||
|
||
class EinopAttention(Attention): | ||
"""Multi-head attention with einops patterns.""" | ||
|
||
def _get_pattern( | ||
self, spatial_shape: tuple[int, ...] | ||
) -> tuple[str, str, dict[str, int]]: | ||
raise NotImplementedError() | ||
|
||
def to_seq(self, x: Array) -> tuple[Array, tuple[int, ...]]: | ||
spatial_shape = x.shape[1:-1] | ||
in_pattern, out_pattern, axes = self._get_pattern(spatial_shape) | ||
pattern = in_pattern + ' -> ' + out_pattern | ||
logging.info( | ||
'Volume to sequence pattern %r for shape %r', pattern, spatial_shape | ||
) | ||
x = einops.rearrange(x, pattern, **axes) | ||
return x, spatial_shape | ||
|
||
def from_seq(self, x: Array, spatial_shape: tuple[int, ...]) -> Array: | ||
in_pattern, out_pattern, axes = self._get_pattern(spatial_shape) | ||
pattern = out_pattern + ' -> ' + in_pattern | ||
x = einops.rearrange(x, pattern, **axes) | ||
return x | ||
|
||
def get_attn_shape(self, spatial_shape: tuple[int, ...]) -> tuple[int, ...]: | ||
raise NotImplementedError() | ||
|
||
|
||
class PatchAttention(EinopAttention): | ||
"""Multi-head attention between contiguous patches of voxels. | ||
For 2D and with p{i} denoting patch sizes and d{i} patched dimensions, | ||
the pattern is 'b (d1 p1) (d2 p2) c -> b (d1 d2) (p1 p2 c)' so we have | ||
d1 x d2 sequence items with p1 p2 c feature dimension. | ||
""" | ||
|
||
patch_sizes: tuple[int, ...] = (8,) | ||
|
||
def _get_pattern( | ||
self, spatial_shape: tuple[int, ...] | ||
) -> tuple[str, str, dict[str, int]]: | ||
spatial_in, spatial_out, feature_out = '', '', '' | ||
axes = dict() | ||
if len(spatial_shape) != len(self.patch_sizes): | ||
raise ValueError('spatial_shape and patch_sizes must have same length') | ||
for i, (dim, patch_size) in enumerate(zip(spatial_shape, self.patch_sizes)): | ||
spatial_in += f'(d{i+1} p{i+1}) ' | ||
spatial_out += f'd{i+1} ' | ||
feature_out += f'p{i+1} ' | ||
axes[f'p{i+1}'] = patch_size | ||
axes[f'd{i+1}'] = dim // patch_size | ||
in_pattern = f'b {spatial_in}c' | ||
out_pattern = f'b ({spatial_out}) ({feature_out}c)' | ||
return in_pattern, out_pattern, axes | ||
|
||
def get_attn_shape(self, spatial_shape: tuple[int, ...]) -> tuple[int, ...]: | ||
return tuple(d // p for d, p in zip(spatial_shape, self.patch_sizes)) | ||
|
||
|
||
class BlockAttention(EinopAttention): | ||
"""Multi-head attention within contiguous patches of voxels. | ||
For 2D and with p{i} denoting patch sizes and d{i} patched dimensions, | ||
the pattern is 'b (d1 p1) (d2 p2) c -> b d1 d2 (p1 p2) c' so we have | ||
p1 x p2 sequence items with c feature dimension (d1, d2 treated as batch). | ||
""" | ||
|
||
patch_sizes: tuple[int, ...] = (8,) | ||
|
||
def _get_pattern( | ||
self, spatial_shape: tuple[int, ...] | ||
) -> tuple[str, str, dict[str, int]]: | ||
spatial_in, spatial_out, feature_out = '', '', '' | ||
axes = dict() | ||
if len(spatial_shape) != len(self.patch_sizes): | ||
raise ValueError('spatial_shape and patch_sizes must have same length') | ||
for i, (dim, patch_size) in enumerate(zip(spatial_shape, self.patch_sizes)): | ||
spatial_in += f'(d{i+1} p{i+1}) ' | ||
spatial_out += f'd{i+1} ' | ||
feature_out += f'p{i+1} ' | ||
axes[f'p{i+1}'] = patch_size | ||
axes[f'd{i+1}'] = dim // patch_size | ||
in_pattern = f'b {spatial_in}c' | ||
out_pattern = f'b {spatial_out} ({feature_out}) c' | ||
logging.info( | ||
'block attention patterns %r %r and axes %r', | ||
in_pattern, | ||
out_pattern, | ||
axes, | ||
) | ||
return in_pattern, out_pattern, axes | ||
|
||
def get_attn_shape(self, spatial_shape: tuple[int, ...]) -> tuple[int, ...]: | ||
return self.patch_sizes | ||
|
||
|
||
class GridAttention(EinopAttention): | ||
"""Multi-head attention within strided grid spaced across input volume. | ||
For 2D and with p{i} denoting patch sizes and d{i} patched dimensions, | ||
the pattern is 'b (p1 d1) (p2 d2) c -> b d1 d2 (p1 p2) c' so we have | ||
p1 x p2 sequence items with c feature dimension (d1, d2 treated as batch). | ||
""" | ||
|
||
patch_sizes: tuple[int, ...] = (8,) | ||
|
||
def _get_pattern( | ||
self, spatial_shape: tuple[int, ...] | ||
) -> tuple[str, str, dict[str, int]]: | ||
spatial_in, spatial_out, feature_out = '', '', '' | ||
axes = dict() | ||
if len(spatial_shape) != len(self.patch_sizes): | ||
raise ValueError('spatial_shape and patch_sizes must have same length') | ||
for i, (dim, patch_size) in enumerate(zip(spatial_shape, self.patch_sizes)): | ||
spatial_in += f'(p{i+1} d{i+1}) ' | ||
spatial_out += f'd{i+1} ' | ||
feature_out += f'p{i+1} ' | ||
axes[f'p{i+1}'] = patch_size | ||
axes[f'd{i+1}'] = dim // patch_size | ||
in_pattern = f'b {spatial_in}c' | ||
out_pattern = f'b {spatial_out} ({feature_out}) c' | ||
return in_pattern, out_pattern, axes | ||
|
||
def get_attn_shape(self, spatial_shape: tuple[int, ...]) -> tuple[int, ...]: | ||
return self.patch_sizes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Test vision attention modules.""" | ||
|
||
from absl.testing import absltest | ||
from connectomics.jax.models import attention as attn | ||
import einops | ||
import jax | ||
|
||
|
||
class AttentionTest(absltest.TestCase): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.batch_size = 2 | ||
self.features = 3 | ||
self.rng = jax.random.PRNGKey(0) | ||
|
||
def test_base_attention(self): | ||
seq = jax.random.normal(self.rng, (self.batch_size, 128, self.features)) | ||
model = attn.Attention(num_heads=2, qkv_features=64) | ||
variables = model.init(self.rng, seq) | ||
seq_out = model.apply(variables, seq, train=True) | ||
self.assertSequenceEqual(seq_out.shape, seq.shape) | ||
|
||
def test_attention_dropout(self): | ||
seq = jax.random.normal(self.rng, (self.batch_size, 128, self.features)) | ||
model = attn.Attention(num_heads=2, qkv_features=64, dropout=0.2) | ||
variables = model.init(self.rng, seq) | ||
seq_out_determined = model.apply(variables, seq, train=False).mean() | ||
seq_out_dropout = model.apply( | ||
variables, seq, train=True, rngs={"dropout": self.rng} | ||
).mean() | ||
self.assertNotEqual(seq_out_dropout, seq_out_determined) | ||
|
||
def test_voxel_attention(self): | ||
x = jax.random.normal(self.rng, (self.batch_size, 4, 2, self.features)) | ||
model = attn.VoxelAttention(num_heads=2, qkv_features=4) | ||
variables = model.init(self.rng, x) | ||
x_out = model.apply(variables, x) | ||
self.assertSequenceEqual(x_out.shape, x.shape) | ||
|
||
def test_patch_attention(self): | ||
x = jax.random.normal(self.rng, (self.batch_size, 8, 4, 2, self.features)) | ||
ps = (2, 2, 2) | ||
model = attn.PatchAttention(num_heads=2, qkv_features=4, patch_sizes=ps) | ||
variables = model.init(self.rng, x) | ||
x_out = model.apply(variables, x) | ||
self.assertSequenceEqual(x_out.shape, x.shape) | ||
|
||
def test_grid_attention(self): | ||
x = jax.random.normal(self.rng, (self.batch_size, 8, 4, 2, self.features)) | ||
ps = (2, 2, 2) | ||
model = attn.GridAttention(num_heads=2, qkv_features=4, patch_sizes=ps) | ||
variables = model.init(self.rng, x) | ||
x_out = model.apply(variables, x) | ||
self.assertSequenceEqual(x_out.shape, x.shape) | ||
|
||
def test_block_attention(self): | ||
x = jax.random.normal(self.rng, (self.batch_size, 8, 4, 2, self.features)) | ||
ps = (2, 2, 2) | ||
model = attn.BlockAttention(num_heads=2, qkv_features=4, patch_sizes=ps) | ||
variables = model.init(self.rng, x) | ||
x_out = model.apply(variables, x) | ||
self.assertSequenceEqual(x_out.shape, x.shape) | ||
|
||
def test_grid_non_uniform_patches(self): | ||
x = jax.random.normal(self.rng, (self.batch_size, 1, 8, 2, self.features)) | ||
ps = (1, 4, 2) | ||
model = attn.GridAttention(num_heads=2, qkv_features=4, patch_sizes=ps) | ||
variables = model.init(self.rng, x) | ||
x_out = model.apply(variables, x) | ||
self.assertSequenceEqual(x_out.shape, x.shape) | ||
|
||
def test_non_divisible_patches_fail(self): | ||
x = jax.random.normal(self.rng, (self.batch_size, 1, 8, 2, self.features)) | ||
ps = (1, 3, 2) | ||
model = attn.GridAttention(num_heads=2, qkv_features=4, patch_sizes=ps) | ||
with self.assertRaises(einops.EinopsError): | ||
model.init(self.rng, x) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |