forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add save and load API for pruned model (PaddlePaddle#38)
- Loading branch information
1 parent
6cde8f0
commit 50d69ec
Showing
5 changed files
with
242 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# 卷积通道剪裁示例 | ||
|
||
本示例将演示如何按指定的剪裁率对每个卷积层的通道数进行剪裁。该示例默认会自动下载并使用mnist数据。 | ||
|
||
当前示例支持以下分类模型: | ||
|
||
- MobileNetV1 | ||
- MobileNetV2 | ||
- ResNet50 | ||
- PVANet | ||
|
||
## 接口介绍 | ||
|
||
该示例使用了`paddleslim.Pruner`工具类,用户接口使用介绍请参考:[API文档](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/) | ||
|
||
## 确定待裁参数 | ||
|
||
不同模型的参数命名不同,在剪裁前需要确定待裁卷积层的参数名称。可通过以下方法列出所有参数名: | ||
|
||
``` | ||
for param in program.global_block().all_parameters(): | ||
print("param name: {}; shape: {}".format(param.name, param.shape)) | ||
``` | ||
|
||
在`train.py`脚本中,提供了`get_pruned_params`方法,根据用户设置的选项`--model`确定要裁剪的参数。 | ||
|
||
## 启动裁剪任务 | ||
|
||
通过以下命令启动裁剪任务: | ||
|
||
``` | ||
export CUDA_VISIBLE_DEVICES=0 | ||
python train.py | ||
``` | ||
|
||
在本示例中,每训练一轮就会保存一个模型到文件系统。 | ||
|
||
执行`python train.py --help`查看更多选项。 | ||
|
||
## 注意 | ||
|
||
1. 在接口`paddle.Pruner.prune`的参数中,`params`和`ratios`的长度需要一样。 | ||
|
||
## 加载和评估模型 | ||
|
||
本节介绍如何加载训练过程中保存的模型。 | ||
|
||
执行以下代码加载模型并评估模型在测试集上的指标。 | ||
|
||
``` | ||
python eval.py \ | ||
--model "mobilenet" \ | ||
--data "mnist" \ | ||
--model_path "./models/0" | ||
``` | ||
|
||
在脚本`eval.py`中,使用`paddleslim.prune.load_model`接口加载剪裁得到的模型。 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import os | ||
import sys | ||
import logging | ||
import paddle | ||
import argparse | ||
import functools | ||
import math | ||
import time | ||
import numpy as np | ||
import paddle.fluid as fluid | ||
from paddleslim.prune import load_model | ||
from paddleslim.common import get_logger | ||
from paddleslim.analysis import flops | ||
sys.path.append(sys.path[0] + "/../") | ||
import models | ||
from utility import add_arguments, print_arguments | ||
|
||
_logger = get_logger(__name__, level=logging.INFO) | ||
|
||
parser = argparse.ArgumentParser(description=__doc__) | ||
add_arg = functools.partial(add_arguments, argparser=parser) | ||
# yapf: disable | ||
add_arg('batch_size', int, 64 * 4, "Minibatch size.") | ||
add_arg('use_gpu', bool, True, "Whether to use GPU or not.") | ||
add_arg('model', str, "MobileNet", "The target model.") | ||
add_arg('model_path', str, "./models/0", "The path of model used to evalate..") | ||
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'") | ||
add_arg('log_period', int, 10, "Log period in batches.") | ||
# yapf: enable | ||
|
||
model_list = models.__all__ | ||
|
||
|
||
def eval(args): | ||
train_reader = None | ||
test_reader = None | ||
if args.data == "mnist": | ||
import paddle.dataset.mnist as reader | ||
train_reader = reader.train() | ||
val_reader = reader.test() | ||
class_dim = 10 | ||
image_shape = "1,28,28" | ||
elif args.data == "imagenet": | ||
import imagenet_reader as reader | ||
train_reader = reader.train() | ||
val_reader = reader.val() | ||
class_dim = 1000 | ||
image_shape = "3,224,224" | ||
else: | ||
raise ValueError("{} is not supported.".format(args.data)) | ||
image_shape = [int(m) for m in image_shape.split(",")] | ||
assert args.model in model_list, "{} is not in lists: {}".format( | ||
args.model, model_list) | ||
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') | ||
label = fluid.layers.data(name='label', shape=[1], dtype='int64') | ||
# model definition | ||
model = models.__dict__[args.model]() | ||
out = model.net(input=image, class_dim=class_dim) | ||
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) | ||
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) | ||
val_program = fluid.default_main_program().clone(for_test=True) | ||
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() | ||
exe = fluid.Executor(place) | ||
exe.run(fluid.default_startup_program()) | ||
|
||
val_reader = paddle.batch(val_reader, batch_size=args.batch_size) | ||
|
||
val_feeder = feeder = fluid.DataFeeder( | ||
[image, label], place, program=val_program) | ||
|
||
load_model(val_program, "./model/mobilenetv1_prune_50") | ||
|
||
batch_id = 0 | ||
acc_top1_ns = [] | ||
acc_top5_ns = [] | ||
for data in val_reader(): | ||
start_time = time.time() | ||
acc_top1_n, acc_top5_n = exe.run( | ||
val_program, | ||
feed=val_feeder.feed(data), | ||
fetch_list=[acc_top1.name, acc_top5.name]) | ||
end_time = time.time() | ||
if batch_id % args.log_period == 0: | ||
_logger.info( | ||
"Eval batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".format( | ||
batch_id, | ||
np.mean(acc_top1_n), | ||
np.mean(acc_top5_n), end_time - start_time)) | ||
acc_top1_ns.append(np.mean(acc_top1_n)) | ||
acc_top5_ns.append(np.mean(acc_top5_n)) | ||
batch_id += 1 | ||
|
||
_logger.info("Final eval - acc_top1: {}; acc_top5: {}".format( | ||
np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns)))) | ||
|
||
|
||
def main(): | ||
args = parser.parse_args() | ||
print_arguments(args) | ||
eval(args) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import os | ||
import paddle.fluid as fluid | ||
from paddle.fluid import Program | ||
from ..core import GraphWrapper | ||
from ..common import get_logger | ||
import json | ||
import logging | ||
|
||
__all__ = ["save_model", "load_model"] | ||
|
||
_logger = get_logger(__name__, level=logging.INFO) | ||
|
||
PARAMS_FILE = "__params__" | ||
SHAPES_FILE = "__shapes__" | ||
|
||
|
||
def save_model(graph, dirname): | ||
""" | ||
Save weights of model and information of shapes into filesystem. | ||
Args: | ||
- graph(Program|Graph): The graph to be saved. | ||
- dirname(str): The directory that the model saved into. | ||
""" | ||
assert graph is not None and dirname is not None | ||
graph = GraphWrapper(graph) if isinstance(graph, Program) else graph | ||
exe = fluid.Executor(fluid.CPUPlace()) | ||
fluid.io.save_params( | ||
executor=exe, | ||
dirname=dirname, | ||
main_program=graph.program, | ||
filename=PARAMS_FILE) | ||
weights_file = os.path.join(dirname, PARAMS_FILE) | ||
_logger.info("Save model weights into {}".format(weights_file)) | ||
shapes = {} | ||
for var in graph.all_parameters(): | ||
shapes[var.name()] = var.shape() | ||
SHAPES_FILE = os.path.join(dirname, SHAPES_FILE) | ||
with open(SHAPES_FILE, "w") as f: | ||
json.dump(shapes, f) | ||
_logger.info("Save shapes of weights into {}".format(SHAPES_FILE)) | ||
|
||
|
||
def load_model(graph, dirname): | ||
""" | ||
Load weights of model and information of shapes from filesystem. | ||
Args: | ||
- graph(Program|Graph): The graph to be saved. | ||
- dirname(str): The directory that the model saved into. | ||
""" | ||
assert graph is not None and dirname is not None | ||
graph = GraphWrapper(graph) if isinstance(graph, Program) else graph | ||
exe = fluid.Executor(fluid.CPUPlace()) | ||
|
||
SHAPES_FILE = os.path.join(dirname, SHAPES_FILE) | ||
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE)) | ||
with open(SHAPES_FILE, "r") as f: | ||
shapes = json.load(f) | ||
for param, shape in shapes.items(): | ||
graph.var(param).set_shape(shape) | ||
|
||
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE)) | ||
|
||
exe = fluid.Executor(fluid.CPUPlace()) | ||
fluid.io.load_params( | ||
executor=exe, | ||
dirname=dirname, | ||
main_program=graph.program, | ||
filename=PARAMS_FILE) | ||
graph.update_groups_of_conv() | ||
graph.infer_shape() | ||
_logger.info("Load weights from {}".format( | ||
os.path.join(dirname, PARAMS_FILE))) |