Skip to content

Commit

Permalink
[CINN]Fix AutoSimplify in block(i, 0, 0) case (#65225)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurelius84 authored and pull[bot] committed Aug 15, 2024
1 parent 8f61cc2 commit 4102377
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 59 deletions.
94 changes: 45 additions & 49 deletions paddle/cinn/optim/ir_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,31 +42,60 @@ using utils::Replace;

namespace {

bool TryEmplaceVarIntervals(const For& op,
cinn::common::cas_intervals_t* var_intervals) {
VLOG(4) << "TryEmplaceVarIntervals with min: " << op.min << ", " << op.extent;
auto* min_i = op.min.As<IntImm>();
auto* extent_i = op.extent.As<IntImm>();
// For containing zero Shape case, skip it.
if (extent_i && extent_i->value <= 0) return false;

if (min_i && extent_i) {
var_intervals->emplace(
op.loop_var->name,
cinn::common::CasInterval{min_i->value, extent_i->value - 1});
} else {
var_intervals->emplace(op.loop_var->name,
cinn::common::CasInterval{op.min, op.extent - 1});
}
return true;
}

bool TryEraseVarIntervals(const For& op,
cinn::common::cas_intervals_t* var_intervals) {
auto* min_i = op.min.As<IntImm>();
auto* extent_i = op.extent.As<IntImm>();
const auto& name = op.loop_var->name;
const bool should_erase = min_i && extent_i && var_intervals->count(name);
if (should_erase) {
var_intervals->erase(name);
}
return should_erase;
}

//! Simplify some sub-expression in the `expr`. Due to the simplify strategy
//! just fit several kinds of IR nodes, we partition the original expression to
//! several sub-expression those supported by simplify, and process each of
//! them.
void PartialSimplify(
Expr* expr,
const absl::flat_hash_map<std::string, cinn::common::CasInterval>&
var_intervals = {}) {
void PartialSimplify(Expr* expr,
const cinn::common::cas_intervals_t& var_intervals = {}) {
*expr = cinn::common::AutoSimplify(*expr, var_intervals);
}

//! Simplify the expression but Load.
struct SimplifyNoPureMathMutator : public ir::IRMutator<ir::Expr*> {
cinn::common::cas_intervals_t& var_intervals;
cinn::common::cas_intervals_t& var_intervals_;
explicit SimplifyNoPureMathMutator(
cinn::common::cas_intervals_t& var_intervals) // NOLINT
: var_intervals(var_intervals) {}
: var_intervals_(var_intervals) {}

void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }

using ir::IRMutator<>::Visit;

#define __(op__) \
void Visit(const op__* op, Expr* expr) override { \
PartialSimplify(expr, var_intervals); \
PartialSimplify(expr, var_intervals_); \
}

__(Add)
Expand All @@ -89,31 +118,19 @@ struct SimplifyNoPureMathMutator : public ir::IRMutator<ir::Expr*> {
auto* node = expr->As<ir::For>();
Visit(&node->min, &node->min);
Visit(&node->extent, &node->extent);
auto* min_i = op->min.As<IntImm>();
auto* extent_i = op->extent.As<IntImm>();
if (min_i && extent_i && extent_i->value > min_i->value) {
var_intervals.emplace(
op->loop_var->name,
cinn::common::CasInterval{min_i->value, extent_i->value - 1});
} else {
var_intervals.emplace(op->loop_var->name,
cinn::common::CasInterval{op->min, op->extent - 1});
}

TryEmplaceVarIntervals(*op, &var_intervals_);
Visit(&node->body, &node->body);
if (min_i && extent_i) {
var_intervals.erase(op->loop_var->name);
}
TryEraseVarIntervals(*op, &var_intervals_);
}

void Visit(const _Tensor_* op, Expr* expr) override {
auto* node = expr->As<ir::_Tensor_>();

for (auto& e : node->shape) {
PartialSimplify(&e, var_intervals);
PartialSimplify(&e, var_intervals_);
}
for (auto& e : node->domain) {
PartialSimplify(&e, var_intervals);
PartialSimplify(&e, var_intervals_);
}
}
};
Expand All @@ -134,22 +151,12 @@ struct SimplifyLoadMutator : public ir::IRMutator<ir::Expr*> {
}

void Visit(const For* op, Expr* expr) override {
auto* min_i = op->min.As<IntImm>();
auto* extent_i = op->extent.As<IntImm>();
if (min_i && extent_i && extent_i->value > min_i->value) {
var_intervals_.emplace(
op->loop_var->name,
cinn::common::CasInterval{min_i->value, extent_i->value - 1});
}

TryEmplaceVarIntervals(*op, &var_intervals_);
auto* node = expr->As<For>();

operator()(&node->body);
operator()(&node->extent);

if (min_i && extent_i) {
var_intervals_.erase(op->loop_var->name);
}
TryEraseVarIntervals(*op, &var_intervals_);
}

cinn::common::cas_intervals_t var_intervals_;
Expand All @@ -172,22 +179,12 @@ struct SimplifyStoreMutator : public ir::IRMutator<ir::Expr*> {
}

void Visit(const For* op, Expr* expr) override {
auto* min_i = op->min.As<IntImm>();
auto* extent_i = op->extent.As<IntImm>();
if (min_i && extent_i) {
var_intervals_.emplace(
op->loop_var->name,
cinn::common::CasInterval{min_i->value, extent_i->value - 1});
}

TryEmplaceVarIntervals(*op, &var_intervals_);
auto* node = expr->As<For>();

operator()(&node->body);
operator()(&node->extent);

if (min_i && extent_i) {
var_intervals_.erase(op->loop_var->name);
}
TryEraseVarIntervals(*op, &var_intervals_);
}

cinn::common::cas_intervals_t var_intervals_;
Expand Down Expand Up @@ -350,8 +347,7 @@ struct SimplifyForLoopsMutator : public ir::IRMutator<> {
Visit(&node->extent, &node->extent);
auto* min_i = node->min.As<IntImm>();
auto* extent_i = node->extent.As<IntImm>();
if (min_i && extent_i && extent_i->value > min_i->value &&
extent_i->value - min_i->value == 1) {
if (min_i && extent_i && extent_i->value - min_i->value == 1) {
VLOG(6) << "Simplify current For Loop";
std::string var_name = node->loop_var->name;
var_intervals.emplace(
Expand Down
21 changes: 11 additions & 10 deletions paddle/cinn/optim/vectorize_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -731,12 +731,12 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
void Visit(const For *forloop, Expr *expr) {
auto *node = expr->As<For>();
auto loop_var_name = forloop->loop_var->name;
if (forloop->extent.As<IntImm>()) {
var_intervals.emplace(
loop_var_name,
cinn::common::CasInterval{static_cast<int64_t>(0),
forloop->extent.as_int64() - 1});
} else {
auto *extern_i = forloop->extent.As<IntImm>();
if (extern_i && extern_i->value > 0) {
var_intervals.emplace(loop_var_name,
cinn::common::CasInterval{static_cast<int64_t>(0),
extern_i->value - 1});
} else if (!extern_i) {
var_intervals.emplace(
loop_var_name,
cinn::common::CasInterval{Expr(0), forloop->extent - 1});
Expand Down Expand Up @@ -962,6 +962,7 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
auto *extent_ptr = forloop->extent.As<IntImm>();
Expr times;
if (extent_ptr) {
if (extent_ptr->value == 0) return Expr();
int extent_int = forloop->extent.as_int32();
int extent_trunc = extent_int / factor;
int extent_times =
Expand All @@ -978,10 +979,10 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
forloop->set_vectorized(false);

forloop->extent = times;
if (times_int && forloop->extent.as_int32() >= 1) {
var_intervals.emplace(
forloop->loop_var->name,
cinn::common::CasInterval{0, forloop->extent.as_int32() - 1});
if (times_int) {
var_intervals.emplace(forloop->loop_var->name,
cinn::common::CasInterval{static_cast<int64_t>(0),
times_int->value - 1});
} else {
var_intervals.erase(forloop->loop_var->name);
var_intervals.emplace(
Expand Down

0 comments on commit 4102377

Please sign in to comment.