Skip to content

Commit

Permalink
[CINN] Refine fully_insert_broadcast_pass (#60676)
Browse files Browse the repository at this point in the history
* refine fully_insert_broadcast_pass

* fix complie bug

* fix complie

* fix conflict
  • Loading branch information
zyfncg authored Jan 12, 2024
1 parent d604bcd commit e328cf7
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 56 deletions.
6 changes: 2 additions & 4 deletions paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ if(NOT CINN_ONLY)
cinn_runtime_dialect
pir_compiler)

cc_library(
cinn_transforms
SRCS ${cinn_transforms_srcs}
DEPS ${cinn_transforms_deps})
cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs} DEPS
${cinn_transforms_deps})

endif()
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h"

#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
Expand Down Expand Up @@ -50,6 +50,14 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
}
pir::Value x = op->operand_source(0);
pir::Value y = op->operand_source(1);
pir::ShapeConstraintIRAnalysis& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
const auto& x_shape = shape_analysis.GetShapeOrDataForValue(x);
const auto& y_shape = shape_analysis.GetShapeOrDataForValue(y);
if (x_shape.shape() == y_shape.shape() && x_shape.data() == y_shape.data()) {
return false;
}

pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y);
{
pir::Value broadcasted_x =
Expand All @@ -67,7 +75,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
} // namespace

template <typename OPTYPE>
class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
class InsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
public:
using pir::OpRewritePattern<OPTYPE>::OpRewritePattern;

Expand All @@ -77,42 +85,46 @@ class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
}
};

FullyInsertBroadcastPass::FullyInsertBroadcastPass()
: pir::PatternRewritePass("fully_insert_broadcast_pass", 1) {}

pir::RewritePatternSet FullyInsertBroadcastPass::InitializePatterns(
pir::IrContext* context) {
pir::RewritePatternSet ps(context);
// elementwise ops
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::AddOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::SubtractOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MultiplyOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::DivideOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::ElementwisePowOp>>(
context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::RemainderOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::FloorDivideOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MaximumOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MinimumOp>>(context);

// compare ops
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessThanOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessEqualOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::EqualOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::NotEqualOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterThanOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterEqualOp>>(context);

// bitwise ops
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseOrOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseXorOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseNotOp>>(context);

return ps;
}
class InsertBroadcastPass : public pir::PatternRewritePass {
public:
InsertBroadcastPass() : pir::PatternRewritePass("insert_broadcast_pass", 1) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
// elementwise ops
ps.Add<InsertBroadcastPattern<paddle::dialect::AddOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::SubtractOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::MultiplyOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::DivideOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::ElementwisePowOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::RemainderOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::FloorDivideOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::MaximumOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::MinimumOp>>(context);

// compare ops
ps.Add<InsertBroadcastPattern<paddle::dialect::LessThanOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::LessEqualOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::EqualOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::NotEqualOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::GreaterThanOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::GreaterEqualOp>>(context);

// bitwise ops
ps.Add<InsertBroadcastPattern<paddle::dialect::BitwiseOrOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::BitwiseXorOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::BitwiseNotOp>>(context);

return ps;
}

bool CanApplyOn(pir::Operation* op) const override {
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
}
};

bool FullyInsertBroadcastPass::CanApplyOn(pir::Operation* op) const {
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
std::unique_ptr<pir::Pass> CreateInsertBroadcastPass() {
return std::make_unique<InsertBroadcastPass>();
}

} // namespace ir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,12 @@
#pragma once

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"

namespace cinn {
namespace dialect {
namespace ir {

class FullyInsertBroadcastPass : public pir::PatternRewritePass {
public:
FullyInsertBroadcastPass();

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override;

bool CanApplyOn(pir::Operation *op) const override;
};
IR_API std::unique_ptr<pir::Pass> CreateInsertBroadcastPass();

} // namespace ir
} // namespace dialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ bool InferSymbolicShapeElementWiseBinary(
std::vector<symbol::DimExpr> shapes;
symbol::DimExprBuilder builder{nullptr};
for (size_t i = 0; i < shape_0.size(); i++) {
shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i]));
if (shape_0[i] == shape_1[i]) {
shapes.emplace_back(shape_0[i]);
} else {
shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i]));
}
}

// TODO(lanxianghit): fill data when the operation is on shape computation
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/pir/transforms/shape_optimization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ void DebugPrintOpInfo(
}

void InferSymExprForAllValues(ModuleOp module_op) {
auto shape_analysis_mgr = ShapeAnalysisManager::Instance();
ShapeConstraintIRAnalysis& shape_analysis =
shape_analysis_mgr.Get(module_op.program());
ShapeAnalysisManager::Instance().Get(module_op.program());
for (uint32_t i = 0; i < module_op->num_regions(); i++) {
for (auto& block : module_op->region(i)) {
for (auto& op : block) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/pir/dialect/shape/utils/shape_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ class IR_API ShapeAnalysisManager {
static ShapeAnalysisManager& Instance();
ShapeConstraintIRAnalysis& Get(pir::Program* program);

ShapeAnalysisManager(const ShapeAnalysisManager&) = delete;
ShapeAnalysisManager(ShapeAnalysisManager&&) = delete;
ShapeAnalysisManager& operator=(const ShapeAnalysisManager&) = delete;

private:
ShapeAnalysisManager() {}
std::unordered_map<uint64_t, ShapeConstraintIRAnalysis> tables_;
Expand Down
7 changes: 4 additions & 3 deletions paddle/pir/pass/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ namespace pir {
namespace detail {

void PrintHeader(const std::string &header, std::ostream &os) {
unsigned padding = (80 - header.size()) / 2;
os << "===" << std::string(73, '-') << "===\n";
const size_t padding = 8;
size_t line_len = header.size() + ((padding - 3) * 2);
os << "===" << std::string(line_len, '-') << "===\n";
os << std::string(padding, ' ') << header << "\n";
os << "===" << std::string(73, '-') << "===\n";
os << "===" << std::string(line_len, '-') << "===\n";
}

} // namespace detail
Expand Down

0 comments on commit e328cf7

Please sign in to comment.