From ca9c42295ff963fe9edb8229081764b6111e4992 Mon Sep 17 00:00:00 2001 From: epiphanyer <19307130192@fudan.edu.cn> Date: Wed, 31 Jul 2024 11:54:00 +0800 Subject: [PATCH 1/2] update softmax unittest --- .../test_sparse_fused_attention_op.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/test/legacy_test/test_sparse_fused_attention_op.py b/test/legacy_test/test_sparse_fused_attention_op.py index 098f4815b85f3..e0b2a8bd58026 100644 --- a/test/legacy_test/test_sparse_fused_attention_op.py +++ b/test/legacy_test/test_sparse_fused_attention_op.py @@ -21,7 +21,9 @@ import numpy as np import paddle +import paddle.sparse from paddle.base import core +from paddle.base.framework import in_pir_mode def get_cuda_version(): @@ -177,5 +179,83 @@ def setUp(self): self.use_mask = True +devices = [] +if paddle.device.get_device() != "cpu": + devices.append(paddle.device.get_device()) +else: + devices.append('cpu') + + +class TestSparseSoftmaxStaticAPI(unittest.TestCase): + ''' + Test the API paddle.sparse.nn.functional.softmax on some sparse tensors in pir mode in static graph. + ''' + + def check_result_coo(self, x_shape): + ''' + x_shape: a tensor shape, + generate a sparse tensor with shape "x_shape" and compute the output of paddle.sparse.nn.functional.softmax. + compare the output of paddle.sparse.nn.functional.softmax and the output of paddle.nn.functional.Softmax. + ''' + for device in devices: + paddle.device.set_device(device) + x = paddle.rand(x_shape, dtype='float32') + indices_data, values_data = ( + x.detach().to_sparse_coo(sparse_dim=len(x_shape)).indices(), + x.detach().to_sparse_coo(sparse_dim=len(x_shape)).values(), + ) + + x.stop_gradient = False + out = paddle.nn.functional.softmax(x) + + paddle.enable_static() + with paddle.static.program_guard( + paddle.static.Program(), paddle.staic.Program() + ): + indices = paddle.static.data( + name="indices", + shape=indices_data.shape, + dtype=indices_data.dtype, + ) + values = paddle.static.data( + name="values", + shape=values_data.shape, + dtype=values_data.dtype, + ) + + sp_x = paddle.sparse.sparse_coo_tensor( + indices, + values, + shape=x.shape, + dtype=x.dtype, + ) + sp_out = paddle.sparse.nn.functional.softmax(sp_x) + sp_dense_out = sp_out.to_dense() + + sp_exe = paddle.static.Executor() + sp_fetch = sp_exe.run( + feed={ + "indices": indices_data.numpy(), + "values": values_data.numpy(), + }, + fetch_list=[sp_dense_out], + return_numpy=True, + ) + np.testing.assert_allclose(out.numpy(), sp_fetch[0], rtol=1e-05) + paddle.disable_static() + + def test_softmax_2d(self): + if in_pir_mode(): + self.check_result_coo([3, 4]) + + def test_softmax_3d(self): + if in_pir_mode(): + self.check_result_coo([3, 4, 5]) + + def test_softmax_4d(self): + if in_pir_mode(): + self.check_result_coo([3, 4, 5, 6]) + + if __name__ == '__main__': unittest.main() From 80fd664b34b4b55e9730f32379f328c44d539830 Mon Sep 17 00:00:00 2001 From: epiphanyer <19307130192@fudan.edu.cn> Date: Wed, 31 Jul 2024 15:58:13 +0800 Subject: [PATCH 2/2] update subm_conv2d and subm_conv3d unittest --- test/legacy_test/test_sparse_conv_op.py | 196 ++++++++++++++++++++++++ 1 file changed, 196 insertions(+) diff --git a/test/legacy_test/test_sparse_conv_op.py b/test/legacy_test/test_sparse_conv_op.py index 755f792215834..e560ff19ed0b1 100644 --- a/test/legacy_test/test_sparse_conv_op.py +++ b/test/legacy_test/test_sparse_conv_op.py @@ -18,8 +18,10 @@ import numpy as np import paddle +import paddle.device from paddle import sparse from paddle.base import core +from paddle.base.framework import in_pir_mode logging.basicConfig( format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO @@ -579,5 +581,199 @@ def test2D_cpu(self): paddle.disable_static() +devices = [] +if paddle.device.get_device() != "cpu": + devices.append(paddle.device.get_device()) +else: + devices.append('cpu') + + +class TestSparseSubmConvStatic(unittest.TestCase): + ''' + test subm_conv2d and subm_conv3d in static graph in pir mode. + compare the results of subm_conv2d in static graph and dynamic graph, use the result in dynamic graph as the correct answer. + ''' + + def check_result_subm_conv2d(self, x_shape, weight_shape): + ''' + x_shape: the shape of input tensor x, [N, H, W, C] + weight_shape: the shape of conv kernel, [kH, kW, C/g, M] + compare the output of paddle.sparse.nn.functional.subm_conv2d in static graph and dynamic graph. + ''' + for device in devices: + paddle.device.set_device(device) + x = paddle.rand(x_shape, dtype='float32') + weight = paddle.randn(weight_shape, dtype='float32') + x_indices_data, x_values_data = ( + x.detach().to_sparse_coo(sparse_dim=len(x_shape)).indices(), + x.detach().to_sparse_coo(sparse_dim=len(x_shape)).values(), + ) + w_indices_data, w_values_data = ( + weight.detach() + .to_sparse_coo(sparse_dim=len(weight_shape)) + .indices(), + weight.detach() + .to_sparse_coo(sparse_dim=len(weight_shape)) + .values(), + ) + x.stop_gradient = False + weight.stop_gradient = False + + dynamic_out = paddle.sparse.nn.functional.subm_conv2d(x, weight) + dynamic_out_dense = dynamic_out.to_dense() + + paddle.enable_static() + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_indices = paddle.static.data( + name="x_indices", + shape=x_indices_data.shape, + dtype=x_indices_data.dtype, + ) + x_values = paddle.static.data( + name="x_values", + shape=x_values_data.shape, + dtype=x_values_data.dtype, + ) + w_indices = paddle.static.data( + name="w_indices", + shape=w_indices_data.shape, + dtype=w_indices_data.dtype, + ) + w_values = paddle.static.data( + name="w_values", + shape=w_values_data.shape, + dtype=w_values_data.dtype, + ) + + static_x = paddle.sparse.sparse_coo_tensor( + x_indices, + x_values, + shape=x_shape, + dtype=x.dtype, + ) + static_w = paddle.sparse.sparse_coo_tensor( + w_indices, + w_values, + shape=weight_shape, + dtype=weight.dtype, + ) + static_out = paddle.sparse.nn.functional.subm_conv2d( + static_x, static_w + ) + static_dense_out = static_out.to_dense() + + st_exe = paddle.static.Executor() + st_fetch = st_exe.run( + feed={ + "x_indices": x_indices_data.numpy(), + "x_values": x_values_data.numpy(), + "w_indices": w_indices_data.numpy(), + "w_values": w_values_data.numpy(), + }, + fetch_list=[static_dense_out], + return_numpy=True, + ) + np.testing.assert_allclose( + dynamic_out_dense.numpy(), st_fetch[0], rtol=1e-05 + ) + paddle.disable_static() + + def check_result_subm_conv3d(self, x_shape, weight_shape): + ''' + x_shape: the shape of input tensor x, [N, D, H, W, C] + weight_shape: the shape of conv kernel, [kD, kH, kW, C/g, M] + compare the output of paddle.sparse.nn.functional.subm_conv3d in static graph and dynamic graph. + ''' + for device in devices: + paddle.device.set_device(device) + x = paddle.rand(x_shape, dtype='float32') + weight = paddle.randn(weight_shape, dtype='float32') + x_indices_data, x_values_data = ( + x.detach().to_sparse_coo(sparse_dim=len(x_shape)).indices(), + x.detach().to_sparse_coo(sparse_dim=len(x_shape)).values(), + ) + w_indices_data, w_values_data = ( + weight.detach() + .to_sparse_coo(sparse_dim=len(weight_shape)) + .indices(), + weight.detach() + .to_sparse_coo(sparse_dim=len(weight_shape)) + .values(), + ) + x.stop_gradient = False + weight.stop_gradient = False + + dynamic_out = paddle.sparse.nn.functional.subm_conv3d(x, weight) + dynamic_out_dense = dynamic_out.to_dense() + + paddle.enable_static() + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_indices = paddle.static.data( + name="x_indices", + shape=x_indices_data.shape, + dtype=x_indices_data.dtype, + ) + x_values = paddle.static.data( + name="x_values", + shape=x_values_data.shape, + dtype=x_values_data.dtype, + ) + w_indices = paddle.static.data( + name="w_indices", + shape=w_indices_data.shape, + dtype=w_indices_data.dtype, + ) + w_values = paddle.static.data( + name="w_values", + shape=w_values_data.shape, + dtype=w_values_data.dtype, + ) + + static_x = paddle.sparse.sparse_coo_tensor( + x_indices, + x_values, + shape=x_shape, + dtype=x.dtype, + ) + static_w = paddle.sparse.sparse_coo_tensor( + w_indices, + w_values, + shape=weight_shape, + dtype=weight.dtype, + ) + static_out = paddle.sparse.nn.functional.subm_conv3d( + static_x, static_w + ) + static_dense_out = static_out.to_dense() + + st_exe = paddle.static.Executor() + st_fetch = st_exe.run( + feed={ + "x_indices": x_indices_data.numpy(), + "x_values": x_values_data.numpy(), + "w_indices": w_indices_data.numpy(), + "w_values": w_values_data.numpy(), + }, + fetch_list=[static_dense_out], + return_numpy=True, + ) + np.testing.assert_allclose( + dynamic_out_dense.numpy(), st_fetch[0], rtol=1e-05 + ) + paddle.disable_static() + + def test_subm_conv2d(self): + if in_pir_mode(): + self.check_result_subm_conv2d([1, 3, 4, 1], [3, 3, 1, 1]) + + def test_subm_conv3d(self): + if in_pir_mode(): + self.check_result_subm_conv3d([1, 1, 3, 4, 1], [1, 3, 3, 1, 1]) + + if __name__ == "__main__": unittest.main()