-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add cumprod_grad composite #64432
add cumprod_grad composite #64432
Changes from 1 commit
10b16ad
d37962e
db81952
4e08872
d23d6d3
a0b8dc0
3c8b7cf
300c63e
d200f8b
f985123
15a1fa2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1084,14 +1084,24 @@ void cumprod_grad(const Tensor& x, | |
std::vector<int64_t> x_dim = common::vectorize<int64_t>(x.dims()); | ||
auto zero_tensor = full<T>(x_dim, 0.0, x.dtype()); | ||
auto zero_mask = cast<T>(equal<T>(x, zero_tensor), x.dtype()); | ||
auto common_dx = | ||
cumsum<T>(out * out_grad, dim, false, exclusive, !reverse) / x; | ||
auto zero_mask_cumsum1 = cumsum<T>(zero_mask, dim, false, false, reverse); | ||
auto zero_mask_cumsum2 = cumsum<T>(zero_mask, dim, false, true, reverse); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 变量命名尽量准确符合语义, |
||
auto zero_mask_cumsum = | ||
zero_mask_cumsum1 + | ||
zero_mask_cumsum2; // determine the index of first zero | ||
auto ones_tensor = full<T>(x_dim, 1.0, x.dtype()); | ||
auto replace_one = (1 - zero_mask) * x + zero_mask * ones_tensor; | ||
auto cumprod_recompute = cumprod<T>(replace_one, dim, exclusive, reverse); | ||
auto first_zero_mask = | ||
cast<T>(equal<T>(zero_mask_cumsum, ones_tensor), x.dtype()); | ||
auto common_dx = cumsum<T>(out * out_grad, dim, false, exclusive, !reverse); | ||
auto replace_one = (1 - zero_mask) * x + zero_mask; | ||
auto replace_first_one = (1 - first_zero_mask) * x + first_zero_mask; | ||
auto cumprod_recompute = | ||
cumprod<T>(replace_first_one, dim, exclusive, reverse); | ||
auto zeros_dx = cumsum<T>( | ||
cumprod_recompute * out_grad, dim, false, exclusive, !reverse); | ||
auto x_grad_res = (1 - zero_mask) * common_dx + zero_mask * zeros_dx; | ||
auto x_grad_res = | ||
((1 - first_zero_mask) * common_dx + first_zero_mask * zeros_dx) / | ||
replace_one; | ||
set_output<T>(x_grad_res, x_grad); | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import random | ||
import unittest | ||
|
||
import numpy as np | ||
|
@@ -24,6 +25,16 @@ | |
@param.parameterized_class( | ||
('primal', 'dtype'), | ||
[ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 加上一个单元素的测试: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
( | ||
np.random.rand( | ||
100, | ||
), | ||
np.float32, | ||
), | ||
( | ||
np.random.rand(10, 10), | ||
np.float32, | ||
), | ||
( | ||
np.random.rand(2, 3, 4), | ||
np.float32, | ||
|
@@ -32,12 +43,21 @@ | |
np.random.rand(2, 3, 3, 4), | ||
np.float32, | ||
), | ||
( | ||
np.random.rand(2, 3, 3, 4, 5), | ||
np.float32, | ||
), | ||
( | ||
np.random.randint(1, 100, (2, 3, 4)), | ||
np.int64, | ||
), | ||
], | ||
) | ||
class TestCumprodGradComp(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.primal = cls.primal.astype(cls.dtype) | ||
cls.zero_nums = [0, 1, 10, int(np.prod(cls.primal.shape))] | ||
|
||
def test_cumprod_grad_comp(self): | ||
def actual(primal, dim): | ||
|
@@ -62,13 +82,20 @@ def desired(primal, dim): | |
) | ||
return x_cotangent[0] | ||
|
||
for i in range(len(self.primal.shape)): | ||
np.testing.assert_allclose( | ||
actual=actual(self.primal, i), | ||
desired=desired(self.primal, i), | ||
rtol=1e-6, | ||
atol=0, | ||
) | ||
for zero_num in self.zero_nums: | ||
shape = self.primal.shape | ||
x = self.primal.flatten() | ||
indices = random.sample(range(x.size), zero_num) | ||
for i in indices: | ||
x[i] = 0 | ||
x = np.reshape(x, shape) | ||
for i in range(len(self.primal.shape)): | ||
np.testing.assert_allclose( | ||
actual=actual(x, i), | ||
desired=desired(x, i), | ||
rtol=1e-6, | ||
atol=0, | ||
) | ||
core.set_prim_eager_enabled(False) | ||
|
||
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 补充更多单测case,同 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,8 @@ | |
|
||
core._set_prim_backward_enabled(True) | ||
|
||
import random | ||
|
||
import numpy as np | ||
import parameterized as param | ||
|
||
|
@@ -46,14 +48,32 @@ def forward(self, x): | |
@param.parameterized_class( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上,补充单元素的0-Dcase There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
('primal', 'cotangent', 'dtype'), | ||
[ | ||
( | ||
np.random.rand( | ||
100, | ||
), | ||
np.random.rand( | ||
100, | ||
), | ||
np.float32, | ||
), | ||
(np.random.rand(10, 10), np.random.rand(10, 10), np.float32), | ||
(np.random.rand(3, 4, 5), np.random.rand(3, 4, 5), np.float32), | ||
(np.random.rand(2, 3, 4, 5), np.random.rand(2, 3, 4, 5), np.float32), | ||
( | ||
np.random.rand(2, 3, 2, 4, 5), | ||
np.random.rand(2, 3, 2, 4, 5), | ||
np.float32, | ||
), | ||
(np.random.randint(1, 20, (10, 10)), np.random.rand(10, 10), np.int64), | ||
], | ||
) | ||
class TestCumprodGradComp(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.primal = cls.primal.astype(cls.dtype) | ||
cls.cotangent = cls.cotangent.astype(cls.dtype) | ||
cls.zero_nums = [0, 1, 10, int(np.prod(cls.primal.shape))] | ||
|
||
def train(self, use_prim, use_cinn): | ||
paddle.seed(2022) | ||
|
@@ -108,13 +128,20 @@ def desired(primal, cotangent, dim): | |
fetch_list=[x_cotangent[0]], | ||
)[0] | ||
|
||
for i in range(len(self.primal.shape)): | ||
np.testing.assert_allclose( | ||
actual=actual(self.primal, self.cotangent, i), | ||
desired=desired(self.primal, self.cotangent, i), | ||
rtol=1e-6, | ||
atol=0, | ||
) | ||
for zero_num in self.zero_nums: | ||
shape = self.primal.shape | ||
x = self.primal.flatten() | ||
indices = random.sample(range(x.size), zero_num) | ||
for i in indices: | ||
x[i] = 0 | ||
x = np.reshape(x, shape) | ||
for i in range(len(self.primal.shape)): | ||
np.testing.assert_allclose( | ||
actual=actual(x, self.cotangent, i), | ||
desired=desired(x, self.cotangent, i), | ||
rtol=1e-6, | ||
atol=0, | ||
) | ||
core._set_prim_backward_enabled(False) | ||
paddle.disable_static() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处的实现跟details.h里的实现好像有一些区别?可以确认一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done