The MegEngine Implementation of MAE(Masked Auto Encoder).
Make sure you are using a GPU device, for there is a gap between output of GPU and CPU device in MegEngine gather API
pip install -r requirements.txt
If you don't want to compare the ouput error between the MegEngine implementation and PyTorch one, just ignore requirements.txt and install MegEngine from the command line:
python3 -m pip install --upgrade pip
python3 -m pip install megengine -f https://megengine.org.cn/whl/mge.html
Note:
The pytorch implementation is based on timm==0.3.2
, for which a fix is needed to work with PyTorch 1.8.1+.
Convert trained weights from torch to megengine, the converted weights will be save in ./pretained/ , you need to specify the converte model architecture and path to checkpoint offered by official repo.
pre-trained checkpoint:
ViT-Base | ViT-Large | ViT-Huge |
---|---|---|
download | download | download |
visuialize checkpoint:
ViT-Base | ViT-Large | ViT-Large-GanLoss | ViT-Huge |
---|---|---|---|
download | download | download | download |
python convert_weights.py -m mae_vit_base_patch16 -c /local/path/to/ckpt
Use python compare.py
.
By default, the compare script will convert the torch state_dict to the format that megengine need.
If you want to compare the error by checkpoints, you neet load them manually.
Just read and run visualize.py
.
Import from megengine.hub:
Way 1:
from functools import partial
import megengine.module as M
from megengine import hub
modelhub = hub.import_module(
repo_info='asthestarsfalll/MAE-MegEngine:main', git_host='github.com')
# load VAN model and custom on you own
model = modelhub.MAE(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(M.LayerNorm, eps=1e-6))
# load pretrained model
pretrained_model = modelhub.mae_vit_base_patch16(pretrained=True)
Way 2:
from megengine import hub
# load pretrained model
model_name = 'mae_vit_base_patch16'
pretrained_model = hub.load(
repo_info='asthestarsfalll/MAE-MegEngine:main', entry=model_name, git_host='github.com', pretrained=True)
Currently pretrained model only support mae_vit_base_patch16.
But you can still load the model without pretrained weights like this:
model = modelhub.mae_vit_large_patch16()
# or
model_name = 'mae_vit_large_patch16'
model = hub.load(
repo_info='asthestarsfalll/MAE-MegEngine:main', entry=model_name, git_host='github.com')
- Add interfaces of visialize.
- Some down stream tasks maybe.
- Some introduction about MAE.