diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc old mode 100755 new mode 100644 index c972fde2dec80..0b309506a88c3 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -528,6 +528,15 @@ bool DotOpInferSymbolicShape(pir::Operation *op, return true; } +bool DistOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(std::vector{})}); + return true; +} + bool DropoutOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const symbol::ShapeOrDataDimExprs &x_shape_or_data = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h index 6795319c966f8..3d6596fbe384b 100755 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h @@ -36,6 +36,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv3d) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(ConvTranspose) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cross) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Correlation) +// OP_DECLARE_INFER_SYMBOLIC_SHAPE(DepthwiseConv) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Dist) OP_DECLARE_INFER_SYMBOLIC_SHAPE(DepthwiseConv2d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Dot) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Dropout) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 80c8c039bb2a5..0821946bda918 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -1451,6 +1451,7 @@ kernel : func : dist backward : dist_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : dot args : (Tensor x, Tensor y)