Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Refine fully_insert_broadcast_pass #60676

Merged
merged 7 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ if(NOT CINN_ONLY)
cinn_runtime_dialect
pir_compiler)

cc_library(
cinn_transforms
SRCS ${cinn_transforms_srcs}
DEPS ${cinn_transforms_deps})
cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs} DEPS
${cinn_transforms_deps})

endif()
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h"

#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
Expand Down Expand Up @@ -50,6 +50,14 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
}
pir::Value x = op->operand_source(0);
pir::Value y = op->operand_source(1);
pir::ShapeConstraintIRAnalysis& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
const auto& x_shape = shape_analysis.GetShapeOrDataForValue(x);
const auto& y_shape = shape_analysis.GetShapeOrDataForValue(y);
if (x_shape.shape() == y_shape.shape() && x_shape.data() == y_shape.data()) {
return false;
}

pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y);
{
pir::Value broadcasted_x =
Expand All @@ -67,7 +75,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
} // namespace

template <typename OPTYPE>
class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
class InsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
public:
using pir::OpRewritePattern<OPTYPE>::OpRewritePattern;

Expand All @@ -77,42 +85,46 @@ class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
}
};

FullyInsertBroadcastPass::FullyInsertBroadcastPass()
: pir::PatternRewritePass("fully_insert_broadcast_pass", 1) {}

pir::RewritePatternSet FullyInsertBroadcastPass::InitializePatterns(
pir::IrContext* context) {
pir::RewritePatternSet ps(context);
// elementwise ops
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::AddOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::SubtractOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MultiplyOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::DivideOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::ElementwisePowOp>>(
context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::RemainderOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::FloorDivideOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MaximumOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MinimumOp>>(context);

// compare ops
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessThanOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessEqualOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::EqualOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::NotEqualOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterThanOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterEqualOp>>(context);

// bitwise ops
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseOrOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseXorOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseNotOp>>(context);

return ps;
}
class InsertBroadcastPass : public pir::PatternRewritePass {
public:
InsertBroadcastPass() : pir::PatternRewritePass("insert_broadcast_pass", 1) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
// elementwise ops
ps.Add<InsertBroadcastPattern<paddle::dialect::AddOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::SubtractOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::MultiplyOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::DivideOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::ElementwisePowOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::RemainderOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::FloorDivideOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::MaximumOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::MinimumOp>>(context);

// compare ops
ps.Add<InsertBroadcastPattern<paddle::dialect::LessThanOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::LessEqualOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::EqualOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::NotEqualOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::GreaterThanOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::GreaterEqualOp>>(context);

// bitwise ops
ps.Add<InsertBroadcastPattern<paddle::dialect::BitwiseOrOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::BitwiseXorOp>>(context);
ps.Add<InsertBroadcastPattern<paddle::dialect::BitwiseNotOp>>(context);

return ps;
}

bool CanApplyOn(pir::Operation* op) const override {
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
}
};

bool FullyInsertBroadcastPass::CanApplyOn(pir::Operation* op) const {
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
std::unique_ptr<pir::Pass> CreateInsertBroadcastPass() {
return std::make_unique<InsertBroadcastPass>();
}

} // namespace ir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,12 @@
#pragma once

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"

namespace cinn {
namespace dialect {
namespace ir {

class FullyInsertBroadcastPass : public pir::PatternRewritePass {
public:
FullyInsertBroadcastPass();

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override;

bool CanApplyOn(pir::Operation *op) const override;
};
IR_API std::unique_ptr<pir::Pass> CreateInsertBroadcastPass();

} // namespace ir
} // namespace dialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ bool InferSymbolicShapeElementWiseBinary(
std::vector<symbol::DimExpr> shapes;
symbol::DimExprBuilder builder{nullptr};
for (size_t i = 0; i < shape_0.size(); i++) {
shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i]));
if (shape_0[i] == shape_1[i]) {
shapes.emplace_back(shape_0[i]);
} else {
shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i]));
}
}

// TODO(lanxianghit): fill data when the operation is on shape computation
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/pir/transforms/shape_optimization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ void DebugPrintOpInfo(
}

void InferSymExprForAllValues(ModuleOp module_op) {
auto shape_analysis_mgr = ShapeAnalysisManager::Instance();
ShapeConstraintIRAnalysis& shape_analysis =
shape_analysis_mgr.Get(module_op.program());
ShapeAnalysisManager::Instance().Get(module_op.program());
for (uint32_t i = 0; i < module_op->num_regions(); i++) {
for (auto& block : module_op->region(i)) {
for (auto& op : block) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/pir/dialect/shape/utils/shape_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ class IR_API ShapeAnalysisManager {
static ShapeAnalysisManager& Instance();
ShapeConstraintIRAnalysis& Get(pir::Program* program);

ShapeAnalysisManager(const ShapeAnalysisManager&) = delete;
ShapeAnalysisManager(ShapeAnalysisManager&&) = delete;
ShapeAnalysisManager& operator=(const ShapeAnalysisManager&) = delete;

private:
ShapeAnalysisManager() {}
std::unordered_map<uint64_t, ShapeConstraintIRAnalysis> tables_;
Expand Down
7 changes: 4 additions & 3 deletions paddle/pir/pass/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ namespace pir {
namespace detail {

void PrintHeader(const std::string &header, std::ostream &os) {
unsigned padding = (80 - header.size()) / 2;
os << "===" << std::string(73, '-') << "===\n";
const size_t padding = 8;
size_t line_len = header.size() + ((padding - 3) * 2);
os << "===" << std::string(line_len, '-') << "===\n";
os << std::string(padding, ' ') << header << "\n";
os << "===" << std::string(73, '-') << "===\n";
os << "===" << std::string(line_len, '-') << "===\n";
}

} // namespace detail
Expand Down