A reproduction of the ViT Paper
Explore the docs »
Table of Contents
该模型由ResNet34以及ViT组成,在此基础上加入Re_Zero以及GatedAttention优化,及对模型的残差连接部分以及自注意力部分做了改动。最终输出对输入人脸图片的性别预测。
模型参考ResNet 原论文使用的是ResNet50,这里为了减少参数量用了ResNet34,也将最后两层残差层合并,在PatchEmbedding中用1*1的卷积核得到EmbeddingSize的通道。
结构同原论文,最后只提取CLS Token的特征作为分类向量
My code works with the following environment.
python=3.7
pytorch=1.12.1+cu116
- tqdm
- numpy
- pandas
- sklearn
- einops
- PIL
Download data from Bitmojidata
, Put all files you download under ./data/Bitmojidata
-
You can run
python train.py
to train a Hybrid_ViT Model in cmd line andpython train.py -h
to get help. -
To train a ResNet34 Model, you can run
python train.py --use_only_res34 True
-
To train a pure ViT Model, you can run
python train.py --use_only_res34 False --use_hybrid False
Here are some important parameters:
--batch_size
--re_zero
: Use Re_zero in ResNet if True--learning_rate
--epochs
Your trained model named 'model_epoch.pth' and loss/acc figure is saved under ./checkpoint,model is saved every 5 epochs. You need to replace the code in test(args, model, data, device) in train.py to predict the test data with your saved model. The result file will be saved as'out_epoch.csv'.
This project is licensed under the MIT License - see the LICENSE file for details.