diff --git a/deploy/slim/act/README.md b/deploy/slim/act/README.md new file mode 100644 index 0000000000..13330762ca --- /dev/null +++ b/deploy/slim/act/README.md @@ -0,0 +1,172 @@ +# 图像分类模型自动压缩示例 + +目录: +- [图像分类模型自动压缩示例](#图像分类模型自动压缩示例) + - [1. 简介](#1-简介) + - [2. Benchmark](#2-benchmark) + - [PaddleClas模型](#paddleclas模型) + - [3. 自动压缩流程](#3-自动压缩流程) + - [3.1 准备环境](#31-准备环境) + - [3.2 准备数据集](#32-准备数据集) + - [3.3 准备预测模型](#33-准备预测模型) + - [3.4 自动压缩并产出模型](#34-自动压缩并产出模型) + - [4.预测部署](#4预测部署) + - [4.1 Paddle Inference 验证性能](#41-paddle-inference-验证性能) + - [4.2 PaddleLite端侧部署](#42-paddlelite端侧部署) + - [5.FAQ](#5faq) + + +## 1. 简介 +本示例将以图像分类模型MobileNetV1为例,介绍如何使用PaddleClas中Inference部署模型进行自动压缩。本示例使用的自动压缩策略为量化训练和蒸馏。 + +## 2. Benchmark + +### PaddleClas模型 + +| 模型 | 策略 | Top-1 Acc | GPU 耗时(ms) | ARM CPU 耗时(ms) | 配置文件 | Inference模型 | +|:----------------------:|:------:|:---------:|:----------:|:--------------:|:------:|:-----:| +| MobileNetV3_small_x1_0 | Baseline | 68.19 | - | | - | [Model](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV3_small_x1_0_infer.tar) | +| MobileNetV3_small_x1_0 | 量化+蒸馏 | 64.90 | - | | [Config](./configs/MobileNetV3_small_x1_0/qat_dis.yaml) || +| ResNet50 | Baseline | 76.46 | | - | - | [Model](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_vd_infer.tar) | +| ResNet50_vd | 量化+蒸馏 | 76.08 | | - | [Config](./configs/ResNet50_vd/qat_dis.yaml) | | + + + +- ARM CPU 测试环境:`SDM865(4xA77+4xA55)` +- Nvidia GPU 测试环境: + - 硬件:NVIDIA Tesla T4 单卡 + - 软件:CUDA 11.2, cuDNN 8.0, TensorRT 8.4 + - 测试配置:batch_size: 1, image size: 224 + +## 3. 自动压缩流程 + +#### 3.1 准备环境 + +- python >= 3.6 +- PaddlePaddle >= 2.5 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) +- PaddleSlim >= 2.5 + +安装paddlepaddle: +```shell +# CPU +pip install paddlepaddle==2.5.1 +# GPU 以Ubuntu、CUDA 11.2为例 +python -m pip install paddlepaddle-gpu==2.5.1.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +``` + +安装paddleslim: +```shell +pip install paddleslim +``` + +若使用`run_ppclas.py`脚本,需安装paddleclas: +```shell +git clone https://github.com/PaddlePaddle/PaddleClas.git -b release/2.5 +cd PaddleClas +pip install --upgrade -r requirements.txt +``` + +#### 3.2 准备数据集 +本案例默认以ImageNet1k数据进行自动压缩实验,如数据集为非ImageNet1k格式数据, 请参考[PaddleClas数据准备文档](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/data_preparation/classification_dataset.md)。将下载好的数据集放在当前目录下`./ILSVRC2012`。 + + +#### 3.3 准备预测模型 +预测模型的格式为:`model.pdmodel` 和 `model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。 + +注:其他像`__model__`和`__params__`分别对应`model.pdmodel` 和 `model.pdiparams`文件。 + +可在[PaddleClas预训练模型库](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/algorithm_introduction/ImageNet_models.md)中直接获取Inference模型,具体可参考下方获取MobileNetV1模型示例: + +```shell +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_infer.tar +tar -xf ResNet50_infer.tar +``` +也可根据[PaddleClas文档](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/inference_deployment/export_model.md)导出Inference模型。 + +#### 3.4 自动压缩并产出模型 + +蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口 ```paddleslim.auto_compression.AutoCompression``` 对模型进行量化训练和蒸馏。配置config文件中模型路径、数据集路径、蒸馏、量化和训练等部分的参数,配置完成后便可开始自动压缩。 + +**单卡启动** + +```shell +export CUDA_VISIBLE_DEVICES=0 +python run_ppclas.py \ + --compression_config_path='./configs/ResNet50/qat_dis.yaml' \ + --reader_config_path='./configs/ResNet50/data_reader.yaml' \ + --save_dir='./save_quant_ResNet50/' +``` + +**多卡启动** + +图像分类训练任务中往往包含大量训练数据,以ImageNet为例,ImageNet22k数据集中包含1400W张图像,如果使用单卡训练,会非常耗时,使用分布式训练可以达到几乎线性的加速比。 + +```shell +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -m paddle.distributed.launch run.py --save_dir='./save_quant_resnet50/' --config_path='./configs/ResNet50/qat_dis.yaml' +``` +多卡训练指的是将训练任务按照一定方法拆分到多个训练节点完成数据读取、前向计算、反向梯度计算等过程,并将计算出的梯度上传至服务节点。服务节点在收到所有训练节点传来的梯度后,会将梯度聚合并更新参数。最后将参数发送给训练节点,开始新一轮的训练。多卡训练一轮训练能训练```batch size * num gpus```的数据,比如单卡的```batch size```为32,单轮训练的数据量即32,而四卡训练的```batch size```为32,单轮训练的数据量为128。 + +注意: + +- 参数设置:```learning rate``` 与 ```batch size``` 呈线性关系,这里单卡 ```batch size``` 为32,对应的 ```learning rate``` 为0.015,那么如果 ```batch size``` 减小4倍改为8,```learning rate``` 也需除以4;多卡时 ```batch size``` 为32,```learning rate``` 需乘上卡数。所以改变 ```batch size``` 或改变训练卡数都需要对应修改 ```learning rate```。 + +- 如需要使用`PaddleClas`中的数据预处理和`DataLoader`,可以使用`run_ppclas.py`脚本启动,启动方式跟以上示例相同,但配置需要对其```PaddleClas```,可参考[ViT配置文件](./configs/VIT/data_reader.yml)。 + + + +## 4.预测部署 + +#### 4.1 Paddle Inference 验证性能 + +量化模型在GPU上可以使用TensorRT进行加速,在CPU上可以使用MKLDNN进行加速。 + +以下字段用于配置预测参数: + +| 参数名 | 含义 | +|:------:|:------:| +| model_path | inference 模型文件所在目录,该目录下需要有文件 .pdmodel 和 .pdiparams 两个文件 | +| model_filename | inference_model_dir文件夹下的模型文件名称 | +| params_filename | inference_model_dir文件夹下的参数文件名称 | +| data_path | 数据集路径 | +| batch_size | 预测一个batch的大小 | +| image_size | 输入图像的大小 | +| use_gpu | 是否使用 GPU 预测 | +| use_trt | 是否使用 TesorRT 预测引擎 | +| use_mkldnn | 是否启用```MKL-DNN```加速库,注意```use_mkldnn```与```use_gpu```同时为```True```时,将忽略```use_mkldnn```,而使用```GPU```预测 | +| cpu_num_threads | CPU预测时,使用CPU线程数量,默认10 | +| use_fp16 | 使用TensorRT时,是否启用```FP16``` | +| use_int8 | 是否启用```INT8``` | + +注意: +- 请注意模型的输入数据尺寸,如InceptionV3输入尺寸为299,部分模型需要修改参数:```image_size``` + + +- TensorRT预测: + +环境配置:如果使用 TesorRT 预测引擎,需安装的是带有TensorRT的PaddlePaddle,使用以下指令查看本地cuda版本,并且在[下载链接](https://www.paddlepaddle.org.cn/inference/user_guides/download_lib.html#python)中下载对应cuda版本和对应python版本的PaddlePaddle安装包。 + + ```shell + cat /usr/local/cuda/version.txt ### CUDA Version 10.2.89 + ### 10.2.89 为cuda版本号,可以根据这个版本号选择需要安装的带有TensorRT的PaddlePaddle安装包。 + ``` + +```shell +python test_ppclas.py \ + --model_path=./save_quant_resnet50 \ + --use_trt=True \ + --use_int8=True \ + --use_gpu=True \ + --data_path=./dataset/ILSVRC2012/ +``` + +- MKLDNN预测: + +```shell +python test_ppclas \ + --model_path=./save_quant_resnet50 \ + --data_path=./dataset/ILSVRC2012/ \ + --cpu_num_threads=10 \ + --use_mkldnn=True \ + --use_int8=True +``` diff --git a/deploy/slim/act/configs/MobileNetV3_small_x1_0/data_reader.yaml b/deploy/slim/act/configs/MobileNetV3_small_x1_0/data_reader.yaml new file mode 100644 index 0000000000..9b5144abb5 --- /dev/null +++ b/deploy/slim/act/configs/MobileNetV3_small_x1_0/data_reader.yaml @@ -0,0 +1,56 @@ +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 16 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True \ No newline at end of file diff --git a/deploy/slim/act/configs/MobileNetV3_small_x1_0/qat_dis.yaml b/deploy/slim/act/configs/MobileNetV3_small_x1_0/qat_dis.yaml new file mode 100644 index 0000000000..7ecacd33c0 --- /dev/null +++ b/deploy/slim/act/configs/MobileNetV3_small_x1_0/qat_dis.yaml @@ -0,0 +1,38 @@ +Global: + model_dir: MobileNetV3_small_x1_0_infer + model_filename: inference.pdmodel + params_filename: inference.pdiparams + batch_size: 16 + input_name: inputs + +Distillation: + alpha: 1.0 + loss: soft_label + +QuantAware: + use_pact: true + activation_bits: 8 + is_full_quantize: false + onnx_format: True + activation_quantize_type: moving_average_abs_max + weight_quantize_type: channel_wise_abs_max + not_quant_pattern: + - skip_quant + quantize_op_types: + - conv2d + - depthwise_conv2d + - matmul + - matmul_v2 + weight_bits: 8 + +TrainConfig: + epochs: 2 + eval_iter: 5000 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.001 + optimizer_builder: + optimizer: + type: Momentum + weight_decay: 0.00002 + origin_metric: 0.6819 diff --git a/deploy/slim/act/configs/PPHGNet_small/data_reader.yaml b/deploy/slim/act/configs/PPHGNet_small/data_reader.yaml new file mode 100644 index 0000000000..f5f32e6428 --- /dev/null +++ b/deploy/slim/act/configs/PPHGNet_small/data_reader.yaml @@ -0,0 +1,59 @@ +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 236 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 16 + drop_last: False + shuffle: True + loader: + num_workers: 16 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 236 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 16 + drop_last: False + shuffle: False + loader: + num_workers: 12 + use_shared_memory: True \ No newline at end of file diff --git a/deploy/slim/act/configs/PPHGNet_small/qat_dis.yaml b/deploy/slim/act/configs/PPHGNet_small/qat_dis.yaml new file mode 100644 index 0000000000..43c75e6269 --- /dev/null +++ b/deploy/slim/act/configs/PPHGNet_small/qat_dis.yaml @@ -0,0 +1,38 @@ +Global: + model_dir: PPHGNet_small_infer + model_filename: inference.pdmodel + params_filename: inference.pdiparams + batch_size: 16 + input_name: x + +Distillation: + alpha: 1.0 + loss: soft_label + +QuantAware: + use_pact: true + activation_bits: 8 + is_full_quantize: false + onnx_format: True + activation_quantize_type: moving_average_abs_max + weight_quantize_type: channel_wise_abs_max + not_quant_pattern: + - skip_quant + quantize_op_types: + - conv2d + - depthwise_conv2d + - matmul + - matmul_v2 + weight_bits: 8 + +TrainConfig: + epochs: 2 + eval_iter: 5000 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.001 + optimizer_builder: + optimizer: + type: Momentum + weight_decay: 0.00002 + origin_metric: 0.7959 \ No newline at end of file diff --git a/deploy/slim/act/configs/PPLCNet_x1_0/data_reader.yaml b/deploy/slim/act/configs/PPLCNet_x1_0/data_reader.yaml new file mode 100644 index 0000000000..b939143358 --- /dev/null +++ b/deploy/slim/act/configs/PPLCNet_x1_0/data_reader.yaml @@ -0,0 +1,56 @@ +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 16 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 16 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True \ No newline at end of file diff --git a/deploy/slim/act/configs/PPLCNet_x1_0/qat_dis.yaml b/deploy/slim/act/configs/PPLCNet_x1_0/qat_dis.yaml new file mode 100644 index 0000000000..8464ad1eb1 --- /dev/null +++ b/deploy/slim/act/configs/PPLCNet_x1_0/qat_dis.yaml @@ -0,0 +1,38 @@ +Global: + model_dir: PPLCNet_x1_0_infer + model_filename: inference.pdmodel + params_filename: inference.pdiparams + batch_size: 16 + input_name: x + +Distillation: + alpha: 1.0 + node: + - softmax_1.tmp_0 + +QuantAware: + use_pact: true + activation_bits: 8 + is_full_quantize: false + onnx_format: True + activation_quantize_type: moving_average_abs_max + weight_quantize_type: channel_wise_abs_max + not_quant_pattern: + - skip_quant + quantize_op_types: + - conv2d + - depthwise_conv2d + - matmul + - matmul_v2 + weight_bits: 8 +TrainConfig: + epochs: 3 + eval_iter: 5000 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.001 + optimizer_builder: + optimizer: + type: Momentum + weight_decay: 0.00002 + origin_metric: 0.7132 diff --git a/deploy/slim/act/configs/ResNet50/data_reader.yaml b/deploy/slim/act/configs/ResNet50/data_reader.yaml new file mode 100644 index 0000000000..ca0b7cad0a --- /dev/null +++ b/deploy/slim/act/configs/ResNet50/data_reader.yaml @@ -0,0 +1,56 @@ +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 16 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True \ No newline at end of file diff --git a/deploy/slim/act/configs/ResNet50/qat_dis.yaml b/deploy/slim/act/configs/ResNet50/qat_dis.yaml new file mode 100644 index 0000000000..2b3e0acad9 --- /dev/null +++ b/deploy/slim/act/configs/ResNet50/qat_dis.yaml @@ -0,0 +1,41 @@ +Global: + model_dir: ResNet50_infer + model_filename: inference.pdmodel + params_filename: inference.pdiparams + batch_size: 32 + input_name: inputs + input_name: inputs + +Distillation: + alpha: 1.0 + loss: l2 + node: + - softmax_0.tmp_0 + +QuantAware: + use_pact: true + activation_bits: 8 + is_full_quantize: false + onnx_format: True + activation_quantize_type: moving_average_abs_max + weight_quantize_type: channel_wise_abs_max + not_quant_pattern: + - skip_quant + quantize_op_types: + - conv2d + - depthwise_conv2d + - matmul + - matmul_v2 + weight_bits: 8 + +TrainConfig: + epochs: 1 + eval_iter: 500 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.015 + optimizer_builder: + optimizer: + type: Momentum + weight_decay: 0.00002 + origin_metric: 0.7634556941778631 diff --git a/deploy/slim/act/configs/SwinTransformer_base/data_reader.yaml b/deploy/slim/act/configs/SwinTransformer_base/data_reader.yaml new file mode 100644 index 0000000000..a376d5a731 --- /dev/null +++ b/deploy/slim/act/configs/SwinTransformer_base/data_reader.yaml @@ -0,0 +1,59 @@ +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 16 + drop_last: False + shuffle: True + loader: + num_workers: 16 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 16 + drop_last: False + shuffle: False + loader: + num_workers: 12 + use_shared_memory: True \ No newline at end of file diff --git a/deploy/slim/act/configs/SwinTransformer_base/qat_dis.yaml b/deploy/slim/act/configs/SwinTransformer_base/qat_dis.yaml new file mode 100644 index 0000000000..588855c11f --- /dev/null +++ b/deploy/slim/act/configs/SwinTransformer_base/qat_dis.yaml @@ -0,0 +1,38 @@ +Global: + model_dir: SwinTransformer_base_patch4_window7_224_infer + model_filename: inference.pdmodel + params_filename: inference.pdiparams + batch_size: 32 + input_name: inputs + +Distillation: + alpha: 1.0 + loss: l2 + node: + - softmax_48.tmp_0 +QuantAware: + use_pact: true + activation_bits: 8 + is_full_quantize: false + onnx_format: True + activation_quantize_type: moving_average_abs_max + weight_quantize_type: channel_wise_abs_max + not_quant_pattern: + - skip_quant + quantize_op_types: + - conv2d + - depthwise_conv2d + - matmul + - matmul_v2 + weight_bits: 8 +TrainConfig: + epochs: 1 + eval_iter: 500 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.015 + optimizer_builder: + optimizer: + type: Momentum + weight_decay: 0.00002 + origin_metric: 0.83 diff --git a/deploy/slim/act/configs/VIT/data_reader.yaml b/deploy/slim/act/configs/VIT/data_reader.yaml new file mode 100644 index 0000000000..370dde6440 --- /dev/null +++ b/deploy/slim/act/configs/VIT/data_reader.yaml @@ -0,0 +1,56 @@ +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 16 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True diff --git a/deploy/slim/act/configs/VIT/qat_dis.yaml b/deploy/slim/act/configs/VIT/qat_dis.yaml new file mode 100644 index 0000000000..d5ef60708a --- /dev/null +++ b/deploy/slim/act/configs/VIT/qat_dis.yaml @@ -0,0 +1,27 @@ +Global: + model_dir: ViT_base_patch16_224_infer + model_filename: inference.pdmodel + params_filename: inference.pdiparams + batch_size: 16 + input_name: inputs + +Distillation: + node: + - softmax_12.tmp_0 + +QuantAware: + use_pact: true + onnx_format: true + +TrainConfig: + epochs: 1 + eval_iter: 500 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.015 + optimizer_builder: + optimizer: + type: Momentum + weight_decay: 0.00002 + origin_metric: 0.8189 + diff --git a/deploy/slim/act/configs/eval.yaml b/deploy/slim/act/configs/eval.yaml new file mode 100644 index 0000000000..f7e61515c6 --- /dev/null +++ b/deploy/slim/act/configs/eval.yaml @@ -0,0 +1,8 @@ +Global: + model_dir: './mobilenet_dbb_inference' + model_filename: 'inference.pdmodel' + params_filename: "inference.pdiparams" + batch_size: 128 + data_dir: './ILSVRC2012/' + img_size: 224 + resize_size: 256 diff --git a/deploy/slim/act/run_ppclas.py b/deploy/slim/act/run_ppclas.py new file mode 100644 index 0000000000..6323139940 --- /dev/null +++ b/deploy/slim/act/run_ppclas.py @@ -0,0 +1,177 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import argparse +import math +from tqdm import tqdm + +import numpy as np +import paddle +from paddleslim.common import load_config as load_slim_config +from paddleslim.auto_compression import AutoCompression + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../../../'))) +from ppcls.data import build_dataloader +from ppcls.utils import config +from ppcls.utils.logger import init_logger + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--compression_config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--reader_config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--save_dir', + type=str, + default='output', + help="directory to save compressed model.") + parser.add_argument( + '--total_images', + type=int, + default=1281167, + help="the number of total training images.") + parser.add_argument( + '--devices', + type=str, + default='gpu', + help="which device used to compress.") + return parser + + +# yapf: enable +def reader_wrapper(reader, input_name): + if isinstance(input_name, list) and len(input_name) == 1: + input_name = input_name[0] + + def gen(): + for i, (imgs, label) in enumerate(reader()): + yield {input_name: imgs} + + return gen + + +def eval_function(exe, compiled_test_program, test_feed_names, + test_fetch_list): + + results = [] + with tqdm( + total=len(eval_loader), + bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: + for batch_id, (image, label) in enumerate(eval_loader): + # top1_acc, top5_acc + if len(test_feed_names) == 1: + image = np.array(image) + label = np.array(label).astype('int64') + if len(label.shape) == 1: + label = label.reshape([label.shape[0], -1]) + pred = exe.run(compiled_test_program, + feed={test_feed_names[0]: image}, + fetch_list=test_fetch_list) + pred = np.array(pred[0]) + sort_array = pred.argsort(axis=1) + top_1_pred = sort_array[:, -1:][:, ::-1] + top_1 = np.mean(label == top_1_pred) + top_5_pred = sort_array[:, -5:][:, ::-1] + acc_num = 0 + for i in range(len(label)): + if label[i][0] in top_5_pred[i]: + acc_num += 1 + top_5 = float(acc_num) / len(label) + results.append([top_1, top_5]) + else: + # eval "eval model", which inputs are image and label, output is top1 and top5 accuracy + image = np.array(image) + label = np.array(label).astype('int64') + result = exe.run(compiled_test_program, + feed={ + test_feed_names[0]: image, + test_feed_names[1]: label + }, + fetch_list=test_fetch_list) + result = [np.mean(r) for r in result] + results.append(result) + t.update() + result = np.mean(np.array(results), axis=0) + return result[0] + + +def main(): + rank_id = paddle.distributed.get_rank() + if args.devices == 'gpu': + paddle.CUDAPlace(rank_id) + device = paddle.set_device('gpu') + else: + paddle.CPUPlace() + device = paddle.set_device('cpu') + global global_config + all_config = load_slim_config(args.compression_config_path) + + assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" + global_config = all_config["Global"] + + gpu_num = paddle.distributed.get_world_size() + if isinstance(all_config['TrainConfig']['learning_rate'], + dict) and all_config['TrainConfig']['learning_rate'][ + 'type'] == 'CosineAnnealingDecay': + step = int( + math.ceil( + float(args.total_images) / (global_config['batch_size'] * + gpu_num))) + all_config['TrainConfig']['learning_rate']['T_max'] = step + print('total training steps:', step) + + init_logger() + data_config = config.get_config(args.reader_config_path, show=False) + train_loader = build_dataloader(data_config["DataLoader"], "Train", device, + False) + train_dataloader = reader_wrapper(train_loader, + global_config['input_name']) + + global eval_loader + eval_loader = build_dataloader(data_config["DataLoader"], "Eval", device, + False) + eval_dataloader = reader_wrapper(eval_loader, global_config['input_name']) + + ac = AutoCompression( + model_dir=global_config['model_dir'], + model_filename=global_config['model_filename'], + params_filename=global_config['params_filename'], + save_dir=args.save_dir, + config=all_config, + train_dataloader=train_dataloader, + eval_callback=eval_function if rank_id == 0 else None, + eval_dataloader=eval_dataloader) + + ac.compress() + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + args = parser.parse_args() + main() diff --git a/deploy/slim/act/test_ppclas.py b/deploy/slim/act/test_ppclas.py new file mode 100644 index 0000000000..e733a36911 --- /dev/null +++ b/deploy/slim/act/test_ppclas.py @@ -0,0 +1,255 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import sys +import argparse +import numpy as np +import cv2 +import yaml + +import paddle +from paddle.inference import create_predictor +from paddle.io import DataLoader +from imagenet_reader import ImageNetDataset + + +def argsparser(): + """ + argsparser func + """ + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--model_path", + type=str, + default="./MobileNetV1_infer", + help="model directory") + parser.add_argument( + "--model_filename", + type=str, + default="inference.pdmodel", + help="model file name") + parser.add_argument( + "--params_filename", + type=str, + default="inference.pdiparams", + help="params file name") + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--img_size", type=int, default=224) + parser.add_argument("--resize_size", type=int, default=256) + parser.add_argument( + "--data_path", type=str, default="./dataset/ILSVRC2012/") + parser.add_argument( + "--use_gpu", type=bool, default=False, help="Whether to use gpu") + parser.add_argument( + "--use_trt", type=bool, default=False, help="Whether to use tensorrt") + parser.add_argument( + "--use_mkldnn", type=bool, default=False, help="Whether to use mkldnn") + parser.add_argument( + "--cpu_num_threads", + type=int, + default=10, + help="Number of cpu threads") + parser.add_argument( + "--use_fp16", type=bool, default=False, help="Whether to use fp16") + parser.add_argument( + "--use_int8", type=bool, default=False, help="Whether to use int8") + parser.add_argument("--gpu_mem", type=int, default=8000, help="GPU memory") + parser.add_argument("--ir_optim", type=bool, default=True) + parser.add_argument( + "--use_dynamic_shape", + type=bool, + default=True, + help="Whether use dynamic shape or not.") + return parser + + +def eval_reader(data_dir, batch_size, crop_size, resize_size): + """ + eval reader func + """ + val_reader = ImageNetDataset( + mode="val", + data_dir=data_dir, + crop_size=crop_size, + resize_size=resize_size) + val_loader = DataLoader( + val_reader, + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + num_workers=0) + return val_loader + + +class Predictor(object): + """ + Paddle Inference Predictor class + """ + + def __init__(self): + # HALF precission predict only work when using tensorrt + if args.use_fp16 is True: + assert args.use_trt is True + + self.rerun_flag = False + self.paddle_predictor = self._create_paddle_predictor() + input_names = self.paddle_predictor.get_input_names() + self.input_tensor = self.paddle_predictor.get_input_handle(input_names[ + 0]) + + output_names = self.paddle_predictor.get_output_names() + self.output_tensor = self.paddle_predictor.get_output_handle( + output_names[0]) + + def _create_paddle_predictor(self): + inference_model_dir = args.model_path + model_file = os.path.join(inference_model_dir, args.model_filename) + params_file = os.path.join(inference_model_dir, args.params_filename) + config = paddle.inference.Config(model_file, params_file) + precision = paddle.inference.Config.Precision.Float32 + if args.use_int8: + precision = paddle.inference.Config.Precision.Int8 + elif args.use_fp16: + precision = paddle.inference.Config.Precision.Half + + if args.use_gpu: + config.enable_use_gpu(args.gpu_mem, 0) + else: + config.disable_gpu() + config.set_cpu_math_library_num_threads(args.cpu_num_threads) + config.switch_ir_optim() + if args.use_mkldnn: + config.enable_mkldnn() + if args.use_int8: + config.enable_mkldnn_int8({ + "conv2d", "depthwise_conv2d", "transpose2", "pool2d" + }) + + config.switch_ir_optim(args.ir_optim) # default true + if args.use_trt: + config.enable_tensorrt_engine( + precision_mode=precision, + max_batch_size=args.batch_size, + workspace_size=1 << 30, + min_subgraph_size=30, + use_static=True, + use_calib_mode=False, ) + + if args.use_dynamic_shape: + dynamic_shape_file = os.path.join(inference_model_dir, + "dynamic_shape.txt") + if os.path.exists(dynamic_shape_file): + config.enable_tuned_tensorrt_dynamic_shape( + dynamic_shape_file, True) + print("trt set dynamic shape done!") + else: + config.collect_shape_range_info(dynamic_shape_file) + print("Start collect dynamic shape...") + self.rerun_flag = True + + config.enable_memory_optim() + predictor = create_predictor(config) + + return predictor + + def eval(self): + """ + eval func + """ + if os.path.exists(args.data_path): + val_loader = eval_reader( + args.data_path, + batch_size=args.batch_size, + crop_size=args.img_size, + resize_size=args.resize_size) + else: + image = np.ones((args.batch_size, 3, args.img_size, + args.img_size)).astype(np.float32) + label = [[None]] * args.batch_size + val_loader = [[image, label]] + results = [] + input_names = self.paddle_predictor.get_input_names() + input_tensor = self.paddle_predictor.get_input_handle(input_names[0]) + output_names = self.paddle_predictor.get_output_names() + output_tensor = self.paddle_predictor.get_output_handle(output_names[ + 0]) + predict_time = 0.0 + time_min = float("inf") + time_max = float("-inf") + sample_nums = len(val_loader) + for batch_id, (image, label) in enumerate(val_loader): + image = np.array(image) + + input_tensor.copy_from_cpu(image) + start_time = time.time() + self.paddle_predictor.run() + batch_output = output_tensor.copy_to_cpu() + end_time = time.time() + timed = end_time - start_time + time_min = min(time_min, timed) + time_max = max(time_max, timed) + predict_time += timed + if self.rerun_flag: + return + sort_array = batch_output.argsort(axis=1) + top_1_pred = sort_array[:, -1:][:, ::-1] + if label is None: + results.append(top_1_pred) + break + label = np.array(label) + top_1 = np.mean(label == top_1_pred) + top_5_pred = sort_array[:, -5:][:, ::-1] + acc_num = 0 + for i, _ in enumerate(label): + if label[i][0] in top_5_pred[i]: + acc_num += 1 + top_5 = float(acc_num) / len(label) + results.append([top_1, top_5]) + if batch_id % 100 == 0: + print("Eval iter:", batch_id) + sys.stdout.flush() + + result = np.mean(np.array(results), axis=0) + fp_message = "FP16" if args.use_fp16 else "FP32" + fp_message = "INT8" if args.use_int8 else fp_message + print_msg = "Paddle" + if args.use_trt: + print_msg = "using TensorRT" + elif args.use_mkldnn: + print_msg = "using MKLDNN" + time_avg = predict_time / sample_nums + print( + "[Benchmark]{}\t{}\tbatch size: {}.Inference time(ms): min={}, max={}, avg={}". + format( + print_msg, + fp_message, + args.batch_size, + round(time_min * 1000, 2), + round(time_max * 1000, 1), + round(time_avg * 1000, 1), )) + print("[Benchmark] Evaluation acc result: {}".format(result[0])) + sys.stdout.flush() + + +if __name__ == "__main__": + parser = argsparser() + args = parser.parse_args() + predictor = Predictor() + predictor.eval() + if predictor.rerun_flag: + print( + "***** Collect dynamic shape done, Please rerun the program to get correct results. *****" + )