Swin Transformer: Hierarchical Vision Transformer using Shifted Windows, optimised for Graphcore's IPU. Based on the models provided by SWIN
Framework | Domain | Model | Datasets | Tasks | Training | Inference | Reference |
---|---|---|---|---|---|---|---|
PyTorch | Vision | SWIN | ImageNet LSVRC 2012 | Image recognition, Image classification | ✅ |
❌ |
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows |
-
Install and enable the Poplar SDK (see Poplar SDK setup)
-
Install the system and Python requirements (see Environment setup)
-
Download the ImageNet LSVRC 2012 dataset (See Dataset setup)
To check if your Poplar SDK has already been enabled, run:
echo $POPLAR_SDK_ENABLED
If no path is provided, then follow these steps:
-
Navigate to your Poplar SDK root directory
-
Enable the Poplar SDK with:
cd poplar-<OS version>-<SDK version>-<hash>
. enable.sh
- Additionally, enable PopART with:
cd popart-<OS version>-<SDK version>-<hash>
. enable.sh
More detailed instructions on setting up your Poplar environment are available in the Poplar quick start guide.
To prepare your environment, follow these steps:
- Create and activate a Python3 virtual environment:
python3 -m venv <venv name>
source <venv path>/bin/activate
-
Navigate to the Poplar SDK root directory
-
Install the PopTorch (PyTorch) wheel:
cd <poplar sdk root dir>
pip3 install poptorch...x86_64.whl
-
Navigate to this example's root directory
-
Install the Python requirements:
pip3 install -r requirements.txt
- Build the custom ops:
make all
More detailed instructions on setting up your PyTorch environment are available in the PyTorch quick start guide.
Download the ImageNet LSVRC 2012 dataset from the source or via kaggle
Disk space required: 144GB
.
├── bounding_boxes
├── imagenet_2012_bounding_boxes.csv
├── train
└── validation
3 directories, 1 file
To run a tested and optimised configuration and to reproduce the performance shown on our performance results page, use the examples_utils
module (installed automatically as part of the environment setup) to run one or more benchmarks. The benchmarks are provided in the benchmarks.yml
file in this example's root directory.
For example:
python3 -m examples_utils benchmark --spec <path to benchmarks.yml file>
Or to run a specific benchmark in the benchmarks.yml
file provided:
python3 -m examples_utils benchmark --spec <path to benchmarks.yml file> --benchmark <name of benchmark>
For more information on using the examples-utils benchmarking module, please refer to the README.
We can train with fp32.32 and fp16.16 (input.weights). We can use the PRECISION parameter in config to modify the required precision
Accuracy is as follows:
model | input size | precision | machine | acc |
---|---|---|---|---|
tiny | 224 | 16.16 | pod16 | 80.9% |
tiny | 224 | 32.32 | pod16 | 81.21% |
tiny | 224 | mix | v100 | 81.3%(SOTA) |
base | 224 | 32.32 | pod16 | 83.5% |
base | 224 | mix | v100 | 83.5%(SOTA) |
base | 384 | 32.32 | pod16 | 84.47% |
base | 384 | mix | v100 | 84.5%(SOTA) |
large | 224 | 32.32 | pod16 | 86.27% |
large | 224 | mix | v100 | 86.3%(SOTA) |
NOTE:The Acc which marked SOTA is quoted from the paper
In the above results, base 384 and large are finetune models it needs the pretrained model form imagenet21k, which require you to provide the path of the "pretrained-model" parameter to get the correct ACC
You can load the base 384 model pretrained on imagenet21k by:
wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
and large 224 model pretrained on imagenet21k by:
wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth
Once the training finishes, you can validate accuracy:
python3 validate.py --cfg YOUR_CONFIG --checkpoint /path/to/checkpoint.pth
This application is licensed under Apache License 2.0. Please see the LICENSE file in this directory for full details of the license conditions.
The following files are created by Graphcore and are licensed under Apache License, Version 2.0 :
- configs/*
- dataset/build_ipu.py
- models/build.py
- options.py
- README.md
- swin_test.py
- train_swin.py
- utils.py
- validate.py The following files are based on code from repo which is licensed under the MIT License:
- config.py
- dataset/cached_image_folder.py
- dataset/samplers.py
- dataset/zipreader.py
- lr_scheduler.py
- model/swin_transformer.py
- optimizer.py
- train_swin.sh See the headers in the files for details.
External packages:
transformers
is licenced under Apache License, Version 2.0pytest
is licensed under MIT Licensetorchvision
is licensed under BSD 3-Clause License