From 7e3916468e50d49dd0ec648f62073ec0c0ce2b71 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 7 Dec 2023 11:23:48 +0000 Subject: [PATCH 1/6] Commmit network and tests Signed-off-by: Mark Graham --- monai/networks/nets/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index db3c77c717..08384b4d52 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -106,6 +106,7 @@ from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR from .torchvision_fc import TorchVisionFCModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex +from .transformer import DecoderOnlyTransformer from .unet import UNet, Unet from .unetr import UNETR from .varautoencoder import VarAutoEncoder From 7ac5a1dffa8e46ead2fd045e27008e1c3584609e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 7 Dec 2023 11:24:24 +0000 Subject: [PATCH 2/6] Commmit network and tests Signed-off-by: Mark Graham --- monai/networks/nets/transformer.py | 319 +++++++++++++++++++++++++++++ tests/test_transformer.py | 45 ++++ 2 files changed, 364 insertions(+) create mode 100644 monai/networks/nets/transformer.py create mode 100644 tests/test_transformer.py diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py new file mode 100644 index 0000000000..56f341e898 --- /dev/null +++ b/monai/networks/nets/transformer.py @@ -0,0 +1,319 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib.util +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks.mlp import MLPBlock + +if importlib.util.find_spec("xformers") is not None: + import xformers.ops as xops + + has_xformers = True +else: + has_xformers = False +__all__ = ["DecoderOnlyTransformer"] + + +class _SABlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A self-attention block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Args: + hidden_size: dimension of hidden layer. + num_heads: number of attention heads. + dropout_rate: dropout ratio. Defaults to no dropout. + qkv_bias: bias term for the qkv linear layer. + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. + with_cross_attention: Whether to use cross attention for conditioning. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + causal: bool = False, + sequence_length: int | None = None, + with_cross_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + self.causal = causal + self.sequence_length = sequence_length + self.with_cross_attention = with_cross_attention + self.use_flash_attention = use_flash_attention + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + self.dropout_rate = dropout_rate + + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + # key, query, value projections + self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.to_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + + # regularization + self.drop_weights = nn.Dropout(dropout_rate) + self.drop_output = nn.Dropout(dropout_rate) + + # output projection + self.out_proj = nn.Linear(hidden_size, hidden_size) + + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + self.causal_mask: torch.Tensor + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + query = self.to_q(x) + + kv = context if context is not None else x + _, kv_t, _ = kv.size() + key = self.to_k(kv) + value = self.to_v(kv) + + query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs) + key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) + value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) + y: torch.Tensor + if self.use_flash_attention: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + y = xops.memory_efficient_attention( + query=query, + key=key, + value=value, + scale=self.scale, + p=self.dropout_rate, + attn_bias=xops.LowerTriangularMask() if self.causal else None, + ) + + else: + query = query.transpose(1, 2) # (b, nh, t, hs) + key = key.transpose(1, 2) # (b, nh, kv_t, hs) + value = value.transpose(1, 2) # (b, nh, kv_t, hs) + + # manual implementation of attention + query = query * self.scale + attention_scores = query @ key.transpose(-2, -1) + + if self.causal: + attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.drop_weights(attention_probs) + y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) + + y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs) + + y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side + + y = self.out_proj(y) + y = self.drop_output(y) + return y + + +class _TransformerBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A transformer block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Args: + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + num_heads: number of attention heads. + dropout_rate: faction of the input units to drop. + qkv_bias: apply bias term for the qkv linear layer + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. + with_cross_attention: Whether to use cross attention for conditioning. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + hidden_size: int, + mlp_dim: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + causal: bool = False, + sequence_length: int | None = None, + with_cross_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + self.with_cross_attention = with_cross_attention + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size should be divisible by num_heads.") + + self.norm1 = nn.LayerNorm(hidden_size) + self.attn = _SABlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=causal, + sequence_length=sequence_length, + use_flash_attention=use_flash_attention, + ) + + if self.with_cross_attention: + self.norm2 = nn.LayerNorm(hidden_size) + self.cross_attn = _SABlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + with_cross_attention=with_cross_attention, + causal=False, + use_flash_attention=use_flash_attention, + ) + self.norm3 = nn.LayerNorm(hidden_size) + self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + x = x + self.attn(self.norm1(x)) + if self.with_cross_attention: + x = x + self.cross_attn(self.norm2(x), context=context) + x = x + self.mlp(self.norm3(x)) + return x + + +class AbsolutePositionalEmbedding(nn.Module): + """Absolute positional embedding. + + Args: + max_seq_len: Maximum sequence length. + embedding_dim: Dimensionality of the embedding. + """ + + def __init__(self, max_seq_len: int, embedding_dim: int) -> None: + super().__init__() + self.max_seq_len = max_seq_len + self.embedding_dim = embedding_dim + self.embedding = nn.Embedding(max_seq_len, embedding_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len = x.size() + positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1) + embedding: torch.Tensor = self.embedding(positions) + return embedding + + +class DecoderOnlyTransformer(nn.Module): + """Decoder-only (Autoregressive) Transformer model. + + Args: + num_tokens: Number of tokens in the vocabulary. + max_seq_len: Maximum sequence length. + attn_layers_dim: Dimensionality of the attention layers. + attn_layers_depth: Number of attention layers. + attn_layers_heads: Number of attention heads. + with_cross_attention: Whether to use cross attention for conditioning. + embedding_dropout_rate: Dropout rate for the embedding. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + num_tokens: int, + max_seq_len: int, + attn_layers_dim: int, + attn_layers_depth: int, + attn_layers_heads: int, + with_cross_attention: bool = False, + embedding_dropout_rate: float = 0.0, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.num_tokens = num_tokens + self.max_seq_len = max_seq_len + self.attn_layers_dim = attn_layers_dim + self.attn_layers_depth = attn_layers_depth + self.attn_layers_heads = attn_layers_heads + self.with_cross_attention = with_cross_attention + + self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim) + self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim) + self.embedding_dropout = nn.Dropout(embedding_dropout_rate) + + self.blocks = nn.ModuleList( + [ + _TransformerBlock( + hidden_size=attn_layers_dim, + mlp_dim=attn_layers_dim * 4, + num_heads=attn_layers_heads, + dropout_rate=0.0, + qkv_bias=False, + causal=True, + sequence_length=max_seq_len, + with_cross_attention=with_cross_attention, + use_flash_attention=use_flash_attention, + ) + for _ in range(attn_layers_depth) + ] + ) + + self.to_logits = nn.Linear(attn_layers_dim, num_tokens) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + tok_emb = self.token_embeddings(x) + pos_emb = self.position_embeddings(x) + x = self.embedding_dropout(tok_emb + pos_emb) + + for block in self.blocks: + x = block(x, context=context) + logits: torch.Tensor = self.to_logits(x) + return logits diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 0000000000..3dd2d6621a --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,45 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.networks import eval_mode +from monai.networks.nets import DecoderOnlyTransformer + + +class TestDecoderOnlyTransformer(unittest.TestCase): + def test_unconditioned_models(self): + net = DecoderOnlyTransformer( + num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=2 + ) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16))) + + def test_conditioned_models(self): + net = DecoderOnlyTransformer( + num_tokens=10, + max_seq_len=16, + attn_layers_dim=8, + attn_layers_depth=2, + attn_layers_heads=2, + with_cross_attention=True, + embedding_dropout_rate=0, + ) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 4, 8)) + + +if __name__ == "__main__": + unittest.main() From b3d615bb61261f5eb5c94ab285cfd7e194fd519a Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 7 Dec 2023 11:45:41 +0000 Subject: [PATCH 3/6] Updates docs, corrects xformer optional import Signed-off-by: Mark Graham --- docs/source/networks.rst | 5 +++++ monai/networks/nets/transformer.py | 9 ++------- tests/test_transformer.py | 12 ++++++++++++ 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index d8be26264b..06f60fe8af 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -613,6 +613,11 @@ Nets .. autoclass:: VarAutoEncoder :members: +`DecoderOnlyTransformer` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: DecoderOnlyTransformer + :members: + `ViT` ~~~~~ .. autoclass:: ViT diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py index 56f341e898..b742c12205 100644 --- a/monai/networks/nets/transformer.py +++ b/monai/networks/nets/transformer.py @@ -11,7 +11,6 @@ from __future__ import annotations -import importlib.util import math import torch @@ -19,13 +18,9 @@ import torch.nn.functional as F from monai.networks.blocks.mlp import MLPBlock +from monai.utils import optional_import -if importlib.util.find_spec("xformers") is not None: - import xformers.ops as xops - - has_xformers = True -else: - has_xformers = False +xops, has_xformers = optional_import("xformers.ops") __all__ = ["DecoderOnlyTransformer"] diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 3dd2d6621a..d17e04744f 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -27,6 +27,18 @@ def test_unconditioned_models(self): with eval_mode(net): net.forward(torch.randint(0, 10, (1, 16))) + def test_models_with_flash_attention(self): + net = DecoderOnlyTransformer( + num_tokens=10, + max_seq_len=16, + attn_layers_dim=8, + attn_layers_depth=2, + attn_layers_heads=2, + use_flash_attention=True, + ).to(torch.device("cuda:0")) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16)).to(torch.device("cuda:0"))) + def test_conditioned_models(self): net = DecoderOnlyTransformer( num_tokens=10, From 6220db6234b1c8ebf0ee4dedcbe4e53562cf3939 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 7 Dec 2023 11:47:10 +0000 Subject: [PATCH 4/6] Removes memory_efficient_attention test as it depends on xformers Signed-off-by: Mark Graham --- tests/test_transformer.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/test_transformer.py b/tests/test_transformer.py index d17e04744f..3dd2d6621a 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -27,18 +27,6 @@ def test_unconditioned_models(self): with eval_mode(net): net.forward(torch.randint(0, 10, (1, 16))) - def test_models_with_flash_attention(self): - net = DecoderOnlyTransformer( - num_tokens=10, - max_seq_len=16, - attn_layers_dim=8, - attn_layers_depth=2, - attn_layers_heads=2, - use_flash_attention=True, - ).to(torch.device("cuda:0")) - with eval_mode(net): - net.forward(torch.randint(0, 10, (1, 16)).to(torch.device("cuda:0"))) - def test_conditioned_models(self): net = DecoderOnlyTransformer( num_tokens=10, From f593be177184bd5fdde39b284c31465bb5ce42c3 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 8 Dec 2023 15:12:19 +0000 Subject: [PATCH 5/6] Adds more test cases Signed-off-by: Mark Graham --- tests/test_transformer.py | 48 ++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 3dd2d6621a..29a505c9c5 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -13,32 +13,58 @@ import unittest +import numpy as np import torch +from parameterized import parameterized from monai.networks import eval_mode from monai.networks.nets import DecoderOnlyTransformer +TEST_CASES = [] +for dropout_rate in np.linspace(0, 1, 2): + for attention_layer_dim in [360, 480, 600, 768]: + for num_heads in [4, 6, 8, 12]: + TEST_CASES.append( + [{ "num_tokens": 10, + "max_seq_len": 16, + "attn_layers_dim": attention_layer_dim, + "attn_layers_depth": 2, + "attn_layers_heads": num_heads, + "embedding_dropout_rate": dropout_rate, + + }] + ) class TestDecoderOnlyTransformer(unittest.TestCase): - def test_unconditioned_models(self): + @parameterized.expand(TEST_CASES) + + def test_unconditioned_models(self, input_param): net = DecoderOnlyTransformer( - num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=2 + **input_param ) with eval_mode(net): net.forward(torch.randint(0, 10, (1, 16))) + @parameterized.expand(TEST_CASES) - def test_conditioned_models(self): + def test_conditioned_models(self, input_param): net = DecoderOnlyTransformer( - num_tokens=10, - max_seq_len=16, - attn_layers_dim=8, - attn_layers_depth=2, - attn_layers_heads=2, - with_cross_attention=True, - embedding_dropout_rate=0, + **input_param, with_cross_attention=True ) with eval_mode(net): - net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 4, 8)) + net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 3, input_param["attn_layers_dim"])) + + + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + DecoderOnlyTransformer( + num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=3 + ) + + def test_dropout_rate_negative(self): + with self.assertRaises(ValueError): + DecoderOnlyTransformer( + num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=2, embedding_dropout_rate=-1 + ) if __name__ == "__main__": From 107e00ac1f25e1974aa4d9079beb5be61713d5b3 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 8 Dec 2023 15:45:33 +0000 Subject: [PATCH 6/6] Formatting Signed-off-by: Mark Graham --- tests/test_transformer.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 29a505c9c5..ea6ebdf50f 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -25,35 +25,32 @@ for attention_layer_dim in [360, 480, 600, 768]: for num_heads in [4, 6, 8, 12]: TEST_CASES.append( - [{ "num_tokens": 10, - "max_seq_len": 16, - "attn_layers_dim": attention_layer_dim, - "attn_layers_depth": 2, - "attn_layers_heads": num_heads, - "embedding_dropout_rate": dropout_rate, - - }] + [ + { + "num_tokens": 10, + "max_seq_len": 16, + "attn_layers_dim": attention_layer_dim, + "attn_layers_depth": 2, + "attn_layers_heads": num_heads, + "embedding_dropout_rate": dropout_rate, + } + ] ) + class TestDecoderOnlyTransformer(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_unconditioned_models(self, input_param): - net = DecoderOnlyTransformer( - **input_param - ) + net = DecoderOnlyTransformer(**input_param) with eval_mode(net): net.forward(torch.randint(0, 10, (1, 16))) - @parameterized.expand(TEST_CASES) + @parameterized.expand(TEST_CASES) def test_conditioned_models(self, input_param): - net = DecoderOnlyTransformer( - **input_param, with_cross_attention=True - ) + net = DecoderOnlyTransformer(**input_param, with_cross_attention=True) with eval_mode(net): net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 3, input_param["attn_layers_dim"])) - def test_attention_dim_not_multiple_of_heads(self): with self.assertRaises(ValueError): DecoderOnlyTransformer( @@ -63,7 +60,12 @@ def test_attention_dim_not_multiple_of_heads(self): def test_dropout_rate_negative(self): with self.assertRaises(ValueError): DecoderOnlyTransformer( - num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=2, embedding_dropout_rate=-1 + num_tokens=10, + max_seq_len=16, + attn_layers_dim=8, + attn_layers_depth=2, + attn_layers_heads=2, + embedding_dropout_rate=-1, )