From 139182ea52725aa3c9214dc18082b9837e32f9a2 Mon Sep 17 00:00:00 2001 From: NabJa Date: Thu, 18 Apr 2024 14:37:32 +0200 Subject: [PATCH 1/5] Add dimensionality of heads argument to SABlock --- monai/networks/blocks/selfattention.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 7c81c1704f..74bf8b4d4e 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -32,6 +32,7 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, + dim_head: int = 64 ) -> None: """ Args: @@ -40,6 +41,7 @@ def __init__( dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + dim_head (int, optional): dimension of each head. Defaults to 64. """ @@ -52,8 +54,11 @@ def __init__( raise ValueError("hidden size should be divisible by num_heads.") self.num_heads = num_heads - self.out_proj = nn.Linear(hidden_size, hidden_size) - self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + self.dim_head = dim_head + self.inner_dim = dim_head * num_heads + + self.out_proj = nn.Linear(self.inner_dim, hidden_size) + self.qkv = nn.Linear(hidden_size, self.inner_dim * 3, bias=qkv_bias) self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) From 1ccb5de43f936720d8fc82307d703f507682d135 Mon Sep 17 00:00:00 2001 From: NabJa Date: Wed, 24 Apr 2024 14:07:28 +0200 Subject: [PATCH 2/5] Add splitting embedding dim across head as default Signed-off-by: NabJa DCO Remediation Commit for NabJa I, NabJa , hereby add my Signed-off-by to this commit: 139182ea52725aa3c9214dc18082b9837e32f9a2 Signed-off-by: NabJa --- monai/networks/blocks/selfattention.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 74bf8b4d4e..f8a738b36d 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -32,7 +32,7 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, - dim_head: int = 64 + dim_head: int | None = None, ) -> None: """ Args: @@ -41,7 +41,7 @@ def __init__( dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. - dim_head (int, optional): dimension of each head. Defaults to 64. + dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. """ @@ -54,8 +54,8 @@ def __init__( raise ValueError("hidden size should be divisible by num_heads.") self.num_heads = num_heads - self.dim_head = dim_head - self.inner_dim = dim_head * num_heads + self.dim_head = hidden_size // num_heads if dim_head is None else dim_head + self.inner_dim = self.dim_head * num_heads self.out_proj = nn.Linear(self.inner_dim, hidden_size) self.qkv = nn.Linear(hidden_size, self.inner_dim * 3, bias=qkv_bias) From 71a55b35e09fc08ebd22822dc67ef3961e16128f Mon Sep 17 00:00:00 2001 From: NabJa Date: Wed, 24 Apr 2024 17:47:24 +0200 Subject: [PATCH 3/5] Compute scale based on updated dim_head Signed-off-by: NabJa --- monai/networks/blocks/selfattention.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index f8a738b36d..7b410b1a7c 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -63,8 +63,7 @@ def __init__( self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) - self.head_dim = hidden_size // num_heads - self.scale = self.head_dim**-0.5 + self.scale = self.dim_head**-0.5 self.save_attn = save_attn self.att_mat = torch.Tensor() From 55d0cdccada36564816cd1deecde43214457514e Mon Sep 17 00:00:00 2001 From: NabJa Date: Fri, 26 Apr 2024 10:40:22 +0200 Subject: [PATCH 4/5] Add tests for SABlock number of parameters Signed-off-by: NabJa --- tests/test_selfattention.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index b8be4fd1b6..73b2a91326 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -74,6 +74,40 @@ def test_access_attn_matrix(self): matrix_acess_blk(torch.randn(input_shape)) assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + def test_number_of_parameters(self): + + def count_params(model): + return sum([x.numel() for x in model.parameters() if x.requires_grad]) + + hidden_size = 128 + num_heads = 8 + default_dim_head = hidden_size // num_heads + + nparams_default = count_params(SABlock(hidden_size=hidden_size, num_heads=num_heads)) + nparams_like_default = count_params( + SABlock(hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head) + ) + nparams_custom_large = count_params( + SABlock(hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2) + ) + nparams_custom_small = count_params( + SABlock(hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2) + ) + + nparams_default_more_heads = count_params(SABlock(hidden_size=hidden_size, num_heads=num_heads * 2)) + + # Default is the same as hidden_size // num_heads + self.assertEqual(nparams_default, nparams_like_default) + + # Increasing dim_head should increase the number of parameters + self.assertGreater(nparams_custom_large, nparams_default) + + # Decreasing dim_head should decrease the number of parameters + self.assertGreater(nparams_default, nparams_custom_small) + + # Increasing the number of heads with the default behaviour should not change the number of params. + self.assertEqual(nparams_default, nparams_default_more_heads) + if __name__ == "__main__": unittest.main() From 36fa8107360227a77a94d97f8bcba971ef0992c0 Mon Sep 17 00:00:00 2001 From: NabJa Date: Fri, 26 Apr 2024 12:34:36 +0200 Subject: [PATCH 5/5] Refactor test_selfattention for better readability Signed-off-by: NabJa --- tests/test_selfattention.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 73b2a91326..0ebed84159 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -76,36 +76,36 @@ def test_access_attn_matrix(self): def test_number_of_parameters(self): - def count_params(model): - return sum([x.numel() for x in model.parameters() if x.requires_grad]) + def count_sablock_params(*args, **kwargs): + """Count the number of parameters in a SABlock.""" + sablock = SABlock(*args, **kwargs) + return sum([x.numel() for x in sablock.parameters() if x.requires_grad]) hidden_size = 128 num_heads = 8 default_dim_head = hidden_size // num_heads - nparams_default = count_params(SABlock(hidden_size=hidden_size, num_heads=num_heads)) - nparams_like_default = count_params( - SABlock(hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head) + # Default dim_head is hidden_size // num_heads + nparams_default = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads) + nparams_like_default = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head ) - nparams_custom_large = count_params( - SABlock(hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2) - ) - nparams_custom_small = count_params( - SABlock(hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2) - ) - - nparams_default_more_heads = count_params(SABlock(hidden_size=hidden_size, num_heads=num_heads * 2)) - - # Default is the same as hidden_size // num_heads self.assertEqual(nparams_default, nparams_like_default) # Increasing dim_head should increase the number of parameters + nparams_custom_large = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2 + ) self.assertGreater(nparams_custom_large, nparams_default) # Decreasing dim_head should decrease the number of parameters + nparams_custom_small = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2 + ) self.assertGreater(nparams_default, nparams_custom_small) # Increasing the number of heads with the default behaviour should not change the number of params. + nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) self.assertEqual(nparams_default, nparams_default_more_heads)