Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cumprod_grad composite #64432

Merged
merged 11 commits into from
May 24, 2024
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处的实现跟details.h里的实现好像有一些区别?可以确认一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Original file line number Diff line number Diff line change
Expand Up @@ -1084,14 +1084,24 @@ void cumprod_grad(const Tensor& x,
std::vector<int64_t> x_dim = common::vectorize<int64_t>(x.dims());
auto zero_tensor = full<T>(x_dim, 0.0, x.dtype());
auto zero_mask = cast<T>(equal<T>(x, zero_tensor), x.dtype());
auto common_dx =
cumsum<T>(out * out_grad, dim, false, exclusive, !reverse) / x;
auto zero_mask_cumsum1 = cumsum<T>(zero_mask, dim, false, false, reverse);
auto zero_mask_cumsum2 = cumsum<T>(zero_mask, dim, false, true, reverse);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量命名尽量准确符合语义,
zero_mask_cumsum1 --> zero_mask_cumsum_left
zero_mask_cumsum2 --> zero_mask_cumsum_right

auto zero_mask_cumsum =
zero_mask_cumsum1 +
zero_mask_cumsum2; // determine the index of first zero
auto ones_tensor = full<T>(x_dim, 1.0, x.dtype());
auto replace_one = (1 - zero_mask) * x + zero_mask * ones_tensor;
auto cumprod_recompute = cumprod<T>(replace_one, dim, exclusive, reverse);
auto first_zero_mask =
cast<T>(equal<T>(zero_mask_cumsum, ones_tensor), x.dtype());
auto common_dx = cumsum<T>(out * out_grad, dim, false, exclusive, !reverse);
auto replace_one = (1 - zero_mask) * x + zero_mask;
auto replace_first_one = (1 - first_zero_mask) * x + first_zero_mask;
auto cumprod_recompute =
cumprod<T>(replace_first_one, dim, exclusive, reverse);
auto zeros_dx = cumsum<T>(
cumprod_recompute * out_grad, dim, false, exclusive, !reverse);
auto x_grad_res = (1 - zero_mask) * common_dx + zero_mask * zeros_dx;
auto x_grad_res =
((1 - first_zero_mask) * common_dx + first_zero_mask * zeros_dx) /
replace_one;
set_output<T>(x_grad_res, x_grad);
}
}
Expand Down
41 changes: 34 additions & 7 deletions test/prim/prim/vjp/eager/test_comp_eager_cumprod_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import unittest

import numpy as np
Expand All @@ -24,6 +25,16 @@
@param.parameterized_class(
('primal', 'dtype'),
[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加上一个单元素的测试:np.array(np.rand(), dtype="float32")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

(
np.random.rand(
100,
),
np.float32,
),
(
np.random.rand(10, 10),
np.float32,
),
(
np.random.rand(2, 3, 4),
np.float32,
Expand All @@ -32,12 +43,21 @@
np.random.rand(2, 3, 3, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4, 5),
np.float32,
),
(
np.random.randint(1, 100, (2, 3, 4)),
np.int64,
),
],
)
class TestCumprodGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal = cls.primal.astype(cls.dtype)
cls.zero_nums = [0, 1, 10, int(np.prod(cls.primal.shape))]

def test_cumprod_grad_comp(self):
def actual(primal, dim):
Expand All @@ -62,13 +82,20 @@ def desired(primal, dim):
)
return x_cotangent[0]

for i in range(len(self.primal.shape)):
np.testing.assert_allclose(
actual=actual(self.primal, i),
desired=desired(self.primal, i),
rtol=1e-6,
atol=0,
)
for zero_num in self.zero_nums:
shape = self.primal.shape
x = self.primal.flatten()
indices = random.sample(range(x.size), zero_num)
for i in indices:
x[i] = 0
x = np.reshape(x, shape)
for i in range(len(self.primal.shape)):
np.testing.assert_allclose(
actual=actual(x, i),
desired=desired(x, i),
rtol=1e-6,
atol=0,
)
core.set_prim_eager_enabled(False)


Expand Down
41 changes: 34 additions & 7 deletions test/prim/prim/vjp/static/test_comp_cumprod_grad.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充更多单测case,同test/prim/prim/vjp/eager/test_comp_eager_cumprod_grad.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

core._set_prim_backward_enabled(True)

import random

import numpy as np
import parameterized as param

Expand Down Expand Up @@ -46,14 +48,32 @@ def forward(self, x):
@param.parameterized_class(
Copy link
Contributor

@HydrogenSulfate HydrogenSulfate May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,补充单元素的0-Dcase

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

('primal', 'cotangent', 'dtype'),
[
(
np.random.rand(
100,
),
np.random.rand(
100,
),
np.float32,
),
(np.random.rand(10, 10), np.random.rand(10, 10), np.float32),
(np.random.rand(3, 4, 5), np.random.rand(3, 4, 5), np.float32),
(np.random.rand(2, 3, 4, 5), np.random.rand(2, 3, 4, 5), np.float32),
(
np.random.rand(2, 3, 2, 4, 5),
np.random.rand(2, 3, 2, 4, 5),
np.float32,
),
(np.random.randint(1, 20, (10, 10)), np.random.rand(10, 10), np.int64),
],
)
class TestCumprodGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal = cls.primal.astype(cls.dtype)
cls.cotangent = cls.cotangent.astype(cls.dtype)
cls.zero_nums = [0, 1, 10, int(np.prod(cls.primal.shape))]

def train(self, use_prim, use_cinn):
paddle.seed(2022)
Expand Down Expand Up @@ -108,13 +128,20 @@ def desired(primal, cotangent, dim):
fetch_list=[x_cotangent[0]],
)[0]

for i in range(len(self.primal.shape)):
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent, i),
desired=desired(self.primal, self.cotangent, i),
rtol=1e-6,
atol=0,
)
for zero_num in self.zero_nums:
shape = self.primal.shape
x = self.primal.flatten()
indices = random.sample(range(x.size), zero_num)
for i in indices:
x[i] = 0
x = np.reshape(x, shape)
for i in range(len(self.primal.shape)):
np.testing.assert_allclose(
actual=actual(x, self.cotangent, i),
desired=desired(x, self.cotangent, i),
rtol=1e-6,
atol=0,
)
core._set_prim_backward_enabled(False)
paddle.disable_static()

Expand Down