diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index a82fa1f771ff0..8a74027aa9eb7 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -2840,11 +2840,60 @@ bool WarprnntOpInferSymbolicShape( // return true; // } -// bool WeightedSampleNeighborsOpInferSymbolicShape( -// pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { -// // pass -// return true; -// } +bool WeightedSampleNeighborsOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + auto GSNShapeCheck = [](const ExprVec &input_shape, + std::string tensor_name, + pir::InferSymbolicShapeContext *infer_context) { + if (input_shape.size() == 2) { + infer_context->AddEqualCstr(input_shape[1], symbol::DimExpr(1)); + } else { + PADDLE_ENFORCE_EQ( + input_shape.size(), + 1, + phi::errors::InvalidArgument( + "The %s should be 1D, when it is not 2D, but we get %d", + tensor_name, + input_shape.size())); + } + }; + + const auto &row_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); + const auto &col_ptr_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape(); + const auto &edge_weight_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(2)).shape(); + const auto &x_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(3)).shape(); + const auto &eids_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(4)).shape(); + bool return_eids = op->attribute("return_eids").data(); + + GSNShapeCheck(row_shape, "row", infer_context); + GSNShapeCheck(col_ptr_shape, "col_ptr", infer_context); + GSNShapeCheck(edge_weight_shape, "edge_weight", infer_context); + GSNShapeCheck(x_shape, "input_nodes", infer_context); + + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs( + {infer_context->GetNextSymName()})}); + infer_context->SetShapeOrDataForValue( + op->result(1), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs( + {infer_context->GetNextSymName()})}); + if (return_eids) { + GSNShapeCheck(eids_shape, "eids", infer_context); + infer_context->SetShapeOrDataForValue( + op->result(2), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs( + {infer_context->GetNextSymName()})}); + } else { + infer_context->SetSymbolForValueByStaticShape(op->result(2)); + } + return true; +} bool WhereOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index 633ce6bfcf556..32eb6ae5b57e8 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -119,7 +119,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(ViterbiDecode) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Warpctc) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Warprnnt) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(WeightOnlyLinear) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(WeightedSampleNeighbors) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(WeightedSampleNeighbors) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(YoloLoss) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 49e93ed77d3c9..e34c7e9f8064a 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -5123,7 +5123,7 @@ kernel : func : weighted_sample_neighbors optional : eids - # interfaces : paddle::dialect::InferSymbolicShapeInterface + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : where args : (Tensor condition, Tensor x, Tensor y)