-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Asp movement #36525
Asp movement #36525
Changes from 7 commits
da1ed48
f6dac96
806d642
0e071cb
22336af
8dca054
5ee159f
550e926
b24c964
cd95bd5
66081ce
175eba7
1d3fc96
34c56f4
6ae246c
c39c5dc
225ce60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,7 +66,7 @@ def decorate(optimizer): | |
|
||
import paddle | ||
import paddle.fluid as fluid | ||
from paddle.fluid.contrib import sparsity | ||
from paddle.static import sparsity | ||
|
||
main_program = fluid.Program() | ||
startup_program = fluid.Program() | ||
|
@@ -128,17 +128,13 @@ def prune_model(place, | |
import paddle | ||
import paddle.fluid as fluid | ||
import paddle.fluid.core as core | ||
from paddle.fluid.contrib import sparsity | ||
from paddle.static import sparsity | ||
|
||
paddle.enable_static() | ||
|
||
main_program = fluid.Program() | ||
startup_program = fluid.Program() | ||
|
||
place = paddle.CPUPlace() | ||
if core.is_compiled_with_cuda(): | ||
place = paddle.CUDAPlace(0) | ||
|
||
with fluid.program_guard(main_program, startup_program): | ||
input_data = fluid.layers.data(name='data', shape=[None, 128]) | ||
label = fluid.layers.data(name='label', shape=[None, 10]) | ||
|
@@ -158,6 +154,8 @@ def prune_model(place, | |
optimizer = sparsity.decorate(optimizer) | ||
optimizer.minimize(loss, startup_program) | ||
|
||
place = paddle.CPUPlace() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么要设置 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 修改成使用paddle.get_device來判斷Place的裝置 |
||
|
||
exe = fluid.Executor(place) | ||
exe.run(startup_program) | ||
|
||
|
@@ -348,7 +346,7 @@ def _is_supported_layer(cls, main_program, param_name): | |
.. code-block:: python | ||
|
||
import paddle.fluid as fluid | ||
from paddle.fluid.contrib.sparsity.asp import ASPHelper | ||
from paddle.static.sparsity.asp import ASPHelper | ||
|
||
main_program = fluid.Program() | ||
startup_program = fluid.Program() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from ...fluid.contrib.sparsity import calculate_density #noqa: F401 | ||
from ...fluid.contrib.sparsity import check_mask_1d #noqa: F401 | ||
from ...fluid.contrib.sparsity import get_mask_1d #noqa: F401 | ||
from ...fluid.contrib.sparsity import check_mask_2d #noqa: F401 | ||
from ...fluid.contrib.sparsity import get_mask_2d_greedy #noqa: F401 | ||
from ...fluid.contrib.sparsity import get_mask_2d_best #noqa: F401 | ||
from ...fluid.contrib.sparsity import create_mask #noqa: F401 | ||
from ...fluid.contrib.sparsity import check_sparsity #noqa: F401 | ||
from ...fluid.contrib.sparsity import MaskAlgo #noqa: F401 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MaskAlgo这个类型可以作为内部使用的数据类型使用,不建议对外公开。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 採用建議使用小寫string作為輸入選項,並隱藏 MaskAlgo |
||
from ...fluid.contrib.sparsity import CheckMethod #noqa: F401 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 再确认一下,以上API都需要向用户暴露吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 一定需要暴露給用戶的API僅為MaskAlgo,目的是要讓用戶能夠決定Pruning的方式 (1D, 2D_greedy, 2D_best)等。 |
||
from ...fluid.contrib.sparsity import decorate #noqa: F401 | ||
from ...fluid.contrib.sparsity import prune_model #noqa: F401 | ||
from ...fluid.contrib.sparsity import set_excluded_layers #noqa: F401 | ||
from ...fluid.contrib.sparsity import reset_excluded_layers #noqa: F401 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 将需要对外公开的API 放到__all__列表,paddle会根据__all__列表区别公开API和内部使用API, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要对用户暴露的接口的示例代码中,不要再
import fluid
,也不要用任何fluid
的API。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的 已經全面修正