Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SABlock parameters when using more heads #7661

Closed
NabJa opened this issue Apr 18, 2024 · 0 comments · Fixed by #7664
Closed

SABlock parameters when using more heads #7661

NabJa opened this issue Apr 18, 2024 · 0 comments · Fixed by #7664

Comments

@NabJa
Copy link
Contributor

NabJa commented Apr 18, 2024

Describe the bug
The number of parameters in the SABlock should be increased when increasing the number of heads (num_heads). However, this is not the case and limits comparability to famous scaling like ViT-S or ViT-B.

To Reproduce
Steps to reproduce the behavior:

from monai.networks.nets import ViT

def count_trainable_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Create ViT models with different numbers of heads
vit_b = ViT(1, 224, 16, num_heads=12)
vit_s = ViT(1, 224, 16, num_heads=6)

print("ViT with 12 heads parameters:", count_trainable_parameters(vit_b))
print("ViT with 6 heads parameters:", count_trainable_parameters(vit_s))

>>> ViT with 12 heads parameters: 90282240
>>> ViT with 6 heads parameters: 90282240

Expected behavior
The number of trainable parameters should be increased with increasing number of heads.

Environment

================================
Printing MONAI config...
================================
MONAI version: 0.8.1rc4+1384.g139182ea
Numpy version: 1.26.4
Pytorch version: 2.2.2+cpu
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 139182ea52725aa3c9214dc18082b9837e32f9a2
MONAI __file__: C:\Users\<username>\MONAI\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: 5.3.0
Nibabel version: 5.2.1
scikit-image version: 0.23.1
scipy version: 1.13.0
Pillow version: 10.3.0
Tensorboard version: 2.16.2
gdown version: 4.7.3
TorchVision version: 0.17.2+cpu
tqdm version: 4.66.2
lmdb version: 1.4.1
psutil version: 5.9.8
pandas version: 2.2.2
einops version: 0.7.0
transformers version: 4.39.3
mlflow version: 2.12.1
pynrrd version: 1.0.0
clearml version: 1.15.1

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies


================================
Printing system config...
================================
System: Windows
Win32 version: ('10', '10.0.22621', 'SP0', 'Multiprocessor Free')
Win32 edition: Professional
Platform: Windows-10-10.0.22621-SP0
Processor: Intel64 Family 6 Model 142 Stepping 12, GenuineIntel
Machine: AMD64
Python version: 3.11.8
Process name: python.exe
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: [popenfile(path='C:\\Windows\\System32\\de-DE\\KernelBase.dll.mui', fd=-1), popenfile(path='C:\\Windows\\System32\\de-DE\\kernel32.dll.mui', fd=-1), popenfile(path='C:\\Windows\\System32\\de-DE\\tzres.dll.mui', fd=-1)]
Num physical CPUs: 4
Num logical CPUs: 8
Num usable CPUs: 8
CPU usage (%): [3.9, 0.2, 3.7, 0.9, 3.9, 3.9, 2.8, 32.2]
CPU freq. (MHz): 1803
Load avg. in last 1, 5, 15 mins (%): [0.0, 0.0, 0.0]
Disk usage (%): 83.1
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 15.8
Available memory (GB): 5.5
Used memory (GB): 10.2

================================
Printing GPU config...
================================
Num GPUs: 0
Has CUDA: False
cuDNN enabled: False
NVIDIA_TF32_OVERRIDE: None
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE: None
KumoLiu added a commit that referenced this issue May 8, 2024
Fixes #7661.

### Description

The changes made add a parameter (_dim_head_) to set the output
paramters of all the heads in the Self-attention Block (SABlock).
Currently the output dimension is set to be _hidden_size_ and when
increasing the number of heads this is equally distributed among all
heads.

### Example
The original implementation automatically determines
**_equally_distributed_head_dim_**:
(qkv * num_heds * equally_distributed_head_dim = 3*hidden_size   
in this example  ->  3 * 8 * 16 = 384)
``` 
block = SABlock(hidden_size=128, num_heads=8)
x = torch.zeros(1, 256, 128)
x = block.qkv(x)
print(x.shape)
x = block.input_rearrange(x)
print(x.shape)

> torch.Size([1, 256, 384])
> torch.Size([3, 1, 8, 256, 16]) # <- This corresponds to (qkv batch num_heads sequence_length equally_distributed_head_dim)
```

The propesed implementation fixes this by setting the new argument
**_dim_head_:**
``` 
block_new = SABlock(hidden_size=128, num_heads=8, dim_head=32)
x = torch.zeros(1, 256, 128)
x = block_new.qkv(x)
print(x.shape)
x = block_new.input_rearrange(x)
print(x.shape)

> torch.Size([1, 256, 384])
> torch.Size([3, 1, 8, 256, 32]) # <- This corresponds to (qkv batch num_heads sequence_length dim_head)
```


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: NabJa <nabil.jabareen@gmail.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant