Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#21 from gongshaotian/drr
Browse files Browse the repository at this point in the history
Add fusion testing code for fullOp and expandOp
  • Loading branch information
yuanlehome authored Aug 25, 2023
2 parents fed2f8b + da03ecb commit c251ca9
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 26 deletions.
15 changes: 13 additions & 2 deletions paddle/fluid/ir/drr/ir_operation_creator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ Operation* CreateOperation(const OpCall& op_call,
op_call.outputs()[1]->name(),
std::make_shared<IrValue>(reshape_op->result(1)));
return reshape_op;

} else if (op_call.name() == "pd.transpose") {
const auto& inputs = op_call.inputs();
std::vector<Value> ir_values =
Expand All @@ -97,6 +98,7 @@ Operation* CreateOperation(const OpCall& op_call,
op_call.outputs()[0]->name(),
std::make_shared<IrValue>(transpose_op->result(0)));
return transpose_op;

} else if (op_call.name() == "pd.cast") {
const auto& inputs = op_call.inputs();
std::vector<Value> ir_values =
Expand All @@ -107,10 +109,19 @@ Operation* CreateOperation(const OpCall& op_call,
res_match_ctx->BindIrValue(op_call.outputs()[0]->name(),
std::make_shared<IrValue>(cast_op->result(0)));
return cast_op;

} else if (op_call.name() == "pd.full") {
const auto& inputs = op_call.inputs();
std::vector<Value> ir_values =
GetIrValuesByDrrTensors(inputs, *res_match_ctx);
Operation* full_op = rewriter.Build<paddle::dialect::FullOp>(
CreateAttributeMap(op_call, src_match_ctx));
res_match_ctx->BindIrValue(op_call.outputs()[0]->name(),
std::make_shared<IrValue>(full_op->result(0)));
return full_op;
}

LOG(ERROR) << "Unknown op " << op_call.name();
return nullptr;
PADDLE_THROW("Unknown op :" + op_call.name());
}

} // namespace drr
Expand Down
21 changes: 21 additions & 0 deletions paddle/fluid/ir/drr/match_context_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ PD_SPECIALIZE_CppTypeToIrAttribute(int64_t, Int64Attribute);
PD_SPECIALIZE_CppTypeToIrAttribute(float, FloatAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(phi::DataType,
paddle::dialect::DataTypeAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute);

template <typename T>
struct IrAttrTypeCast {
Expand All @@ -62,6 +63,26 @@ struct IrAttrTypeCast<std::vector<int32_t>> {
}
};

template <>
struct IrAttrTypeCast<std::vector<int64_t>> {
static std::vector<int64_t> To(const ir::Attribute& attr) {
std::vector<int64_t> result;
if (attr.dyn_cast<ir::ArrayAttribute>()) {
auto array_attr = attr.dyn_cast<ir::ArrayAttribute>();
for (size_t i = 0; i < array_attr.size(); i++) {
result.push_back(
array_attr.at(i).dyn_cast<ir::Int64Attribute>().data());
}
return result;
} else if (attr.dyn_cast<paddle::dialect::IntArrayAttribute>()) {
result =
attr.dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData();
return result;
}
PADDLE_THROW("Dynamic cast failed for IR attribute vector<int64_t>");
}
};

class MatchContextImpl final {
public:
MatchContextImpl() = default;
Expand Down
61 changes: 37 additions & 24 deletions test/cpp/ir/pattern_rewrite/drr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,31 +45,33 @@ class RemoveRedundentReshapePattern
}
};

class FoldBroadcastToConstantPattern
: public ir::drr::DrrPatternBase<FoldBroadcastToConstantPattern> {
class FoldExpandToConstantPattern
: public ir::drr::DrrPatternBase<FoldExpandToConstantPattern> {
public:
void operator()(ir::drr::DrrPatternContext *ctx) const override {
ir::drr::SourcePattern pat = ctx->SourcePattern();
// Source Pattern 中可匹配的类型包括 Op 和 Tensor
const auto &fill_constant = pat.Op(
"pd.fill_constant",
{{"value", pat.Attr("value_1")}, {"dtype", pat.Attr("dtype_1")}});
const auto &broadcast_to =
pat.Op("pd.expand", {{"shape", pat.Attr("shape_1")}});
// 匹配fill_constant+broadcast_to,同时对输出张量标记为ret,方便后面加约束
pat.Tensor("ret") = broadcast_to(fill_constant());
// Constrains:本Pass无额外的约束规则
// Result patterns:要替换为的子图
ir::drr::SourcePattern pat = ctx->SourcePattern();
const auto &full1 = pat.Op("pd.full",
{{"shape", pat.Attr("shape_1")},
{"value", pat.Attr("value_1")},
{"dtype", pat.Attr("dtype_1")},
{"place", pat.Attr("place_1")}});
const auto &full_int_array1 =
pat.Op("pd.full_int_array",
{{"value", pat.Attr("expand_shape_value")},
{"dtype", pat.Attr("dtype_2")},
{"place", pat.Attr("place_2")}});
const auto &expand = pat.Op("pd.expand");
pat.Tensor("ret") = expand(full1(), full_int_array1());

// Result patterns:要替换为的子图. Constrains: 本Pass无额外约束规则
ir::drr::ResultPattern res = pat.ResultPattern();
// 使用 folded_fill_constant 替换
// broadcast_to(fill_constant()),注意shape属性已更新 所有 ret
// 参数均在Source Pattern中使用,对 ret 的赋值等同于对 ret 的 producer
// op的删除和重连接
const auto &folded_fill_constant = res.Op("pd.fill_constant",
{{"shape", res.Attr("shape_1")},
{"value", res.Attr("value_1")},
{"dtype", res.Attr("dtype_1")}});
res.Tensor("ret") = folded_fill_constant();
const auto &full2 = res.Op("pd.full",
{{"shape", pat.Attr("expand_shape_value")},
{"value", pat.Attr("value_1")},
{"dtype", pat.Attr("dtype_1")},
{"place", pat.Attr("place_1")}});
res.Tensor("ret") = full2();
}
};

Expand Down Expand Up @@ -133,14 +135,24 @@ class RemoveUselessCastPattern

void BuildProgram(ir::Builder &builder) { // NOLINT
paddle::dialect::FullOp full_input_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{4, 3, 16, 16},
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{4, 3, 16},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());

paddle::dialect::FullIntArrayOp full_int_array_op =
builder.Build<paddle::dialect::FullIntArrayOp>(
std::vector<int64_t>{4, 3, 16, 16},
phi::DataType::FLOAT32,
phi::CPUPlace());

paddle::dialect::ExpandOp expand_op =
builder.Build<paddle::dialect::ExpandOp>(full_input_op.out(),
full_int_array_op.out());

paddle::dialect::ReshapeOp reshape_op1 =
builder.Build<paddle::dialect::ReshapeOp>(
full_input_op.out(), std::vector<int64_t>{16, 3, 4, 16});
expand_op.out(), std::vector<int64_t>{16, 3, 4, 16});

paddle::dialect::ReshapeOp reshape_op2 =
builder.Build<paddle::dialect::ReshapeOp>(
Expand Down Expand Up @@ -179,6 +191,7 @@ class DrrPatternRewritePass : public ir::Pass {
ps.Add(RemoveRedundentTransposePattern().Build(context));
ps.Add(RemoveRedundentCastPattern().Build(context));
ps.Add(RemoveUselessCastPattern().Build(context));
ps.Add(FoldExpandToConstantPattern().Build(context));

patterns_ = ir::FrozenRewritePatternSet(std::move(ps));
return true;
Expand Down Expand Up @@ -206,7 +219,7 @@ TEST(DrrTest, drr) {
ir::Builder builder = ir::Builder(ctx, program.block());
BuildProgram(builder);

EXPECT_EQ(program.block()->size(), 12u);
EXPECT_EQ(program.block()->size(), 14u);

ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<DrrPatternRewritePass>());
Expand Down

0 comments on commit c251ca9

Please sign in to comment.