Skip to content

Commit

Permalink
【Hackathon 7th No.31】NO.31 为 paddle.sparse.sparse_csr_tensor进行功能增强 (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
monster1015 authored Sep 24, 2024
1 parent 189cd71 commit 5f90eef
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
4 changes: 2 additions & 2 deletions docs/api/paddle/sparse/sparse_csr_tensor_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
sparse_csr_tensor
-------------------------------

.. py:function:: paddle.sparse.sparse_csr_tensor(crows, cols, values, shape, dtype=None, place=None, stop_gradient=True)
.. py:function:: paddle.sparse.sparse_csr_tensor(crows, cols, values, shape=None, dtype=None, place=None, stop_gradient=True)
该 API 通过已知的非零元素的 ``crows`` , ``cols`` 和 ``values`` 来创建一个 CSR(Compressed Sparse Row) 格式的稀疏 tensor,tensor 类型为 ``paddle.Tensor`` 。

Expand All @@ -26,7 +26,7 @@ sparse_csr_tensor
list,tuple,numpy\.ndarray,paddle\.Tensor 类型。
- **values** (list|tuple|ndarray|Tensor) - 一维数组,存储非零元素,可以是
list,tuple,numpy\.ndarray,paddle\.Tensor 类型。
- **shape** (list|tuple) - 稀疏 Tensor 的形状,也是 Tensor 的形状,如果没有提供,将自动推测出最小的形状。
- **shape** (list|tuple,可选) - 稀疏 Tensor 的形状,也是 Tensor 的形状,如果没有提供,将自动推测出最小的形状。
- **dtype** (str|np.dtype,可选) - 创建 tensor 的数据类型,可以是 'bool' ,'float16','float32',
'float64' ,'int8','int16','int32','int64','uint8','complex64','complex128'。
默认值为 None,如果 ``values`` 为 python 浮点类型,则从
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,24 @@ torch.sparse_csr_tensor(crow_indices,
### [paddle.sparse.sparse_csr_tensor](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/sparse/sparse_csr_tensor_cn.html#sparse-csr-tensor)

```python
paddle.sparse.sparse_csr_tensor(crows, cols, values, shape, dtype=None, place=None, stop_gradient=True)
paddle.sparse.sparse_csr_tensor(crows, cols, values, shape=None, dtype=None, place=None, stop_gradient=True)
```

PyTorch 相比 Paddle 支持更多其他参数,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
| ---------------- | ------------- | -------------------------------------------------------------- |
| crow_indices | crows | 每行第一个非零元素在 values 的起始位置,仅参数名不一致。 |
| col_indices | cols | 一维数组,存储每个非零元素的列信息,仅参数名不一致。 |
| values | values | 一维数组,存储非零元素。 |
| size | shape | 稀疏 Tensor 的形状,仅参数名不一致。 |
| dtype | dtype | 创建 tensor 的数据类型。 |
| layout |- |表示布局方式,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。|
| device | place | 创建 tensor 的设备位置,仅参数名不一致。 |
| pin_memory | - | 表示是否使用锁页内存, Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。|
| requires_grad | stop_gradient | 是否阻断 Autograd 的梯度传导,两者参数功能相反,需要转写。 |
| PyTorch | PaddlePaddle | 备注 |
| ---------------- | ------------- | ----------------------------------------------------------------------------------- |
| crow_indices | crows | 每行第一个非零元素在 values 的起始位置,仅参数名不一致。 |
| col_indices | cols | 一维数组,存储每个非零元素的列信息,仅参数名不一致。 |
| values | values | 一维数组,存储非零元素。 |
| size | shape | 稀疏 Tensor 的形状,仅参数名不一致。 |
| dtype | dtype | 创建 tensor 的数据类型。 |
| layout | - | 表示布局方式,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 |
| device | place | 创建 tensor 的设备位置,仅参数名不一致。 |
| pin_memory | - | 表示是否使用锁页内存, Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 |
| requires_grad | stop_gradient | 是否阻断 Autograd 的梯度传导,两者参数功能相反,需要转写。 |
| check_invariants | - | 是否检查稀疏 Tensor 变量,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 |

### 转写示例
Expand Down

0 comments on commit 5f90eef

Please sign in to comment.