Skip to content

Commit

Permalink
[Sparse] Unified api args name (PaddlePaddle#47529)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangkaihuo authored Nov 3, 2022
1 parent 1a40465 commit f9a0605
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/sparse_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@
func : softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr}

- backward_op : sparse_coo_tensor_grad
forward : sparse_coo_tensor(Tensor values, Tensor indices, IntArray dense_shape) -> Tensor(out)
forward : sparse_coo_tensor(Tensor values, Tensor indices, int64_t[] shape) -> Tensor(out)
args : (Tensor indices, Tensor out_grad)
output : Tensor(values_grad)
infer_meta :
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/sparse_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@
backward : softmax_grad

- op : sparse_coo_tensor
args : (Tensor values, Tensor indices, IntArray dense_shape)
args : (Tensor values, Tensor indices, int64_t[] shape={})
output : Tensor(out)
infer_meta :
func : sparse::SparseCooTensorInferMeta
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/infermeta/sparse/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ void Pool3dInferMeta(const MetaTensor& x,

void SparseCooTensorInferMeta(const MetaTensor& values,
const MetaTensor& indices,
const IntArray& dense_shape,
const std::vector<int64_t>& shape,
MetaTensor* out) {
out->set_dims(phi::make_ddim(dense_shape.GetData()));
out->set_dims(phi::make_ddim(shape));
out->set_dtype(values.dtype());
out->set_layout(values.layout());
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/sparse/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void Pool3dInferMeta(const MetaTensor& x,

void SparseCooTensorInferMeta(const MetaTensor& values,
const MetaTensor& indices,
const IntArray& dense_shape,
const std::vector<int64_t>& shape,
MetaTensor* out);

} // namespace sparse
Expand Down
5 changes: 2 additions & 3 deletions paddle/phi/kernels/sparse/sparse_utils_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,9 @@ template <typename T, typename Context>
void SparseCooTensorKernel(const Context& dev_ctx,
const DenseTensor& values,
const DenseTensor& indices,
const IntArray& dense_shape,
const std::vector<int64_t>& shape,
SparseCooTensor* out) {
*out =
SparseCooTensor(indices, values, phi::make_ddim(dense_shape.GetData()));
*out = SparseCooTensor(indices, values, phi::make_ddim(shape));
}

} // namespace sparse
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/sparse/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def sparse_coo_tensor(
inputs = {'values': values, 'indices': indices}
if shape[0] is None:
shape[0] = -1
attrs = {'dense_shape': shape}
attrs = {'shape': shape}
helper = LayerHelper(op_type)
out = helper.create_sparse_variable_for_type_inference(dtype)
helper.append_op(
Expand Down

0 comments on commit f9a0605

Please sign in to comment.