Skip to content

Commit

Permalink
[PIR][DynamicShape] Add an example for broadcast in dynamic shape inf…
Browse files Browse the repository at this point in the history
…er (#60608)

* Add an example for broadcast in dynamic shape infer
  • Loading branch information
lanxianghit authored Jan 8, 2024
1 parent 311c0ea commit 21f0c78
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,44 @@ bool InferSymbolicShapeAllEqualBinary(
return true;
}

bool InferSymbolicShapeElementWiseBinary(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source_0 = op->operand_source(0);
std::string operand_source_0_id = pir::GetValueId(&operand_source_0);
std::vector<symbol::DimExpr> shape_0{
shape_analysis->value_id_to_shapeordata_[operand_source_0_id].shape()};

pir::Value operand_source_1 = op->operand_source(1);
std::string operand_source_1_id = pir::GetValueId(&operand_source_1);
std::vector<symbol::DimExpr> shape_1{
shape_analysis->value_id_to_shapeordata_[operand_source_1_id].shape()};

if (shape_0.size() > shape_1.size()) {
for (size_t i = 0; i < shape_0.size() - shape_1.size(); i++) {
shape_1.emplace(shape_1.begin(), 1);
}
} else {
for (size_t i = 0; i < shape_1.size() - shape_0.size(); i++) {
shape_0.emplace(shape_0.begin(), 1);
}
}

std::vector<symbol::DimExpr> shapes;
symbol::DimExprBuilder builder{nullptr};
for (size_t i = 0; i < shape_0.size(); i++) {
shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i]));
}

// TODO(lanxianghit): fill data when the operation is on shape computation
std::vector<symbol::DimExpr> data;

pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);
symbol::ShapeOrDataDimExprs shape_data{shapes, data};
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
return true;
}

} // namespace

bool AbsOpInferSymbolicShape(pir::Operation *op,
Expand All @@ -63,6 +101,16 @@ bool Abs_OpInferSymbolicShape(pir::Operation *op,
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
}

bool AddOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeElementWiseBinary(op, shape_analysis);
}

bool Add_OpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeElementWiseBinary(op, shape_analysis);
}

bool CastOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ bool AbsOpInferSymbolicShape(pir::Operation *op,
bool Abs_OpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);

bool AddOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);

bool Add_OpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);

bool CastOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
func : add
inplace : (x -> out)
backward : add_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : add_n
args : (Tensor[] inputs)
Expand Down
19 changes: 11 additions & 8 deletions paddle/fluid/pir/transforms/shape_optimization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,17 @@ void DebugPrintOpInfo(
auto shape_data = shape_analysis->value_id_to_shapeordata_[value_id];
print_stream << ", ShapeOrData.shape: [";

for (auto str : shape_data.shape()) {
int64_t* i = std::get_if<int64_t>(&str);
std::string* s = std::get_if<std::string>(&str);
if (i) {
print_stream << *i << ", ";
} else if (s) {
print_stream << *s << ", ";
}
// for (auto str : shape_data.shape()) {
// int64_t* i = std::get_if<int64_t>(&str);
// std::string* s = std::get_if<std::string>(&str);
// if (i) {
// print_stream << *i << ", ";
// } else if (s) {
// print_stream << *s << ", ";
// }
// }
for (auto dim : shape_data.shape()) {
print_stream << dim << ", ";
}

print_stream << "], ShapeOrData.data: [";
Expand Down
51 changes: 46 additions & 5 deletions test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def reshape(x):
return out


def broadcast(x, y):
z = x + y
return z


class CINNSubGraphNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
Expand All @@ -67,6 +72,16 @@ def forward(self, x):
return out


class CINNBroadcastSubGraphNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.fn = broadcast

def forward(self, x, y):
out = self.fn(x, y)
return out


class TestCinnSubGraphBase(unittest.TestCase):
"""
Test Pir API + @to_static + CINN.
Expand Down Expand Up @@ -117,12 +132,38 @@ def test_eval_symolic(self):
import os

is_debug = os.getenv('IS_DEBUG_DY_SHAPE')
if is_debug:
cinn_out = self.eval_symbolic(use_cinn=True)
# print("cinn_out:", cinn_out)
# if is_debug:
# cinn_out = self.eval_symbolic(use_cinn=True)

dy_out = self.eval_symbolic(use_cinn=False)
# np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)


# dy_out = self.eval_symbolic(use_cinn=False)
# print("dy_out:", dy_out)
class TestCinnDyShapeBC(TestCinnDyShapeBase):
def prepare_data(self):
self.x_shape = [2, 4, 1]
self.x = paddle.randn(self.x_shape, dtype="float32")
self.x.stop_gradient = False

self.y_shape = [4, 5]
self.y = paddle.randn(self.y_shape, dtype="float32")
self.y.stop_gradient = False

def eval_symbolic(self, use_cinn):
paddle.seed(2022)
net = CINNBroadcastSubGraphNet()
input_spec = [
InputSpec(shape=[None, None, None], dtype='float32'),
InputSpec(shape=[None, None], dtype='float32'),
]
net = apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.x, self.y)
return out

def test_eval_symolic(self):
# cinn_out = self.eval_symbolic(use_cinn=True)
dy_out = self.eval_symbolic(use_cinn=False)
# np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)


Expand Down

0 comments on commit 21f0c78

Please sign in to comment.