Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dimensionality of heads argument to SABlock #7664

Merged
merged 13 commits into from
May 8, 2024
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
dropout_rate: float = 0.0,
qkv_bias: bool = False,
save_attn: bool = False,
dim_head: int | None = None,
) -> None:
"""
Args:
Expand All @@ -40,6 +41,7 @@ def __init__(
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.

"""

Expand All @@ -52,8 +54,11 @@ def __init__(
raise ValueError("hidden size should be divisible by num_heads.")

self.num_heads = num_heads
self.out_proj = nn.Linear(hidden_size, hidden_size)
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
self.dim_head = hidden_size // num_heads if dim_head is None else dim_head
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
self.inner_dim = self.dim_head * num_heads

self.out_proj = nn.Linear(self.inner_dim, hidden_size)
self.qkv = nn.Linear(hidden_size, self.inner_dim * 3, bias=qkv_bias)
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.drop_output = nn.Dropout(dropout_rate)
Expand Down
Loading