Skip to content

Polarisjame/Bitmoji_Hybrid_ViT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation


Hybrid_ViT

A reproduction of the ViT Paper
Explore the docs »

Table of Contents
  1. About The Project
  2. Getting Started
  3. Training and Testing

About The Project

该模型由ResNet34以及ViT组成,在此基础上加入Re_Zero以及GatedAttention优化,及对模型的残差连接部分以及自注意力部分做了改动。最终输出对输入人脸图片的性别预测。

ResNet34

模型参考ResNet 原论文使用的是ResNet50,这里为了减少参数量用了ResNet34,也将最后两层残差层合并,在PatchEmbedding中用1*1的卷积核得到EmbeddingSize的通道。

ViT

结构同原论文,最后只提取CLS Token的特征作为分类向量

(back to top)

Getting Started

Requirements

My code works with the following environment.

  • python=3.7
  • pytorch=1.12.1+cu116
  • tqdm
  • numpy
  • pandas
  • sklearn
  • einops
  • PIL

Dataset

Download data from Bitmojidata, Put all files you download under ./data/Bitmojidata

(back to top)

Training and Testing

Train

  • You can run python train.py to train a Hybrid_ViT Model in cmd line and python train.py -h to get help.

  • To train a ResNet34 Model, you can run python train.py --use_only_res34 True

  • To train a pure ViT Model, you can run python train.py --use_only_res34 False --use_hybrid False

Here are some important parameters:

  • --batch_size
  • --re_zero: Use Re_zero in ResNet if True
  • --learning_rate
  • --epochs

Test

Your trained model named 'model_epoch.pth' and loss/acc figure is saved under ./checkpoint,model is saved every 5 epochs. You need to replace the code in test(args, model, data, device) in train.py to predict the test data with your saved model. The result file will be saved as'out_epoch.csv'.

Here are My Train Example: TrainPic

(back to top)

License

This project is licensed under the MIT License - see the LICENSE file for details.

About

基于Hybrid_ViT根据人脸图片判断性别

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published