Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prepare_gradient_aggregation for non-leaf output of PartialProgramLayer #44893

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,63 @@ def _verify_program(self, main_program):

return main_program

def prepare_gradient_aggregation(self, main_program, target_program):
# Why we need add Reverse gradient aggregation operation ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数注释最好用
""""
xxxx
"""""
格式

# In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
# def forward(self, in):
# x = 2 * in # <---- x is a non-leaf node in program.
# y = x + 3
# return x, y
#
# loss = forward(in)[0].sum()
# loss.backward() # <----- x@grad will be overwrited by elementwise_add_grad Op
def _need_aggregation(var):
"""
if exist a op whose inputs is var, then return True
"""
for op in main_program.block(0).ops:
for in_arg in op.input_arg_names:
if in_arg == var.name:
return True
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reuturn var.name in op.input_arg_names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样逻辑不等价了好像。


def _insert_aggregation_ops_for_var(target_program, var):
var_grad_name = var.name + "@GRAD"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里最好不要写死 + "@Grad" ,grad后缀框架是有统一的API的

finded_ops = list(
filter(
lambda x: any([
out_arg == var_grad_name
for out_arg in x[1].output_arg_names
]), enumerate(target_program.block(0).ops)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么要用 enumerate ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个enumerate得到的值是插入的idx,后续插入Op会用到的。


# len(finded_ops) may equals zero when stop_gradient works.
# len(finded_ops) may > 1, because we may have fill_constant op.
if len(finded_ops) == 0:
return None
suffix = "@dy2static"
Copy link
Contributor

@Aurelius84 Aurelius84 Aug 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是最好使用 var_name + _dy2static + grad_suffix。

# step1: create a new var named var.name@GRAD
target_program.block(0).create_var(name=var_grad_name + suffix,
type=var.type,
dtype=var.dtype,
shape=var.shape)
# step2: rename the var.name@GRAD to var.name@GRAD@dy2static
for idx, op in finded_ops:
op._rename_input(var_grad_name, var_grad_name + suffix)
op._rename_output(var_grad_name, var_grad_name + suffix)
# step3: insert sum op to aggregate the gradient.
# var.name@GRAD = sum(var.name@GRAD@dy2static, var.name@GRAD)
target_program.block(0)._insert_op(
finded_ops[-1][0] + 1,
type='sum',
inputs={'X': [var_grad_name, var_grad_name + suffix]},
outputs={"Out": var_grad_name})
return None

to_processed_vars = list(
filter(_need_aggregation, self._outputs.tolist()))
for _var in to_processed_vars:
_insert_aggregation_ops_for_var(target_program, _var)

@switch_to_static_graph
def _append_backward_desc(self, main_program):
# make sure all status of is_test are False in train mode.
Expand All @@ -299,6 +356,8 @@ def _append_backward_desc(self, main_program):
if targets and self._params:
backward.gradients(targets=targets, inputs=[])

self.prepare_gradient_aggregation(main_program, program)

return program

def _prune_unused_params(self, program):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest

import paddle
import numpy as np

SEED = 2020
np.random.seed(SEED)


class SimpleNet(paddle.nn.Layer):

def __init__(self):
super(SimpleNet, self).__init__()
self.linear1 = paddle.nn.Linear(10, 3)
self.linear2 = paddle.nn.Linear(3, 1)

def forward(self, x):
out1 = self.linear1(x)
out2 = self.linear2(out1)
return [out1, out2] # 梯度为0
#return [out1] # 梯度正常
#return [out2, out1] # 梯度正常


class TestGradientAggregationInDy2Static(unittest.TestCase):

def test_to_static(self):

def simplenet_grad(inp, to_static=False):
net = SimpleNet()
if to_static: net = paddle.jit.to_static(net)
loss = net(inp)
loss[0].backward()
return net.linear1.weight.grad

inp = paddle.to_tensor(np.random.randn(10, )).astype("float32")
self.assertTrue(
np.allclose(
simplenet_grad(inp, True).numpy(),
simplenet_grad(inp, False).numpy()))


if __name__ == '__main__':
unittest.main()