Skip to content

The official implementation of "Domain Aware Post Training Quantization for Vision Transformers in Deployment" in PyTorch.

Notifications You must be signed in to change notification settings

hailuorou/DAQuant

Repository files navigation

DAQuant

Despite the increasing popularity of Vision Transformers (ViTs) on vision tasks, their deployment on mobile devices presents two main challenges: performance degradation due to the necessary model compression amidst computational constraints, and accuracy drop stemming from domain shift effects. Although existing post-training quantization (PTQ) methods can reduce computational load for ViTs, they often fail under extreme low-bit conditions and domain shift scenarios. To address the two challenges, this paper introduces a novel Domain Aware Post-training Quantization (DAQuant) approach that simultaneously tackles extreme model compression and domain adaptation for ViTs in deployment. DAQuant employs a distribution-aware smoothing technique to mitigate outlier effects in ViT activations and employs learnable activation clipping (LAC) to minimize quantization errors. Additionally, we propose an effective domain alignment strategy to improve the model’s generalizability, which preserves model’s optimization on source domain while enhancing generalization ability on the target domain. DAQuant demonstrates superior performance in both quantization error and generalization capacity, outperforming existing quantization methods significantly in real-device deployment scenarios.

Usage

We provide full script to run DAQuant. We use DeiT-S as an example here. You can download the model weights of deit-small-patch16-224 from Huggingface.

  1. Install Package
conda create -n daquant python=3.11.0 -y
conda activate daquant
pip install --upgrade pip  
pip install -r requirements.txt
  1. Obtain the channel-wise scales and shifts required for initialization:
python generate_act_scale_shift.py --model /PATH/TO/DeiT/deit-small-patch16-224
  1. model quantization
# W4A4 
CUDA_VISIBLE_DEVICES=0 python main.py \
--model /PATH/TO/DeiT/deit-small-patch16-224  \
--epochs 20 --output_dir ./log/deit-small-patch16-224-w4a4 \
--wbits 4 --abits 4 --dga --lwc --lac --wrc

# W6A6
CUDA_VISIBLE_DEVICES=0 python main.py \
--model /PATH/TO/DeiT/deit-small-patch16-224  \
--epochs 20 --output_dir ./log/deit-small-patch16-224-w6a6 \
--wbits 6 --abits 6 --dga --lwc --lac --wrc

# W4A16
CUDA_VISIBLE_DEVICES=0 python main.py \
--model /PATH/TO/DeiT/deit-small-patch16-224  \
--epochs 20 --output_dir ./log/deit-small-patch16-224-w4a16 \
--wbits 4 --abits 16 --dga --lwc --lac --wrc

# W3A16
CUDA_VISIBLE_DEVICES=0 python main.py \
--model /PATH/TO/DeiT/deit-small-patch16-224  \
--epochs 20 --output_dir ./log/deit-small-patch16-224-w3a16 \
--wbits 3 --abits 16 --dga --lwc --lac --wrc
  1. domain adaptation

Below is the running script for Domain Adaptation, and we will release the pre-trained model weights shortly.

# W4A4
CUDA_VISIBLE_DEVICES=7 python main.py \
--model /PATH/TO/DeiT/deit-small-patch16-224  \
--source_model /PATH/TO/Pre-train-in-office/DeiT/DeiT-S \
--epochs 10 --output_dir ./log/deit-small-patch16-224-w4a4-da  \
--wbits 4 --abits 4 --dga --lwc --lac --wrc --tl \
--calib_dataset amazon --target_dataset webcam \
--tl_loss --tl_weight 1.5

  1. real quant

We utilize the kernel from AutoGPTQ to enable real quantization. If you aim to accelerate and compress your model using real quantization, we can follow these steps.

pip install auto-gptq==0.6.0

CUDA_VISIBLE_DEVICES=0 python main.py \
--model /PATH/TO/DeiT/deit-small-patch16-224  \
--epochs 20 --output_dir ./log/deit-small-patch16-224-w4a4 \
--wbits 4 --abits 16 --lwc --lac --wrc \
--real_quant --save_dir ./real_quant/deit-small-patch16-224-w4a16

Related Project

SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models

OmniQuant: Omnidirectionally Calibrated Quantization for Large Language Models

About

The official implementation of "Domain Aware Post Training Quantization for Vision Transformers in Deployment" in PyTorch.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages