Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve pruning module #2354

Merged
merged 28 commits into from
Jun 23, 2017
Merged

Improve pruning module #2354

merged 28 commits into from
Jun 23, 2017

Conversation

NHZlX
Copy link
Contributor

@NHZlX NHZlX commented Jun 2, 2017

resolve #2284
add dynamicPruningHook in ParameterUpdaterHook.cpp
improve python v2 api

@@ -131,6 +134,73 @@ class StaticPruningHook : public IParameterUpdaterHook {
std::vector<bool> mask_;
};

class DynamicPruningHook : public IParameterUpdaterHook {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the dynamicpruning hook, which calculate mask according to sparsity_ratio


/**
* ParameterUpdaterHook actually factory method.
*/
static IParameterUpdaterHook* createImpl(
const ParameterUpdaterHookConfig& config) {
auto& type = config.type();
if (type == "pruning") {
if (config.has_purning_mask_filename()) {
if (type == "pruning_static") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the specific 'pruning' type is dynamic one, and 'pruning_static' is static which read mask from the file.

@NHZlX NHZlX requested review from reyoung, hedaoyuan and Xreki June 2, 2017 06:45
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增的DynamicPruningHook,和原来StaticPruningHook的区别是:

  • DynamicPruningHook是根据config里面的sparsity_ratio生成一个mask
  • StaticPruningHook是从文件读入mask

二者对于Parameter的操作是一样的,只是初始化mask的方式不一样。我不认为这种实现是一种Dynamic Pruning。它还是Static Pruning,可以替换原来的Staic Pruning实现。

Dynamic Pruning是类似这篇文章Dynamic Network Surgery for Efficient DNNs中的方法,不需要指定sparsity_ratio,可以自动地设置和调整sparsity_ratio,直至达到最大的压缩率。

@@ -25,6 +25,9 @@ limitations under the License. */
#include "paddle/utils/Flags.h"
#include "paddle/utils/Util.h"

using std::vector;
using std::pair;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要使用using语句

@NHZlX
Copy link
Contributor Author

NHZlX commented Jun 2, 2017

@Xreki 嗯, 那么就只保留sparsity_ratio 的这个

@Xreki
Copy link
Contributor

Xreki commented Jun 2, 2017

@NHZlX

那么就只保留sparsity_ratio 的这个

只保留这种就可以了,原来的StaticPruningHook操作太不方便

* define which link/weight between neural is disabled.
* Static means user specific a sparsity_ratio map before training started. The
* network will
* hold the sparsity_ratio maximum numbers of parameters, and cut off the rest.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 33,多于的map。line 35,不通顺。

SameThreadChecker updateThreadChecker_;
std::atomic<size_t> initCount_;
VectorPtr maskVec_;
std::vector<bool> mask_;
VectorPtr maskTemp_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看看能不能不把maskTemp做成成员变量。

}

LOG(FATAL) << "Unknown Hook type: " << type;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认一下,如果没有配置hook时,是不会调用该函数的吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块是为了保证如果指定的hook type不在我们提供的当中的话,会报错。 python v2 /python/paddle/trainer/config_parser.py ParameterHook 处有关于hook type的检测,但是未来也有其他不同过python方式来调用这个的吧。

@@ -26,7 +26,8 @@ enum ParameterInitStrategy {

message ParameterUpdaterHookConfig {
required string type = 1;
optional string purning_mask_filename = 2;
//hook type such as 'pruning'
optional double sparsity_ratio = 3;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然purning_mask_filename类型删掉了,sparsity_ratio应该设置成2。另外,sparsity_ratio若是可选的,则应该设置默认值。

@@ -26,7 +26,8 @@ enum ParameterInitStrategy {

message ParameterUpdaterHookConfig {
required string type = 1;
optional string purning_mask_filename = 2;
//hook type such as 'pruning'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释一般放在前面,并且//后要有一个空格。

Hook Attribute object. The hook is an auxiliary operation that occurs
during network propagation. Such as pruning operation, It will cut off
redundant parameters in the network before training. More detail can see
here paddle/parameter/ParameterUpdaterHook.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可改成引用论文。另外,最后一句检查一下语法。


:param sparsity_ratio: Must be specified if hook type is 'pruning',
the network will hold the sparsity_ratio maximum parameters, and cut off the rest.
:type sparsity_ratio: float number between 0 and 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:type xxx:后面只跟类型,范围限制应该放到:param xxx:后面。


for (size_t i = 0; i < para->getSize(); i++)
param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i));
std::sort(param.begin(), param.end(), sortPairAscend);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以用std::partial_sort

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hedaoyuan 嗯,这里用这个排序会更好一些,我会修改一下

dataPtr[i++] = m ? 1.0 : 0.0;
}
}

// Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if (para->useGpu()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

86-91这段代码可以挪到generateMask里面去吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hedaoyuan 这个是已经删掉的代码

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

哦,我指的是修改后的文件的86-91行。if (para->useGpu())这段逻辑。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以放到那里边去

if (config.has_purning_mask_filename()) {
return new StaticPruningHook(config.purning_mask_filename());
}
if (config.has_sparsity_ratio())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config.has_sparsity_ratio()的判断放到StaticPruningHook的构造里面去吧。另外,这个我看python里面是有default值的,这里为什么是报错,而不是加default值?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StaticPruningHook只是赋值没有判断。我指的是There must be sparsity_ratio parameter for pruning Hook.这个逻辑本身是属于StaticPruningHook的。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python端会修改一下: 如果没有指定sparsit_ratio, 自动使用默认值, c++这块的判断将删掉

@@ -60,17 +61,28 @@ class StaticPruningHook : public IParameterUpdaterHook {
maskTemp_ = Vector::create(para->getSize(), false);
maskTemp_->zeroMem();
real* dataPtr = maskTemp_->getData();
size_t sparsityNum = para->getSize() * (1 - sparsityRatio_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sparsityRatio_这个指的是非零元还是零元的ratio?我看左边的第72行,原先的定义看起来是非零元的ratio,这里为什么换了?
···
for (size_t i = 0; i < para->getSize() * sparsityRatio_; i++)
···

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是0元的ratio,mask初始值全为0, 这里要做的是将非0元素的mask 设置为1,所以为(1 - sparsityRatio_), 这里的sparsityNum确实起的不是太好,会修改一下

Copy link
Contributor

@hedaoyuan hedaoyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

required string type = 1;
optional string purning_mask_filename = 2;
optional double sparsity_ratio = 2 [default = 0.8];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里要注释清楚sparsity_ratio指的是非零还是零元的。比如default=0.8指的是80%的零元?实际上我看到sparsity_ratio的第一反应以为是非零元的占比。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hedaoyuan 嗯,好

* Static means user load a mask map before training started. This map will
* define which link/weight between neural is disabled.
* Static means user specific a sparsity_ratio before training start, and the
* network will prune the parameters based on the sparsity_ratio. More deatils
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: specific -> specify, start -> started
More deatils can see -> More details can be found

SetDevice device(para->getDeviceId());
void generateMask(Parameter* para) {
VectorPtr vec = para->getBuf(PARAMETER_VALUE);
maskTemp_ = Vector::create(para->getSize(), false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maskTemp_改成局部变量


std::partial_sort(
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend);
for (size_t i = 0; i < nonZeroNum; i++) dataPtr[param[i].second] = 1.0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一些变量的名字看能不能改下,比如dataPtr,我得往上找到63行才知道这个变量指的是什么内容,最好用一些能顾名思义的名字。包括vecvecCpu

during network propagation.
NOTE: IT IS A HIGH LEVEL USER INTERFACE.

:param type: Hook type, eg: 'pruning'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的注释会用来生成api文档,所以最好写详细点。比如所有支持的type类型,以及参考的论文工作。

"""
Hook Attribute object. The hook is an auxiliary operation that occurs
during network propagation.
NOTE: IT IS A HIGH LEVEL USER INTERFACE.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好能说明下Hook是用来干什么的,作用的对象。另外,我认为这个NOTE就没有必要了。

assert is_compatible_with(
self.sparsity_ratio,
float), 'sparisity_ratio must be float type'
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] '
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'sparisity must be a flaot between [0, 1] ' -> 'sparisity_ratio must be a float between [0, 1] ',错误提示与变量名保持一致,另外还有typo。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@Xreki
Copy link
Contributor

Xreki commented Jun 16, 2017

git commit的信息最好能反映实际修改的内容, Update ParameterUpdaterHook.cpp看不出改了什么。

@NHZlX NHZlX removed the request for review from reyoung June 17, 2017 08:03
@NHZlX NHZlX merged commit 8b86624 into PaddlePaddle:develop Jun 23, 2017
@NHZlX NHZlX deleted the improve_pruning branch June 23, 2017 05:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve the pruning module of paddle
3 participants