Skip to content

Commit

Permalink
【Hackathon 7th No.31】NO.31 为 paddle.sparse.sparse_csr_tensor进行功能增强 (#…
Browse files Browse the repository at this point in the history
…68281)

* 【Hackathon 7th No.31】NO.31 为 paddle.sparse.sparse_csr_tensor进行功能增强

* update

* update
  • Loading branch information
monster1015 authored Sep 27, 2024
1 parent e8fcf98 commit d2744a0
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 5 deletions.
30 changes: 25 additions & 5 deletions python/paddle/sparse/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,29 @@ def sparse_coo_tensor(
return out


def _infer_dense_csr_shape(crows, cols):
crows_numpy = crows.numpy()
cols_numpy = cols.numpy()
batchs = np.sum(crows_numpy[:-1] > crows_numpy[1:]) + 1
if (int(len(crows_numpy) / batchs) * batchs) != len(crows_numpy):
raise ValueError(
f"The calculated original matrix batch size is {batchs}, but it cannot correctly split the row data. Please carefully check the data or the input shape."
)
col = cols_numpy.max() + 1
row = int(len(crows_numpy) / batchs) - 1
if batchs == 1:
return [row, col]
else:
return [batchs, row, col]


# TODO: need to support shape is None
@dygraph_only
def sparse_csr_tensor(
crows: list[int] | tuple[int, ...] | npt.NDArray[np.int_] | Tensor,
cols: list[int] | tuple[int, ...] | npt.NDArray[np.int_] | Tensor,
values: NumbericSequence | npt.NDArray[Any] | Tensor,
shape: ShapeLike,
shape: ShapeLike | None = None,
dtype: DTypeLike | None = None,
place: CPUPlace | CUDAPinnedPlace | CUDAPlace | str | None = None,
stop_gradient: bool = True,
Expand Down Expand Up @@ -268,10 +284,14 @@ def sparse_csr_tensor(
_check_indices_dtype(crows.dtype)
_check_indices_dtype(cols.dtype)

if len(shape) != 2 and len(shape) != 3:
raise ValueError(
f"SparseCsrTensor only support 2-D or 3-D matrix. but get shape {shape}"
)
if shape is not None:
if len(shape) != 2 and len(shape) != 3:
raise ValueError(
f"SparseCsrTensor only support 2-D or 3-D matrix. but get shape {shape}"
)
else:
shape = _infer_dense_csr_shape(crows, cols)

rows = shape[len(shape) - 2]

if not crows.place._equals(place):
Expand Down
41 changes: 41 additions & 0 deletions test/legacy_test/test_sparse_utils_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,37 @@ def test_create_coo_no_shape(self):
coo = paddle.sparse.sparse_coo_tensor(indices, values)
assert [2, 2] == coo.shape

def test_create_csr_no_shape(self):
# 2D sparse tensor
crows = [0, 2, 3, 5]
clos = [1, 3, 2, 0, 1]
values = [1.0, 2.0, 3.0, 4.0, 5.0]
crows = paddle.to_tensor(crows, dtype='int32')
clos = paddle.to_tensor(clos, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
csr = paddle.sparse.sparse_csr_tensor(crows, clos, values)
assert [3, 4] == csr.shape

# 3D sparse tensor
crows = [0, 2, 2, 0, 1, 1, 0, 0, 0]
clos = [0, 1, 1]
values = [1, 2, 5]
crows = paddle.to_tensor(crows, dtype='int32')
clos = paddle.to_tensor(clos, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
csr = paddle.sparse.sparse_csr_tensor(crows, clos, values)
assert [3, 2, 2] == csr.shape

# 3D sparse tensor
crows = [0, 1, 2, 0, 1, 1, 0, 1, 2]
clos = [0, 2, 1, 0, 1]
values = [1, 2, 3, 4, 5]
crows = paddle.to_tensor(crows, dtype='int32')
clos = paddle.to_tensor(clos, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
csr = paddle.sparse.sparse_csr_tensor(crows, clos, values)
assert [3, 2, 3] == csr.shape


class TestSparseConvert(unittest.TestCase):
def test_to_sparse_coo(self):
Expand Down Expand Up @@ -489,6 +520,16 @@ def test_dtype(self):
crows, cols, values, shape
)

def test_error_crows(self):
with self.assertRaises(ValueError):
crows = [0, 2, 2, 0, 1, 1, 0, 0, 0, 0]
clos = [0, 1, 1]
values = [1, 2, 5]
crows = paddle.to_tensor(crows, dtype='int32')
clos = paddle.to_tensor(clos, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
coo = paddle.sparse.sparse_csr_tensor(crows, clos, values)


devices = []
if paddle.device.get_device() != "cpu":
Expand Down

0 comments on commit d2744a0

Please sign in to comment.