From a784e4fe14d6f44cf41649d62fa402942a09a2c9 Mon Sep 17 00:00:00 2001 From: whs Date: Mon, 17 Feb 2020 17:37:17 +0800 Subject: [PATCH] Fix demo of pruning to load pretrained model. (#115) --- demo/prune/README.md | 25 +++++++++++++++++++------ demo/prune/eval.py | 2 +- demo/prune/train.py | 4 ++++ docs/zh_cn/api_cn/prune_api.rst | 2 +- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/demo/prune/README.md b/demo/prune/README.md index f3feac019ae01..138a315bd7c28 100755 --- a/demo/prune/README.md +++ b/demo/prune/README.md @@ -17,7 +17,20 @@ 1). 根据分类模型中[ImageNet数据准备文档](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87)下载数据到`PaddleSlim/demo/data/ILSVRC2012`路径下。 2). 使用`train.py`脚本时,指定`--data`选项为`imagenet`. -## 2. 启动剪裁任务 +## 2. 下载预训练模型 + +如果使用`ImageNet`数据,建议在预训练模型的基础上进行剪裁,请从[分类库](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD)中下载合适的预训练模型。 + +这里以`MobileNetV1`为例,下载并解压预训练模型到当前路径: + +``` +wget http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar +tar -xf MobileNetV1_pretrained.tar +``` + +使用`train.py`脚本时,指定`--pretrained_model`加载预训练模型。 + +## 3. 启动剪裁任务 通过以下命令启动裁剪任务: @@ -25,8 +38,8 @@ export CUDA_VISIBLE_DEVICES=0 python train.py \ --model "MobileNet" \ ---pruned_ratio 0.33 \ ---data "imagenet" +--pruned_ratio 0.31 \ +--data "mnist" ``` 其中,`model`用于指定待裁剪的模型。`pruned_ratio`用于指定各个卷积层通道数被裁剪的比例。`data`选项用于指定使用的数据集。 @@ -35,7 +48,7 @@ python train.py \ 在本示例中,会在日志中输出剪裁前后的`FLOPs`,并且每训练一轮就会保存一个模型到文件系统。 -## 3. 加载和评估模型 +## 4. 加载和评估模型 本节介绍如何加载训练过程中保存的模型。 @@ -43,14 +56,14 @@ python train.py \ ``` python eval.py \ ---model "mobilenet" \ +--model "MobileNet" \ --data "mnist" \ --model_path "./models/0" ``` 在脚本`eval.py`中,使用`paddleslim.prune.load_model`接口加载剪裁得到的模型。 -## 4. 接口介绍 +## 5. 接口介绍 该示例使用了`paddleslim.Pruner`工具类,用户接口使用介绍请参考:[API文档](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/) diff --git a/demo/prune/eval.py b/demo/prune/eval.py index 596764e6fc42b..28bc24ad984cf 100644 --- a/demo/prune/eval.py +++ b/demo/prune/eval.py @@ -68,7 +68,7 @@ def eval(args): val_feeder = feeder = fluid.DataFeeder( [image, label], place, program=val_program) - load_model(val_program, "./model/mobilenetv1_prune_50") + load_model(exe, val_program, args.model_path) batch_id = 0 acc_top1_ns = [] diff --git a/demo/prune/train.py b/demo/prune/train.py index 1a25f3f1e4dcc..2fd535399a716 100644 --- a/demo/prune/train.py +++ b/demo/prune/train.py @@ -136,6 +136,8 @@ def if_exist(var): return os.path.exists( os.path.join(args.pretrained_model, var.name)) + _logger.info("Load pretrained model from {}".format( + args.pretrained_model)) fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) val_reader = paddle.batch(val_reader, batch_size=args.batch_size) @@ -200,6 +202,8 @@ def train(epoch, program): end_time - start_time)) batch_id += 1 + test(0, val_program) + params = get_pruned_params(args, fluid.default_main_program()) _logger.info("FLOPs before pruning: {}".format( flops(fluid.default_main_program()))) diff --git a/docs/zh_cn/api_cn/prune_api.rst b/docs/zh_cn/api_cn/prune_api.rst index d38480d6055e5..38a76acdaed38 100644 --- a/docs/zh_cn/api_cn/prune_api.rst +++ b/docs/zh_cn/api_cn/prune_api.rst @@ -378,7 +378,7 @@ load_sensitivities } } sensitivities_file = "sensitive_api_demo.data" - with open(sensitivities_file, 'w') as f: + with open(sensitivities_file, 'wb') as f: pickle.dump(sen, f) sensitivities = load_sensitivities(sensitivities_file) print(sensitivities)