Skip to content

Commit

Permalink
【Infer Symbolic Shape No.137、143、159】【BUAA】Add elementwise ops (Paddl…
Browse files Browse the repository at this point in the history
…ePaddle#67120)

* Update element_wise_binary.cc

* Update element_wise_binary.h

* delete unnecesary placeholders

* delete unnecesary placeholders

* Update ops.yaml
  • Loading branch information
Guanhuachen2003 authored and Jeff114514 committed Aug 14, 2024
1 parent 005d2a1 commit d006b4d
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -998,27 +998,6 @@ bool TopPSamplingOpInferSymbolicShape(
// return true;
// }

// bool GammainccOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

// bool HeavisideOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

// bool NextafterOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

} // namespace paddle::dialect

namespace cinn::dialect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonalTensor)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonalTensor_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedSoftmaxMask)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GridSample)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gammaincc)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gather)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GatherNd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(HuberLoss)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Histogram)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Heaviside)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Isclose)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(AccuracyCheck)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexSample)
Expand All @@ -56,7 +54,6 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kron)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaskedSelect)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Matmul)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mv)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nextafter)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullBoxSparse)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullGpuPsSparse)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullSparseV2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,13 @@ OP_ELEMENT_WISE_BINARY(Equal)
OP_ELEMENT_WISE_BINARY(Equal_)
OP_ELEMENT_WISE_BINARY(Fmax)
OP_ELEMENT_WISE_BINARY(Fmin)
OP_ELEMENT_WISE_BINARY(Gammaincc)
OP_ELEMENT_WISE_BINARY(Gammaincc_)
OP_ELEMENT_WISE_BINARY(GreaterEqual)
OP_ELEMENT_WISE_BINARY(GreaterEqual_)
OP_ELEMENT_WISE_BINARY(GreaterThan)
OP_ELEMENT_WISE_BINARY(GreaterThan_)
OP_ELEMENT_WISE_BINARY(Heaviside)
OP_ELEMENT_WISE_BINARY(LessEqual)
OP_ELEMENT_WISE_BINARY(LessEqual_)
OP_ELEMENT_WISE_BINARY(LessThan)
Expand All @@ -167,6 +170,7 @@ OP_ELEMENT_WISE_BINARY(Minimum)
OP_ELEMENT_WISE_BINARY(MultiplySr)
OP_ELEMENT_WISE_BINARY(MultiplySr_)
OP_ELEMENT_WISE_BINARY(Multiply_)
OP_ELEMENT_WISE_BINARY(Nextafter)
OP_ELEMENT_WISE_BINARY(NotEqual)
OP_ELEMENT_WISE_BINARY(NotEqual_)
OP_ELEMENT_WISE_BINARY(Remainder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Equal)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Equal_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Fmax)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Fmin)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gammaincc)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gammaincc_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GreaterEqual)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GreaterEqual_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GreaterThan)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GreaterThan_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Heaviside)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LessEqual)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LessEqual_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LessThan)
Expand All @@ -53,6 +56,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Multiply)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MultiplySr)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MultiplySr_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Multiply_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nextafter)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(NotEqual)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(NotEqual_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Remainder)
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2142,6 +2142,7 @@
func : gammaincc
inplace: (x -> out)
backward : gammaincc_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : gammaln
args : (Tensor x)
Expand Down Expand Up @@ -2354,6 +2355,7 @@
kernel :
func : heaviside
backward : heaviside_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : hinge_loss
args: (Tensor logits, Tensor labels)
Expand Down Expand Up @@ -3347,6 +3349,7 @@
func : nextafter
data_type : x
traits : paddle::dialect::ForwardOnlyTrait
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : nll_loss
args : (Tensor input, Tensor label, Tensor weight, int64_t ignore_index = -100, str reduction = "mean")
Expand Down

0 comments on commit d006b4d

Please sign in to comment.