Skip to content

Commit b53ce8a

Browse files
committed
add tests
1 parent de2afce commit b53ce8a

File tree

3 files changed

+830
-466
lines changed

3 files changed

+830
-466
lines changed
Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import random
16+
import unittest
17+
18+
import numpy as np
19+
from op_test import convert_float_to_uint16, get_places
20+
21+
import paddle
22+
from paddle.device import get_device
23+
24+
25+
def cumprod_wrapper(x, dim=-1, exclusive=False, reverse=False):
26+
return paddle._C_ops.cumprod(x, dim, exclusive, reverse)
27+
28+
29+
# define cumprod grad function.
30+
def cumprod_grad(x, y, dy, dx, shape, dim, exclusive=False, reverse=False):
31+
if dim < 0:
32+
dim += len(shape)
33+
mid_dim = shape[dim]
34+
outer_dim = 1
35+
inner_dim = 1
36+
for i in range(0, dim):
37+
outer_dim *= shape[i]
38+
for i in range(dim + 1, len(shape)):
39+
inner_dim *= shape[i]
40+
if not reverse:
41+
for i in range(outer_dim):
42+
for k in range(inner_dim):
43+
for j in range(mid_dim):
44+
index = i * mid_dim * inner_dim + j * inner_dim + k
45+
for n in range(mid_dim):
46+
pos = i * mid_dim * inner_dim + n * inner_dim + k
47+
elem = 0
48+
if exclusive:
49+
if pos > index:
50+
elem = dy[pos] * y[index]
51+
for m in range(
52+
index + inner_dim, pos, inner_dim
53+
):
54+
elem *= x[m]
55+
else:
56+
elem = 0
57+
else:
58+
if j == 0:
59+
elem = dy[pos]
60+
else:
61+
elem = dy[pos] * y[index - inner_dim]
62+
if pos > index:
63+
for m in range(
64+
index + inner_dim,
65+
pos + inner_dim,
66+
inner_dim,
67+
):
68+
elem *= x[m]
69+
elif pos < index:
70+
elem = 0
71+
dx[index] += elem
72+
else:
73+
for i in range(outer_dim):
74+
for k in range(inner_dim):
75+
for j in range(mid_dim - 1, -1, -1):
76+
index = i * mid_dim * inner_dim + j * inner_dim + k
77+
for n in range(mid_dim - 1, -1, -1):
78+
pos = i * mid_dim * inner_dim + n * inner_dim + k
79+
elem = 0
80+
if exclusive:
81+
if pos < index:
82+
elem = dy[pos] * y[index]
83+
for m in range(
84+
index - inner_dim, pos, -inner_dim
85+
):
86+
elem *= x[m]
87+
else:
88+
if j == mid_dim - 1:
89+
elem = dy[pos]
90+
else:
91+
elem = dy[pos] * y[index + inner_dim]
92+
if pos < index:
93+
for m in range(
94+
index - inner_dim,
95+
pos - inner_dim,
96+
-inner_dim,
97+
):
98+
elem *= x[m]
99+
elif pos > index:
100+
elem = 0
101+
dx[index] += elem
102+
103+
104+
def skip_if_not_cpu_or_gpu(func):
105+
def wrapper(self):
106+
device = get_device()
107+
if not (device == 'cpu' or device.startswith('gpu:')):
108+
self.skipTest(f"Test skipped on device: {device}")
109+
return func(self)
110+
111+
return wrapper
112+
113+
114+
class TestCumprod(unittest.TestCase):
115+
def init_params(self):
116+
self.shape = (2, 3, 4, 5)
117+
self.zero_nums = [0, 10, 20, 30, int(np.prod(self.shape))]
118+
119+
def init_dtype(self):
120+
self.dtype = np.float64
121+
self.val_dtype = np.float64
122+
123+
def setUp(self):
124+
paddle.disable_static()
125+
self.init_params()
126+
self.init_dtype()
127+
128+
def tearDown(self):
129+
paddle.enable_static()
130+
131+
def prepare_test_data(self, dim, zero_num):
132+
self.x = (
133+
np.random.uniform(0.0, 0.5, self.shape).astype(self.val_dtype) + 0.5
134+
)
135+
if zero_num > 0:
136+
zero_num = min(zero_num, self.x.size)
137+
shape = self.x.shape
138+
self.x = self.x.flatten()
139+
indices = random.sample(range(self.x.size), zero_num)
140+
for i in indices:
141+
self.x[i] = 0
142+
self.x = np.reshape(self.x, self.shape)
143+
self.expected_out = np.cumprod(self.x, axis=dim)
144+
145+
def compute_expected_grad(self, dim):
146+
reshape_x = self.x.reshape(self.x.size)
147+
grad_out = np.ones(self.x.size, self.val_dtype)
148+
grad_x = np.zeros(self.x.size, self.val_dtype)
149+
out_data = self.expected_out.reshape(self.x.size)
150+
151+
if self.dtype == np.complex128 or self.dtype == np.complex64:
152+
reshape_x = np.conj(reshape_x)
153+
out_data = np.conj(out_data)
154+
155+
cumprod_grad(reshape_x, out_data, grad_out, grad_x, self.shape, dim)
156+
157+
return grad_x.reshape(self.shape)
158+
159+
def test_forward_computation(self):
160+
for dim in range(-len(self.shape), len(self.shape)):
161+
for zero_num in self.zero_nums:
162+
with self.subTest(dim=dim, zero_num=zero_num):
163+
self._test_forward_for_case(dim, zero_num)
164+
165+
def _test_forward_for_case(self, dim, zero_num):
166+
self.prepare_test_data(dim, zero_num)
167+
168+
x_tensor = paddle.to_tensor(self.x, dtype=self.val_dtype)
169+
out = paddle.cumprod(x_tensor, dim=dim)
170+
171+
np.testing.assert_allclose(
172+
out.numpy(), self.expected_out, rtol=1e-05, atol=1e-06
173+
)
174+
175+
def test_gradient_computation(self):
176+
for dim in range(-len(self.shape), len(self.shape)):
177+
for zero_num in [0, 10]:
178+
with self.subTest(dim=dim, zero_num=zero_num):
179+
self._test_gradient_for_case(dim, zero_num)
180+
181+
def _test_gradient_for_case(self, dim, zero_num):
182+
self.prepare_test_data(dim, zero_num)
183+
184+
x_tensor = paddle.to_tensor(
185+
self.x, dtype=self.val_dtype, stop_gradient=False
186+
)
187+
out = paddle.cumprod(x_tensor, dim=dim)
188+
189+
np.testing.assert_allclose(
190+
out.numpy(), self.expected_out, rtol=1e-05, atol=1e-06
191+
)
192+
193+
loss = paddle.sum(out)
194+
loss.backward()
195+
196+
expected_grad = self.compute_expected_grad(dim)
197+
198+
if self.dtype == np.float64:
199+
np.testing.assert_allclose(
200+
x_tensor.grad.numpy(), expected_grad, rtol=1e-05, atol=1e-06
201+
)
202+
else:
203+
if self.dtype == np.uint16:
204+
expected_grad_converted = convert_float_to_uint16(expected_grad)
205+
np.testing.assert_allclose(
206+
x_tensor.grad.numpy(),
207+
expected_grad_converted,
208+
rtol=1e-03,
209+
atol=1e-04,
210+
)
211+
else:
212+
np.testing.assert_allclose(
213+
x_tensor.grad.numpy(), expected_grad, rtol=1e-04, atol=1e-05
214+
)
215+
216+
217+
class TestCumprodDtypeFloat32(TestCumprod):
218+
def init_dtype(self):
219+
self.dtype = np.float32
220+
self.val_dtype = np.float32
221+
222+
@skip_if_not_cpu_or_gpu
223+
def test_dtype_float32(self):
224+
self.prepare_test_data(dim=1, zero_num=0)
225+
226+
x = paddle.to_tensor(self.x, dtype='float32')
227+
x.stop_gradient = False
228+
out = paddle.cumprod(x, dim=1, dtype='float32')
229+
self.assertEqual(out.dtype, paddle.float32)
230+
231+
out_ref = np.cumprod(self.x.astype(np.float32), axis=1).astype(
232+
np.float32
233+
)
234+
np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05)
235+
236+
loss = paddle.sum(out)
237+
loss.backward()
238+
self.assertEqual(x.grad.dtype, paddle.float32)
239+
240+
expected_grad = self.compute_expected_grad(1)
241+
np.testing.assert_allclose(
242+
x.grad.numpy(), expected_grad, rtol=1e-04, atol=1e-05
243+
)
244+
245+
246+
class TestCumprodDtypeFloat64(TestCumprod):
247+
def init_dtype(self):
248+
self.dtype = np.float32
249+
self.val_dtype = np.float32
250+
251+
@skip_if_not_cpu_or_gpu
252+
def test_dtype_float64(self):
253+
self.prepare_test_data(dim=1, zero_num=0)
254+
255+
x = paddle.to_tensor(self.x, dtype='float32')
256+
x.stop_gradient = False
257+
out = paddle.cumprod(x, dim=1, dtype='float64')
258+
self.assertEqual(out.dtype, paddle.float64)
259+
260+
out_ref = np.cumprod(self.x.astype(np.float32), axis=1).astype(
261+
np.float64
262+
)
263+
np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05)
264+
265+
loss = paddle.sum(out)
266+
loss.backward()
267+
self.assertEqual(x.grad.dtype, paddle.float32)
268+
269+
self.assertIsNotNone(x.grad)
270+
self.assertEqual(x.grad.shape, x.shape)
271+
272+
273+
class TestCumprodDtypeStatic(unittest.TestCase):
274+
def setUp(self):
275+
self.shape = [2, 3, 4]
276+
self.x = (np.random.rand(*self.shape) + 0.5).astype(np.float32)
277+
self.places = get_places()
278+
279+
@skip_if_not_cpu_or_gpu
280+
def test_static_dtype_float32(self):
281+
paddle.enable_static()
282+
for place in self.places:
283+
with paddle.static.program_guard(paddle.static.Program()):
284+
x = paddle.static.data('X', self.shape, dtype='float32')
285+
out = paddle.cumprod(x, dim=1, dtype='float32')
286+
exe = paddle.static.Executor(place)
287+
(out_res,) = exe.run(feed={'X': self.x}, fetch_list=[out])
288+
289+
out_ref = np.cumprod(self.x, axis=1).astype(np.float32)
290+
np.testing.assert_allclose(out_ref, out_res, rtol=1e-05)
291+
292+
293+
class TestCumprodBoundaryConditions(unittest.TestCase):
294+
def setUp(self):
295+
paddle.disable_static()
296+
297+
def tearDown(self):
298+
paddle.enable_static()
299+
300+
@skip_if_not_cpu_or_gpu
301+
def test_single_element_tensor(self):
302+
x = paddle.to_tensor([5.0], dtype='float32', stop_gradient=False)
303+
out = paddle.cumprod(x, dim=0)
304+
305+
self.assertEqual(out.shape, [1])
306+
np.testing.assert_allclose(out.numpy(), [5.0], rtol=1e-05)
307+
308+
out.backward()
309+
np.testing.assert_allclose(x.grad.numpy(), [1.0], rtol=1e-05)
310+
311+
@skip_if_not_cpu_or_gpu
312+
def test_zero_values_gradient(self):
313+
x_data = np.array([[1.0, 0.0, 3.0], [2.0, 4.0, 0.0]], dtype=np.float32)
314+
x = paddle.to_tensor(x_data, stop_gradient=False)
315+
316+
out = paddle.cumprod(x, dim=1)
317+
loss = paddle.sum(out)
318+
loss.backward()
319+
320+
self.assertIsNotNone(x.grad)
321+
self.assertEqual(x.grad.shape, x.shape)
322+
323+
@skip_if_not_cpu_or_gpu
324+
def test_negative_dim(self):
325+
x_data = np.random.rand(2, 3, 4).astype(np.float32) + 0.5
326+
x = paddle.to_tensor(x_data, stop_gradient=False)
327+
328+
out1 = paddle.cumprod(x, dim=-1)
329+
out2 = paddle.cumprod(x, dim=2)
330+
331+
np.testing.assert_allclose(out1.numpy(), out2.numpy(), rtol=1e-05)
332+
333+
loss1 = paddle.sum(out1)
334+
loss1.backward()
335+
336+
self.assertIsNotNone(x.grad)
337+
338+
339+
if __name__ == "__main__":
340+
unittest.main()

0 commit comments

Comments
 (0)