From 21f0c78a407bae1a5232c5db40a7d5427aae92f3 Mon Sep 17 00:00:00 2001 From: lanxianghit <47554610+lanxianghit@users.noreply.github.com> Date: Mon, 8 Jan 2024 20:20:19 +0800 Subject: [PATCH] [PIR][DynamicShape] Add an example for broadcast in dynamic shape infer (#60608) * Add an example for broadcast in dynamic shape infer --- .../interface/infer_symbolic_shape.cc | 48 +++++++++++++++++ .../operator/interface/infer_symbolic_shape.h | 6 +++ paddle/fluid/pir/dialect/operator/ir/ops.yaml | 1 + .../pir/transforms/shape_optimization_pass.cc | 19 ++++--- .../symbolic/test_cinn_sub_graph_symbolic.py | 51 +++++++++++++++++-- 5 files changed, 112 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index 79c8e703e1184..99ab424279a7d 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -51,6 +51,44 @@ bool InferSymbolicShapeAllEqualBinary( return true; } +bool InferSymbolicShapeElementWiseBinary( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source_0 = op->operand_source(0); + std::string operand_source_0_id = pir::GetValueId(&operand_source_0); + std::vector shape_0{ + shape_analysis->value_id_to_shapeordata_[operand_source_0_id].shape()}; + + pir::Value operand_source_1 = op->operand_source(1); + std::string operand_source_1_id = pir::GetValueId(&operand_source_1); + std::vector shape_1{ + shape_analysis->value_id_to_shapeordata_[operand_source_1_id].shape()}; + + if (shape_0.size() > shape_1.size()) { + for (size_t i = 0; i < shape_0.size() - shape_1.size(); i++) { + shape_1.emplace(shape_1.begin(), 1); + } + } else { + for (size_t i = 0; i < shape_1.size() - shape_0.size(); i++) { + shape_0.emplace(shape_0.begin(), 1); + } + } + + std::vector shapes; + symbol::DimExprBuilder builder{nullptr}; + for (size_t i = 0; i < shape_0.size(); i++) { + shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i])); + } + + // TODO(lanxianghit): fill data when the operation is on shape computation + std::vector data; + + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + symbol::ShapeOrDataDimExprs shape_data{shapes, data}; + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + return true; +} + } // namespace bool AbsOpInferSymbolicShape(pir::Operation *op, @@ -63,6 +101,16 @@ bool Abs_OpInferSymbolicShape(pir::Operation *op, return InferSymbolicShapeAllEqualUnary(op, shape_analysis); } +bool AddOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeElementWiseBinary(op, shape_analysis); +} + +bool Add_OpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeElementWiseBinary(op, shape_analysis); +} + bool CastOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { return InferSymbolicShapeAllEqualUnary(op, shape_analysis); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h index fc96df40596af..4e7c70fd386ef 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h @@ -68,6 +68,12 @@ bool AbsOpInferSymbolicShape(pir::Operation *op, bool Abs_OpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); +bool AddOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool Add_OpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + bool CastOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 97fa1a6879e0a..d9c5b1c873609 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -24,6 +24,7 @@ func : add inplace : (x -> out) backward : add_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : add_n args : (Tensor[] inputs) diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.cc b/paddle/fluid/pir/transforms/shape_optimization_pass.cc index 1ad2700684186..b42584fcf0953 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.cc +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.cc @@ -365,14 +365,17 @@ void DebugPrintOpInfo( auto shape_data = shape_analysis->value_id_to_shapeordata_[value_id]; print_stream << ", ShapeOrData.shape: ["; - for (auto str : shape_data.shape()) { - int64_t* i = std::get_if(&str); - std::string* s = std::get_if(&str); - if (i) { - print_stream << *i << ", "; - } else if (s) { - print_stream << *s << ", "; - } + // for (auto str : shape_data.shape()) { + // int64_t* i = std::get_if(&str); + // std::string* s = std::get_if(&str); + // if (i) { + // print_stream << *i << ", "; + // } else if (s) { + // print_stream << *s << ", "; + // } + // } + for (auto dim : shape_data.shape()) { + print_stream << dim << ", "; } print_stream << "], ShapeOrData.data: ["; diff --git a/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py b/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py index 98968a18e228f..5eb9e6af6da85 100644 --- a/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py +++ b/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py @@ -47,6 +47,11 @@ def reshape(x): return out +def broadcast(x, y): + z = x + y + return z + + class CINNSubGraphNet(paddle.nn.Layer): def __init__(self): super().__init__() @@ -67,6 +72,16 @@ def forward(self, x): return out +class CINNBroadcastSubGraphNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.fn = broadcast + + def forward(self, x, y): + out = self.fn(x, y) + return out + + class TestCinnSubGraphBase(unittest.TestCase): """ Test Pir API + @to_static + CINN. @@ -117,12 +132,38 @@ def test_eval_symolic(self): import os is_debug = os.getenv('IS_DEBUG_DY_SHAPE') - if is_debug: - cinn_out = self.eval_symbolic(use_cinn=True) - # print("cinn_out:", cinn_out) + # if is_debug: + # cinn_out = self.eval_symbolic(use_cinn=True) + + dy_out = self.eval_symbolic(use_cinn=False) + # np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8) + - # dy_out = self.eval_symbolic(use_cinn=False) - # print("dy_out:", dy_out) +class TestCinnDyShapeBC(TestCinnDyShapeBase): + def prepare_data(self): + self.x_shape = [2, 4, 1] + self.x = paddle.randn(self.x_shape, dtype="float32") + self.x.stop_gradient = False + + self.y_shape = [4, 5] + self.y = paddle.randn(self.y_shape, dtype="float32") + self.y.stop_gradient = False + + def eval_symbolic(self, use_cinn): + paddle.seed(2022) + net = CINNBroadcastSubGraphNet() + input_spec = [ + InputSpec(shape=[None, None, None], dtype='float32'), + InputSpec(shape=[None, None], dtype='float32'), + ] + net = apply_to_static(net, use_cinn, input_spec) + net.eval() + out = net(self.x, self.y) + return out + + def test_eval_symolic(self): + # cinn_out = self.eval_symbolic(use_cinn=True) + dy_out = self.eval_symbolic(use_cinn=False) # np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)