diff --git a/src/brevitas_examples/super_resolution/README.md b/src/brevitas_examples/super_resolution/README.md index 9e0068c27..1c73873e5 100644 --- a/src/brevitas_examples/super_resolution/README.md +++ b/src/brevitas_examples/super_resolution/README.md @@ -1,6 +1,6 @@ # Integer-Quantized Super Resolution Experiments with Brevitas -This directory contains scripts demonstrating how to train integer-quantized super resolution models using [Brevitas](https://github.com/Xilinx/brevitas). +This directory contains scripts demonstrating how to train integer-quantized super resolution models using Brevitas. Code is also provided to demonstrate accumulator-aware quantization (A2Q) as proposed in our ICCV 2023 paper "[A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance](https://arxiv.org/abs/2308.13504)". ## Experiments @@ -12,25 +12,32 @@ During inference center cropping is applied. Inputs are then downscaled by 2x and then used to train the model directly in the RGB space. Note that this is a difference from many academic works that train only on the Y-channel in YCbCr format. -| Model Name | Upscale Factor | Weight quantization | Activation quantization | Peak Signal-to-Noise Ratio | +| Model Name | Upscale Factor | Weight quantization | Activation quantization | Peak Signal-to-Noise Ratio | |-----------------------------|----------------|---------------------|-------------------------|----------------------------| -| bicubic_interp | x2 | N/A | N/A | 28.71 | -| [float_espcn_x2]() | x2 | float32 | float32 | 31.03 | +| bicubic_interp | x2 | N/A | N/A | 28.71 | +| [float_espcn_x2](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/float_espcn_x2-2f85a454.pth) | x2 | float32 | float32 | 31.03 | || -| [quant_espcn_x2_w8a8_base]() | x2 | int8 | (u)int8 | 30.96 | -| [quant_espcn_x2_w8a8_a2q_32b]() | x2 | int8 | (u)int8 | 30.79 | -| [quant_espcn_x2_w8a8_a2q_16b]() | x2 | int8 | (u)int8 | 30.56 | +| [quant_espcn_x2_w8a8_base](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w8a8_base-f761e4a1.pth) | x2 | int8 | (u)int8 | 30.96 | +| [quant_espcn_x2_w8a8_a2q_32b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w8a8_a2q_32b-85470d9b.pth) | x2 | int8 | (u)int8 | 30.79 | +| [quant_espcn_x2_w8a8_a2q_16b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w8a8_a2q_16b-f9e1da66.pth) | x2 | int8 | (u)int8 | 30.56 | || -| [quant_espcn_x2_w4a4_base]() | x2 | int4 | (u)int4 | 30.30 | -| [quant_espcn_x2_w4a4_a2q_32b]() | x2 | int4 | (u)int4 | 30.27 | -| [quant_espcn_x2_w4a4_a2q_13b]() | x2 | int4 | (u)int4 | 30.24 | +| [quant_espcn_x2_w4a4_base](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w4a4_base-80658e6d.pth) | x2 | int4 | (u)int4 | 30.30 | +| [quant_espcn_x2_w4a4_a2q_32b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w4a4_a2q_32b-8702a412.pth) | x2 | int4 | (u)int4 | 30.27 | +| [quant_espcn_x2_w4a4_a2q_13b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w4a4_a2q_13b-9fff234e.pth) | x2 | int4 | (u)int4 | 30.24 | ## Train -To start training a model from scratch (*e.g.*, `quant_espcn_x2_w8a8_a2q_32b`) run: +All models are trained from scratch as follows: ```bash -python train_model.py --data_root=data --model=quant_espcn_x2_w8a8_a2q_32b +python train_model.py^ + --data_root=./data^ + --model=quant_espcn_x2_w8a8_a2q_32b^ + --batch_size=8^ + --learning_rate=0.001^ + --weight_decay=0.00001^ + --gamma=0.999^ + --step_size=1 ``` ## Evaluate diff --git a/src/brevitas_examples/super_resolution/models/__init__.py b/src/brevitas_examples/super_resolution/models/__init__.py index 872624a85..6d533fccb 100644 --- a/src/brevitas_examples/super_resolution/models/__init__.py +++ b/src/brevitas_examples/super_resolution/models/__init__.py @@ -45,13 +45,16 @@ act_bit_width=4, acc_bit_width=13)} -root_url = 'https://github.com/Xilinx/brevitas/releases/download/super_res-r0' +root_url = 'https://github.com/Xilinx/brevitas/releases/download/super_res_r1' model_url = { - 'float_espcn_x2': f'{root_url}/float_espcn_x2-2f3821e3.pth', - 'quant_espcn_x2_w8a8_base': f'{root_url}/quant_espcn_x2_w8a8_base-7d54e29c.pth', - 'quant_espcn_x2_w8a8_a2q_32b': f'{root_url}/quant_espcn_x2_w8a8_a2q_32b-0b1f361d.pth', - 'quant_espcn_x2_w8a8_a2q_16b': f'{root_url}/quant_espcn_x2_w8a8_a2q_16b-3c4acd35.pth'} + 'float_espcn_x2': f'{root_url}/float_espcn_x2-2f85a454.pth', + 'quant_espcn_x2_w4a4_a2q_13b': f'{root_url}/quant_espcn_x2_w4a4_a2q_13b-9fff234e.pth', + 'quant_espcn_x2_w4a4_a2q_32b': f'{root_url}/quant_espcn_x2_w4a4_a2q_32b-8702a412.pth', + 'quant_espcn_x2_w4a4_base': f'{root_url}/quant_espcn_x2_w4a4_base-80658e6d.pth', + 'quant_espcn_x2_w8a8_a2q_16b': f'{root_url}/quant_espcn_x2_w8a8_a2q_16b-f9e1da66.pth', + 'quant_espcn_x2_w8a8_a2q_32b': f'{root_url}/quant_espcn_x2_w8a8_a2q_32b-85470d9b.pth', + 'quant_espcn_x2_w8a8_base': f'{root_url}/quant_espcn_x2_w8a8_base-f761e4a1.pth'} def get_model_by_name(name: str, pretrained: bool = False) -> Union[FloatESPCN, QuantESPCN]: