-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add cumprod_grad composite #64432
add cumprod_grad composite #64432
Changes from 2 commits
10b16ad
d37962e
db81952
4e08872
d23d6d3
a0b8dc0
3c8b7cf
300c63e
d200f8b
f985123
15a1fa2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
- maximum | ||
- minimum | ||
- prod | ||
- cumprod | ||
- roll | ||
- scatter | ||
- scatter_nd_add | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1071,6 +1071,31 @@ void gather_nd_grad(const Tensor& x, | |
} | ||
} | ||
|
||
template <typename T> | ||
void cumprod_grad(const Tensor& x, | ||
const Tensor& out, | ||
const Tensor& out_grad, | ||
int dim, | ||
bool exclusive, | ||
bool reverse, | ||
Tensor* x_grad) { | ||
if (x_grad) { | ||
// dx = cumsum(out * out_grad, dim, false, exclusive, !reverse) / x | ||
std::vector<int64_t> x_dim = common::vectorize<int64_t>(x.dims()); | ||
auto zero_tensor = full<T>(x_dim, 0.0, x.dtype()); | ||
auto zero_mask = cast<T>(equal<T>(x, zero_tensor), x.dtype()); | ||
auto common_dx = | ||
cumsum<T>(out * out_grad, dim, false, exclusive, !reverse) / x; | ||
auto ones_tensor = full<T>(x_dim, 1.0, x.dtype()); | ||
auto replace_one = (1 - zero_mask) * x + zero_mask * ones_tensor; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
auto cumprod_recompute = cumprod<T>(replace_one, dim, exclusive, reverse); | ||
auto zeros_dx = cumsum<T>( | ||
cumprod_recompute * out_grad, dim, false, exclusive, !reverse); | ||
auto x_grad_res = (1 - zero_mask) * common_dx + zero_mask * zeros_dx; | ||
set_output<T>(x_grad_res, x_grad); | ||
} | ||
} | ||
|
||
template <typename T> | ||
void prod_grad(const Tensor& x, | ||
const Tensor& out, | ||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -59,6 +59,31 @@ void cumsum_grad(const Tensor& x, | |||||||
} | ||||||||
} | ||||||||
|
||||||||
template <typename T> | ||||||||
void cumprod_grad(const Tensor& x, | ||||||||
const Tensor& out, | ||||||||
const Tensor& out_grad, | ||||||||
int dim, | ||||||||
bool exclusive, | ||||||||
bool reverse, | ||||||||
Tensor* x_grad) { | ||||||||
if (x_grad) { | ||||||||
// dx = cumsum(out * out_grad, dim, false, exclusive, !reverse) / x | ||||||||
std::vector<int64_t> x_dim = common::vectorize<int64_t>(x.dims()); | ||||||||
auto zero_tensor = full<T>(x_dim, 0.0, x.dtype()); | ||||||||
auto zero_mask = cast<T>(equal<T>(x, zero_tensor), x.dtype()); | ||||||||
auto common_dx = | ||||||||
cumsum<T>(out * out_grad, dim, false, exclusive, !reverse) / x; | ||||||||
auto ones_tensor = full<T>(x_dim, 1.0, x.dtype()); | ||||||||
auto replace_one = (1 - zero_mask) * x + zero_mask * ones_tensor; | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的ones_tensor应该可以删掉吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 逻辑连续的代码块之间可以加上一些注释,比如此处将0的位置填充成1,可以加上: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||||||||
auto cumprod_recompute = cumprod<T>(replace_one, dim, exclusive, reverse); | ||||||||
auto zeros_dx = cumsum<T>( | ||||||||
cumprod_recompute * out_grad, dim, false, exclusive, !reverse); | ||||||||
auto x_grad_res = (1 - zero_mask) * common_dx + zero_mask * zeros_dx; | ||||||||
set_output<T>(x_grad_res, x_grad); | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
template <typename T> | ||||||||
void divide_grad(const Tensor& x, | ||||||||
const Tensor& y, | ||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
import parameterized as param | ||
|
||
import paddle | ||
from paddle.base import core | ||
|
||
|
||
@param.parameterized_class( | ||
('primal', 'dtype'), | ||
[ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 加上一个单元素的测试: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
( | ||
np.random.rand(2, 3, 4), | ||
np.float32, | ||
), | ||
( | ||
np.random.rand(2, 3, 3, 4), | ||
np.float32, | ||
), | ||
], | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
class TestCumprodGradComp(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.primal = cls.primal.astype(cls.dtype) | ||
|
||
def test_cumprod_grad_comp(self): | ||
def actual(primal, dim): | ||
paddle.disable_static() | ||
core.set_prim_eager_enabled(True) | ||
x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False) | ||
x.stop_gradient = False | ||
y = paddle.cumprod(x, dim=dim) | ||
x_cotangent = paddle.grad( | ||
y, x, create_graph=True, retain_graph=True | ||
) | ||
return x_cotangent[0] | ||
|
||
def desired(primal, dim): | ||
paddle.disable_static() | ||
core.set_prim_eager_enabled(False) | ||
x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False) | ||
x.stop_gradient = False | ||
y = paddle.cumprod(x, dim=dim) | ||
x_cotangent = paddle.grad( | ||
y, x, create_graph=False, retain_graph=True | ||
) | ||
return x_cotangent[0] | ||
|
||
for i in range(len(self.primal.shape)): | ||
np.testing.assert_allclose( | ||
actual=actual(self.primal, i), | ||
desired=desired(self.primal, i), | ||
rtol=1e-6, | ||
atol=0, | ||
) | ||
core.set_prim_eager_enabled(False) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 补充更多单测case,同 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
from paddle.base import core | ||
|
||
core._set_prim_backward_enabled(True) | ||
|
||
import numpy as np | ||
import parameterized as param | ||
|
||
import paddle | ||
|
||
|
||
def apply_to_static(net, use_cinn): | ||
build_strategy = paddle.static.BuildStrategy() | ||
build_strategy.build_cinn_pass = use_cinn | ||
return paddle.jit.to_static( | ||
net, build_strategy=build_strategy, full_graph=True | ||
) | ||
|
||
|
||
class PrimeNet(paddle.nn.Layer): | ||
def __init__(self): | ||
super().__init__() | ||
self.fc = paddle.nn.Linear(4, 4) | ||
|
||
def forward(self, x): | ||
tmp = self.fc(x) | ||
out = paddle.cumprod(tmp, -1) | ||
return out | ||
|
||
|
||
@param.parameterized_class( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上,补充单元素的0-Dcase There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
('primal', 'cotangent', 'dtype'), | ||
[ | ||
(np.random.rand(10, 10), np.random.rand(10, 10), np.float32), | ||
], | ||
) | ||
class TestCumprodGradComp(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.primal = cls.primal.astype(cls.dtype) | ||
cls.cotangent = cls.cotangent.astype(cls.dtype) | ||
|
||
def train(self, use_prim, use_cinn): | ||
paddle.seed(2022) | ||
self.x = paddle.randn([2, 4]) | ||
self.x.stop_gradient = False | ||
net = PrimeNet() | ||
core._set_prim_backward_enabled(use_prim) | ||
net = apply_to_static(net, use_cinn) | ||
out = net(self.x) | ||
res = paddle.autograd.grad(out, [self.x]) | ||
|
||
return res | ||
|
||
def test_tanh_grad_comp(self): | ||
paddle.enable_static() | ||
|
||
def actual(primal, cotangent, dim): | ||
core._set_prim_backward_enabled(True) | ||
mp, sp = paddle.static.Program(), paddle.static.Program() | ||
with paddle.static.program_guard(mp, sp): | ||
x = paddle.static.data('primal', primal.shape, primal.dtype) | ||
x.stop_gradient = False | ||
v = paddle.static.data( | ||
'cotangent', cotangent.shape, cotangent.dtype | ||
) | ||
y = paddle.cumprod(x, dim) | ||
x_cotangent = paddle.static.gradients(y, x, v) | ||
exe = paddle.static.Executor() | ||
exe.run(sp) | ||
return exe.run( | ||
program=mp, | ||
feed={'primal': primal, 'cotangent': cotangent}, | ||
fetch_list=[x_cotangent[0]], | ||
)[0] | ||
|
||
def desired(primal, cotangent, dim): | ||
core._set_prim_backward_enabled(False) | ||
mp, sp = paddle.static.Program(), paddle.static.Program() | ||
with paddle.static.program_guard(mp, sp): | ||
x = paddle.static.data('primal', primal.shape, primal.dtype) | ||
x.stop_gradient = False | ||
v = paddle.static.data( | ||
'cotangent', cotangent.shape, cotangent.dtype | ||
) | ||
y = paddle.cumprod(x, dim) | ||
x_cotangent = paddle.static.gradients(y, x, v) | ||
exe = paddle.static.Executor() | ||
exe.run(sp) | ||
return exe.run( | ||
program=mp, | ||
feed={'primal': primal, 'cotangent': cotangent}, | ||
fetch_list=[x_cotangent[0]], | ||
)[0] | ||
|
||
for i in range(len(self.primal.shape)): | ||
np.testing.assert_allclose( | ||
actual=actual(self.primal, self.cotangent, i), | ||
desired=desired(self.primal, self.cotangent, i), | ||
rtol=1e-6, | ||
atol=0, | ||
) | ||
core._set_prim_backward_enabled(False) | ||
paddle.disable_static() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处的实现跟details.h里的实现好像有一些区别?可以确认一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done