diff --git a/cinn/backends/codegen_c.cc b/cinn/backends/codegen_c.cc index bb9aab9e8992a..fa160f14f81a2 100644 --- a/cinn/backends/codegen_c.cc +++ b/cinn/backends/codegen_c.cc @@ -53,6 +53,7 @@ std::string CodeGenC::Compile(const lang::Module &module, OutputKind output_kind return ss_.str(); } std::string CodeGenC::Compile(const ir::LoweredFunc &function) { + CHECK(function.defined()); Print(function); os() << "\n\n"; return ss_.str(); diff --git a/cinn/common/arithmatic.cc b/cinn/common/arithmatic.cc index f9b6211e4564d..19a0ccbfd8918 100644 --- a/cinn/common/arithmatic.cc +++ b/cinn/common/arithmatic.cc @@ -226,13 +226,16 @@ bool IsPureMath(Expr expr) { IrNodeTy ::Sub, IrNodeTy ::Div, IrNodeTy ::Mul, + IrNodeTy::Mod, IrNodeTy ::Minus, }); auto complex_nodes = ir::CollectIRNodes(expr, [&](const Expr* n) { return !valid_node_tys.count(n->node_type()); }); +#ifdef CINN_DEBUG for (auto& node : complex_nodes) { VLOG(3) << "Found " << node->node_type() << " " << Expr(node); } +#endif return complex_nodes.empty(); } diff --git a/cinn/common/cas.cc b/cinn/common/cas.cc index c5331310824ce..e379207a4c19f 100644 --- a/cinn/common/cas.cc +++ b/cinn/common/cas.cc @@ -94,6 +94,71 @@ Expr Exponent(Expr v) { namespace detail { +// Is a Divisiable to b. +// @{ +bool IsDivisible(int64_t a, int64_t b) { + CHECK_NE(b, 0); + return a % b == 0; +} +bool IsDivisible(const Sum* a, int b); +bool IsDivisible(const Product* a, int b) { + for (auto& item : a->operands()) { + if (item.As() && IsDivisible(item.As()->value, b)) return true; + if (item.As() && IsDivisible(item.As(), b)) return true; + } + return false; +} +bool IsDivisible(const Sum* a, int b) { + for (auto& item : a->operands()) { + auto* vi = item.As(); + auto* vp = item.As(); + if (vi && IsDivisible(vi->value, b)) continue; + if (vp && IsDivisible(vp, b)) continue; + return false; + } + return true; +} +bool IsDivisible(Expr a, int b) { + auto* ai = a.As(); + auto* as = a.As(); + auto* ap = a.As(); + + if (ai) return IsDivisible(ai->value, b); + if (as) return IsDivisible(as, b); + if (ap) return IsDivisible(ap, b); + return false; +} +// @} + +//! Divide a by b, NOTE that a should be divisible by b. +// @{ +Expr Divide(const Product* a, int b); +Expr Divide(const Sum* a, int b) { + std::vector args; + for (auto& item : a->operands()) { + if (item.As()) + args.push_back(make_const(item.type(), item.As()->value / b)); + else if (item.As()) + args.push_back(Divide(item.As(), b)); + else + NOT_IMPLEMENTED + } + return Sum::Make(args); +} +Expr Divide(const Product* a, int b) { + auto* a_first_i = a->operand(0).As(); + CHECK(a_first_i); + int times = a_first_i->value / b; + if (times == 1) { + return Product::Make(Rest(a->operands())); + } else { + auto args = Rest(a->operands()); + args.insert(std::begin(args), make_const(a->type(), times)); + return Product::Make(args); + } +} +// @} + inline int Iquot(int n, int d) { return n / d; } inline int Irem(int n, int d) { @@ -798,33 +863,59 @@ Expr CasSimplifyMutator::SimplifyMod(Expr u) { auto a = CasSimplify(node->a(), var_intervals); auto b = CasSimplify(node->b(), var_intervals); - auto* ai = a.As(); - auto* bi = b.As(); + auto* a_i = a.As(); + auto* a_product = a.As(); + auto* a_sum = a.As(); + auto* a_var = a.As<_Var_>(); + auto* a_mod = a.As(); + + auto* b_i = b.As(); + // 7 % 3 - if (ai && bi) { - return make_const(ai->type(), ai->value % bi->value); + if (a_i && b_i) { + return make_const(a_i->type(), a_i->value % b_i->value); } // x % 1 = 0 - if (bi && bi->value == 1) return make_const(bi->type(), 0); + if (b_i && b_i->value == 1) return make_const(b_i->type(), 0); // 2x % 2 = 0 - if (bi) { - auto* ap = a.As(); - if (ap && ap->operand(0).As()) { - if (ap->operand(0).As()->value % bi->value == 0) return make_const(ap->type(), 0); - } + if (b_i && a_product && a_product->operand(0).As()) { + if (a_product->operand(0).As()->value % b_i->value == 0) return make_const(a_product->type(), 0); + } + + // 0 % x = 1, 1 % x = 1 + if (a_i && (a_i->value == 0 || a_i->value == 1)) return a; + + if (b_i && a_var && var_intervals.count(a_var->name)) { + auto& interval = var_intervals.at(a_var->name); + int b_abs = std::abs(b_i->value); + // x\in[1, 3] % 4 = x + if (std::abs(interval.l) < b_abs && std::abs(interval.r) < b_abs) return a; + // [3,3] % 3 = 0 + if (interval.l == interval.r && interval.l % b_abs == 0) return make_const(b_i->type(), 0); } - if (ai && (ai->value == 0 || ai->value == 1)) return a; + if (a_product && b_i) { + if (IsDivisible(a_product, b_i->value)) { + return make_const(Int(32), 0); + } + } - // (x+y) % 2 = x%2 + y%2 - if (a.As()) { + // (2x+y+z) % 2 = (y+z) % 2 + if (a_sum && b_i) { std::vector sum_args; - for (auto& v : a->operands) { - sum_args.push_back(Mod::Make(v, b)); + for (auto& v : a_sum->operands()) { + if (!IsDivisible(v, b_i->value)) { + sum_args.push_back(v); + } } - return CasSimplify(Sum::Make(sum_args), var_intervals); + if (sum_args.size() == a_sum->operands().size()) return Mod::Make(a, b); + if (sum_args.empty()) return make_const(b_i->type(), 0); + if (sum_args.size() == 1) { + return SimplifyMod(Mod::Make(sum_args.front(), b)); + } + return SimplifyMod(Mod::Make(Sum::Make(sum_args), b)); } return Mod::Make(a, b); @@ -1030,61 +1121,6 @@ bool IsExprCasCompatible(Expr expr) { return ir::CollectIRNodes(expr, teller).empty(); } -bool IsDivisible(int64_t a, int64_t b) { - CHECK_NE(b, 0); - return a % b == 0; -} - -bool IsDivisible(const Sum* a, int b); - -bool IsDivisible(const Product* a, int b) { - for (auto& item : a->operands()) { - if (item.As() && IsDivisible(item.As()->value, b)) return true; - if (item.As() && IsDivisible(item.As(), b)) return true; - } - return false; -} - -bool IsDivisible(const Sum* a, int b) { - for (auto& item : a->operands()) { - auto* vi = item.As(); - auto* vp = item.As(); - if (vi && IsDivisible(vi->value, b)) continue; - if (vp && IsDivisible(vp, b)) continue; - return false; - } - return true; -} - -//! Divide a by b, NOTE that a should be divisible by b. -// @{ -Expr Divide(const Product* a, int b); -Expr Divide(const Sum* a, int b) { - std::vector args; - for (auto& item : a->operands()) { - if (item.As()) - args.push_back(make_const(item.type(), item.As()->value / b)); - else if (item.As()) - args.push_back(Divide(item.As(), b)); - else - NOT_IMPLEMENTED - } - return Sum::Make(args); -} -Expr Divide(const Product* a, int b) { - auto* a_first_i = a->operand(0).As(); - CHECK(a_first_i); - int times = a_first_i->value / b; - if (times == 1) { - return Product::Make(Rest(a->operands())); - } else { - auto args = Rest(a->operands()); - args.insert(std::begin(args), make_const(a->type(), times)); - return Product::Make(args); - } -} -// @} - // Partially divide a by b. e.g. (2x+y)/2 => x + y/2 Expr DividePartially(Sum* a, int b) { std::vector external_sum_args, sum_args; @@ -1138,7 +1174,6 @@ Expr CasSimplifyMutator::FurtherSimplifyFracWithInterval( if (bi) { if (av) { auto it = var_intervals.find(av->name); - if (it != var_intervals.end()) LOG(INFO) << "found " << av->name; if (it != var_intervals.end() && std::abs(it->second.r) < std::abs(bi->value) && std::abs(it->second.l) < std::abs(bi->value)) return make_const(a.type(), 0); diff --git a/cinn/common/cas.h b/cinn/common/cas.h index a5f32a950e1db..10bb9216cf1a2 100644 --- a/cinn/common/cas.h +++ b/cinn/common/cas.h @@ -10,7 +10,7 @@ namespace common { * Interval of a _Var_. */ struct CasInterval { - CasInterval(int l, int r) : l(l), r(r) {} + CasInterval(int64_t l, int64_t r) : l(l), r(r) {} int l, r; friend std::ostream& operator<<(std::ostream& os, const CasInterval& i) { @@ -19,6 +19,8 @@ struct CasInterval { } }; +using cas_intervals_t = std::unordered_map; + Expr AutoSimplify(Expr u, const std::unordered_map& var_intervals = {}); //! Simplify a CAS expression. diff --git a/cinn/common/cas_test.cc b/cinn/common/cas_test.cc index b126c44dc26e4..0c5453ee72c76 100644 --- a/cinn/common/cas_test.cc +++ b/cinn/common/cas_test.cc @@ -180,7 +180,7 @@ TEST(CAS, SimplifyMod) { {Power::Make(x, Expr(3)), Mod::Make(x, Expr(5)), y, Expr(1), Mod::Make(Product::Make({x, Expr(4)}), Expr(5))})); EXPECT_EQ(GetStreamCnt(u1), "0"); - EXPECT_EQ(GetStreamCnt(u2), "((x % 2) + (y % 2) + (z % 2))"); + EXPECT_EQ(GetStreamCnt(u2), "((x + y + z) % 2)"); EXPECT_EQ(GetStreamCnt(u3), "1"); EXPECT_EQ(GetStreamCnt(u4), "(1 + (x^3) + y)"); } @@ -222,6 +222,48 @@ TEST(CAS, FracOp) { EXPECT_EQ(GetStreamCnt(u4), "((32768 * (y/32)) + (32768 * x))"); } +#define OUTPUT_EQUAL(s__) EXPECT_EQ(GetStreamCnt(u), s__); + +TEST(CAS, Mod) { + Var x = ir::_Var_::Make("x", Int(32)); + Var y = ir::_Var_::Make("y", Int(32)); + Var z = ir::_Var_::Make("z", Int(32)); + Var k = ir::_Var_::Make("k", Int(32)); + + std::unordered_map var_intervals0, var_intervals1; + var_intervals0.emplace("x", CasInterval{0, 3}); + var_intervals0.emplace("y", CasInterval{0, 3}); + var_intervals0.emplace("z", CasInterval{0, 3}); + var_intervals0.emplace("k", CasInterval{0, 3}); + + Expr u; + u = AutoSimplify(x % 5); + EXPECT_EQ(GetStreamCnt(u), "(x % 5)"); + OUTPUT_EQUAL("(x % 5)") + + u = AutoSimplify((5 + x) % 5); + OUTPUT_EQUAL("(x % 5)") + + u = AutoSimplify((x + 5 * y + 1 + 1 + 3 - z * 3) % 5); + OUTPUT_EQUAL("((x + (-3 * z)) % 5)") + + u = AutoSimplify((x + 5) % 5, var_intervals0); + OUTPUT_EQUAL("x") + + u = AutoSimplify((x + y + 5) % 5, var_intervals0); + OUTPUT_EQUAL("((x + y) % 5)") + + u = AutoSimplify((x + 20 * y + 5) % 5, var_intervals0); + OUTPUT_EQUAL("x") + + u = AutoSimplify((x % 32) + ((32768 * (x / 32)) + ((32768 * y) + ((32 * z) + (128 * k))))); + OUTPUT_EQUAL("((32768 * (x/32)) + ((x % 32) + ((128 * k) + ((32768 * y) + (32 * z)))))"); + + u = AutoSimplify((x % 32) + ((32768 * (x / 32)) + ((32768 * y) + ((32 * z) + (128 * k)))), var_intervals0); + OUTPUT_EQUAL("((128 * k) + (x + ((32768 * y) + (32 * z))))") + LOG(INFO) << u; +} + TEST(CAS, IntConnerCase) { Var x = ir::_Var_::Make("x", Int(32)); Var y = ir::_Var_::Make("y", Int(32)); diff --git a/cinn/ir/ir.cc b/cinn/ir/ir.cc index 8419d48305dbd..db10a99316ddd 100644 --- a/cinn/ir/ir.cc +++ b/cinn/ir/ir.cc @@ -448,6 +448,12 @@ Expr Load::index() const { return res; } +const std::string &Load::name() const { + auto *t = tensor.As(); + CHECK(t); + return t->name; +} + Expr Ramp::Make(Expr base, Expr stride, int lanes) { CHECK(base.defined()); CHECK(stride.defined()); diff --git a/cinn/ir/ir.h b/cinn/ir/ir.h index 1b721f73e8c5d..c0d5d7dee2804 100644 --- a/cinn/ir/ir.h +++ b/cinn/ir/ir.h @@ -398,6 +398,8 @@ struct Load : public ExprNode { std::vector expr_fields() override; std::vector expr_fields() const override; + const std::string& name() const; + Type type() const override; static const IrNodeTy _node_type_ = IrNodeTy::Load; diff --git a/cinn/lang/lower.cc b/cinn/lang/lower.cc index 2e3abf0cc9f4d..af5393e412ebd 100644 --- a/cinn/lang/lower.cc +++ b/cinn/lang/lower.cc @@ -85,119 +85,6 @@ struct MarkUnrollMutator : public ir::IRMutator { std::vector stack; }; -/** - * Expand the split transforms. - * This should takes the expression generated from isl ast as input(without relacing the statement with the real - * computation), it takes each Call to identify the statement. Each time it can only deal with one statement. - * - * NOTE this is discarded, to be clean up latter. - */ -struct SplitExpandMutator : public ir::IRMutator { - SplitExpandMutator(const std::string& statement, const std::map& strategies) - : statement_(statement), strategies_(strategies) {} - - void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } - - void Visit(const ir::PolyFor* op, Expr* expr) override { - auto* node = expr->As(); - forloop_stack_.push(expr); - - ir::IRMutator<>::Visit(op, expr); - - forloop_stack_.pop(); - - // The Split transform always split one forloop into outer and inner, and we do separation on the inner one, so - // there should be at least one forloop remaining if the current level is the inner. - if (!forloop_stack_.empty()) { - if (cur_statement_ == statement_ && strategies_.count(op->iterator->name) && - strategies_.at(op->iterator->name) == poly::SplitRestStrategy::kSeparate) { - auto* outer = forloop_stack_.top()->As(); - DoSaparation(outer, expr); - } - } - } - - void Visit(const ir::Call* op, Expr* expr) override { - auto* node = expr->As(); - // We reach the call node that represents the statment, just mark the current(innermost) forloop to separate. - if (node->call_type == ir::Call::CallType::ISL) { - cur_statement_ = node->name; - } - } - - //! Do the separation on the \p inner forloop, add new forloops to \p outer forloop. - void DoSaparation(ir::PolyFor* outer, Expr* inner) { - VLOG(3) << "Doing separation"; - auto* inner_node = inner->As(); - CHECK(inner_node); - auto* lt = inner_node->condition.As(); - auto* le = inner_node->condition.As(); - - auto create_forloop = [&](Expr cond) { - return ir::PolyFor::Make(inner_node->iterator, - inner_node->init, - cond, - inner_node->inc, - inner_node->for_type, - inner_node->device_api, - inner_node->body); - }; - - auto insert_new_forloops_to_upper = [&](ir::PolyFor* origin, Expr if_then_else) { - auto* outer_block = outer->body.As(); - CHECK(outer_block); - auto it = std::find_if(outer_block->stmts.begin(), outer_block->stmts.end(), [&](const Expr& e) { - auto* a_for = e.As(); - if (!a_for) return false; - return a_for == origin; - }); - CHECK(it != outer_block->stmts.end()); - - *it = if_then_else; - }; - - Expr cond0, cond1; - if (!(lt || le)) { - LOG(ERROR) << "The condition of the forloop don't contains LT or LE operator, skip seperation, the condition is " - << inner_node->condition; - return; - } - - ir::Min* min_n = lt ? lt->b().As() : le->b().As(); - - if (min_n) { - auto upper_bound0 = min_n->a(); - auto upper_bound1 = min_n->b(); - - Expr forloop0, forloop1; - if (lt) { - forloop0 = create_forloop(ir::LT::Make(Expr(inner_node->iterator), upper_bound0)); - forloop1 = create_forloop(ir::LT::Make(Expr(inner_node->iterator), upper_bound1)); - } else { - forloop0 = create_forloop(ir::LE::Make(Expr(inner_node->iterator), upper_bound0)); - forloop1 = create_forloop(ir::LE::Make(Expr(inner_node->iterator), upper_bound1)); - } - - // the new forloops should be wrapped by a if-then-else - Expr if_then_else_cond = ir::LE::Make(upper_bound0, upper_bound1); - auto if_then_else = ir::IfThenElse::Make(if_then_else_cond, forloop0, forloop1); - VLOG(2) << "Separate two new forloops"; - VLOG(2) << forloop0; - VLOG(2) << forloop1; - insert_new_forloops_to_upper(inner_node, if_then_else); - } - } - - private: - std::string statement_; - //! A stack to record the forloops call stack to the current statement. - std::stack forloop_stack_; - const std::map& strategies_; - ir::Expr* forloop_to_separate_{}; - //! The statement in the innermost forloop, used to determine whether the forloops in the stack need to separate. - std::string cur_statement_; -}; // namespace lang - //! Lower a single group. A LoweredFunc is composed of several group. Expr LowerGroup(const poly::ScheduleGroup& group, const std::map& tuple_to_expr) { std::vector stages; @@ -212,13 +99,6 @@ Expr LowerGroup(const poly::ScheduleGroup& group, const std::mapid() << " " << stage->split_strageties().size() << " strategies"; - SplitExpandMutator(stage->id(), stage->split_strageties())(&e); - } - */ - // replace call to the corresponding statement for (auto& statement : tuple_to_expr) { auto axis_ast_map = gen.axis2ast(statement.first); diff --git a/cinn/optim/CMakeLists.txt b/cinn/optim/CMakeLists.txt index f56acaa4c6035..dfee9c6708c24 100644 --- a/cinn/optim/CMakeLists.txt +++ b/cinn/optim/CMakeLists.txt @@ -7,7 +7,6 @@ set(srcs remove_nested_block.cc replace_call_with_expr.cc ir_copy.cc vectorize_loops.cc unroll_loops.cc transform_polyfor_to_for.cc - ir_eliminate_mod.cc ) foreach(cpp ${srcs}) diff --git a/cinn/optim/ir_eliminate_mod.cc b/cinn/optim/ir_eliminate_mod.cc deleted file mode 100644 index 0b6b82efb9fc4..0000000000000 --- a/cinn/optim/ir_eliminate_mod.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "cinn/optim/ir_eliminate_mod.h" - -#include "cinn/ir/ir_mutator.h" -#include "cinn/ir/ir_printer.h" - -namespace cinn { -namespace optim { - -void IrEliminateMod(Expr* expr) { - struct Modifier : public ir::IRMutator { - void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } - - void Visit(const ir::Mod* op, Expr* expr) { - LOG(ERROR) << "Not Implemented"; - LOG(INFO) << "elimiate " << *expr << " to " << op->a(); - *expr = op->a(); - } - }; - - Modifier()(expr); -} - -} // namespace optim -} // namespace cinn diff --git a/cinn/optim/ir_eliminate_mod.h b/cinn/optim/ir_eliminate_mod.h deleted file mode 100644 index 0cb6edf79cd16..0000000000000 --- a/cinn/optim/ir_eliminate_mod.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#include "cinn/ir/ir.h" - -namespace cinn { -namespace optim { - -void IrEliminateMod(Expr* expr); - -} // namespace optim -} // namespace cinn diff --git a/cinn/optim/ir_simplify.cc b/cinn/optim/ir_simplify.cc index d632e6302784e..b57148dbe4a86 100644 --- a/cinn/optim/ir_simplify.cc +++ b/cinn/optim/ir_simplify.cc @@ -25,16 +25,21 @@ namespace { //! Simplify some sub-expression in the `expr`. Due to the simplify strategy just fit several kinds of IR noedes, we //! partition the original expression to several sub-expression those supported by simplify, and process each of them. -void PartialSimplify(Expr* expr) { *expr = common::AutoSimplify(*expr); } +void PartialSimplify(Expr* expr, const std::unordered_map& var_intervals = {}) { + *expr = common::AutoSimplify(*expr, var_intervals); +} //! Simplify the expression but Load. struct SimplifyButStoreLoadMutator : public ir::IRMutator { + const common::cas_intervals_t& var_intervals; + explicit SimplifyButStoreLoadMutator(const common::cas_intervals_t& var_intervals) : var_intervals(var_intervals) {} + void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } using ir::IRMutator<>::Visit; #define __(op__) \ - void Visit(const op__* op, Expr* expr) override { PartialSimplify(expr); } + void Visit(const op__* op, Expr* expr) override { PartialSimplify(expr, var_intervals); } __(Add) __(Mul) @@ -46,10 +51,10 @@ struct SimplifyButStoreLoadMutator : public ir::IRMutator { auto* node = expr->As(); CHECK(common::IsPureMath(node->base)); CHECK(common::IsPureMath(node->stride)); - PartialSimplify(&node->base); - PartialSimplify(&node->stride); + PartialSimplify(&node->base, var_intervals); + PartialSimplify(&node->stride, var_intervals); } -}; // namespace +}; struct SimplifyLoadMutator : public ir::IRMutator { void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } @@ -57,14 +62,33 @@ struct SimplifyLoadMutator : public ir::IRMutator { void Visit(const Load* expr, Expr* op) override { auto* node = op->As(); for (auto& idx : node->indices) { - if (common::IsPureMath(node->index())) { - PartialSimplify(&idx); + if (common::IsPureMath(idx)) { + PartialSimplify(&idx, var_intervals_); } else { - SimplifyButStoreLoadMutator mutator; + SimplifyButStoreLoadMutator mutator(var_intervals_); mutator(&idx); } } } + + void Visit(const For* op, Expr* expr) override { + auto* min_i = op->min.As(); + auto* extent_i = op->extent.As(); + if (min_i && extent_i) { + var_intervals_.emplace(op->loop_var->name, common::CasInterval{min_i->value, extent_i->value - 1}); + } + + auto* node = expr->As(); + + operator()(&node->body); + operator()(&node->extent); + + if (min_i && extent_i) { + var_intervals_.erase(op->loop_var->name); + } + } + + common::cas_intervals_t var_intervals_; }; struct SimplifyStoreMutator : public ir::IRMutator { @@ -74,14 +98,33 @@ struct SimplifyStoreMutator : public ir::IRMutator { auto* node = op->As(); for (auto& idx : node->indices) { - if (common::IsPureMath(node->index())) { - PartialSimplify(&idx); + if (common::IsPureMath(idx)) { + PartialSimplify(&idx, var_intervals_); } else { - SimplifyButStoreLoadMutator mutator; + SimplifyButStoreLoadMutator mutator(var_intervals_); mutator(&idx); } } } + + void Visit(const For* op, Expr* expr) override { + auto* min_i = op->min.As(); + auto* extent_i = op->extent.As(); + if (min_i && extent_i) { + var_intervals_.emplace(op->loop_var->name, common::CasInterval{min_i->value, extent_i->value - 1}); + } + + auto* node = expr->As(); + + operator()(&node->body); + operator()(&node->extent); + + if (min_i && extent_i) { + var_intervals_.erase(op->loop_var->name); + } + } + + common::cas_intervals_t var_intervals_; }; struct SimplifyRampMutator : public ir::IRMutator { @@ -103,8 +146,10 @@ void Simplify(Expr* expr) { SimplifyRampMutator()(expr); SimplifyLoadMutator()(expr); SimplifyStoreMutator()(expr); - SimplifyButStoreLoadMutator()(expr); -} + common::cas_intervals_t var_intervals; + SimplifyButStoreLoadMutator mutator(var_intervals); + mutator(expr); +} } // namespace optim } // namespace cinn diff --git a/cinn/optim/optimize.cc b/cinn/optim/optimize.cc index 1266ac5e6d61c..639353720cf76 100644 --- a/cinn/optim/optimize.cc +++ b/cinn/optim/optimize.cc @@ -1,9 +1,9 @@ #include "cinn/optim/optimize.h" #include "cinn/optim/ir_copy.h" -#include "cinn/optim/ir_eliminate_mod.h" #include "cinn/optim/ir_simplify.h" #include "cinn/optim/remove_nested_block.h" +#include "cinn/optim/transform_polyfor_to_for.h" #include "cinn/optim/unroll_loops.h" #include "cinn/optim/vectorize_loops.h" @@ -11,9 +11,11 @@ namespace cinn { namespace optim { Expr Optimize(Expr e) { + CHECK(e.defined()); auto copied = IRCopy(e); + + TransformPolyForToFor(&copied); Simplify(&copied); - // IrEliminateMod(&copied); VectorizeLoops(&copied, Target()); UnrollLoop(&copied); RemoveNestedBlock(&copied); diff --git a/cinn/optim/transform_polyfor_to_for.cc b/cinn/optim/transform_polyfor_to_for.cc index ebfa4d7f74a23..c0cd625ab7161 100644 --- a/cinn/optim/transform_polyfor_to_for.cc +++ b/cinn/optim/transform_polyfor_to_for.cc @@ -191,7 +191,8 @@ struct ForAutoSeparateMutator : ir::IRMutator { Expr* PolyForAutoSeparateHelper(Expr* expr) { ForAutoSeparateMutator mutator; - return mutator(expr); + auto* res = mutator(expr); + if (res) return res; } struct ForAutoSeparateMutatorMain : public ir::IRMutator { diff --git a/cinn/optim/vectorize_loops.cc b/cinn/optim/vectorize_loops.cc index 2fe2aa8cca6be..23f3a3cb04684 100644 --- a/cinn/optim/vectorize_loops.cc +++ b/cinn/optim/vectorize_loops.cc @@ -6,7 +6,6 @@ #include "cinn/common/ir.h" #include "cinn/ir/ir_printer.h" -#include "cinn/optim/transform_polyfor_to_for.h" #include "cinn/utils/functional.h" namespace cinn { @@ -312,11 +311,7 @@ struct VectorizeLoops_ : public IRMutator { } }; -void VectorizeLoops(Expr *expr, const Target &target) { - optim::TransformPolyForToFor(expr); - - return VectorizeLoops_(target)(expr); -} +void VectorizeLoops(Expr *expr, const Target &target) { return VectorizeLoops_(target)(expr); } namespace detail { diff --git a/cinn/optim/vectorize_loops_test.cc b/cinn/optim/vectorize_loops_test.cc index 7b9ed4d770613..208362f695bcb 100644 --- a/cinn/optim/vectorize_loops_test.cc +++ b/cinn/optim/vectorize_loops_test.cc @@ -35,7 +35,7 @@ TEST(VectorizeLoops, Split_sperate) { std::tie(i_outer, i_inner, j_outer, j_inner) = C->stage()->Tile(0, 1, bn, bn); std::tie(k_outer, k_inner) = C->stage()->Split(poly::Iterator("k"), 8); C->stage()->Reorder({i_outer, j_outer, k_outer, k_inner, i_inner, j_inner}); - C->stage()->Split(j_inner, 8, poly::SplitRestStrategy::kAuto); + C->stage()->Split(j_inner, 8); } // Code gen @@ -47,10 +47,11 @@ TEST(VectorizeLoops, Split_sperate) { target.bits = Target::Bit ::k32; target.os = Target::OS ::Linux; - optim::VectorizeLoops(&funcs[0]->body, target); + Expr body = optim::Optimize(Expr(funcs[0])); + LOG(INFO) << "body:\n" << body; lang::Module module("module1", target); - module.Append(funcs.front()); + module.Append(ir::LoweredFunc(body.As())); module.Append(C_buf); CodeGenC codegen(target); @@ -68,57 +69,55 @@ void matmul(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, stru const float* A = (const float*)(cinn_buffer_get_data_const_handle(_A)); const float* B = (const float*)(cinn_buffer_get_data_const_handle(_B)); float* C = (float*)(cinn_buffer_get_data_handle(_C)); - { - for (int32_t i_outer = 0; i_outer < 3; i_outer += 1) { - for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) { - for (int32_t k_outer = 0; k_outer < 25; k_outer += 1) { - for (int32_t k_inner = 0; k_inner < 8; k_inner += 1) { - for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) { - for (int32_t j_inner_outer = 0; j_inner_outer < 4; j_inner_outer += 1) { - for (int32_t j_inner_inner = 0; j_inner_inner < min(8, (500 + ((-8 * j_inner_outer) + (-32 * j_outer)))); j_inner_inner += 1) { - C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((8 * k_outer) + k_inner)))] * B[((8 * j_inner_outer) + ((32 * j_outer) + ((500 * k_inner) + ((4000 * k_outer) + j_inner_inner))))])); - }; + for (int32_t i_outer = 0; i_outer < 3; i_outer += 1) { + for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) { + for (int32_t k_outer = 0; k_outer < 25; k_outer += 1) { + for (int32_t k_inner = 0; k_inner < 8; k_inner += 1) { + for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) { + for (int32_t j_inner_outer = 0; j_inner_outer < 4; j_inner_outer += 1) { + for (int32_t j_inner_inner = 0; j_inner_inner < min(8, (500 + ((-8 * j_inner_outer) + (-32 * j_outer)))); j_inner_inner += 1) { + C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((8 * k_outer) + k_inner)))] * B[((8 * j_inner_outer) + ((32 * j_outer) + ((500 * k_inner) + ((4000 * k_outer) + j_inner_inner))))])); }; }; }; }; }; - for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) { - for (int32_t k_outer = 0; k_outer < 25; k_outer += 1) { - for (int32_t k_inner = 0; k_inner < 8; k_inner += 1) { - for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) { - for (int32_t j_inner_outer = 0; j_inner_outer < (63 + (-4 * j_outer)); j_inner_outer += 1) { - for (int32_t j_inner_inner = 0; j_inner_inner < min(8, (500 + ((-8 * j_inner_outer) + (-32 * j_outer)))); j_inner_inner += 1) { - C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((8 * k_outer) + k_inner)))] * B[((8 * j_inner_outer) + ((32 * j_outer) + ((500 * k_inner) + ((4000 * k_outer) + j_inner_inner))))])); - }; + }; + for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) { + for (int32_t k_outer = 0; k_outer < 25; k_outer += 1) { + for (int32_t k_inner = 0; k_inner < 8; k_inner += 1) { + for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) { + for (int32_t j_inner_outer = 0; j_inner_outer < (63 + (-4 * j_outer)); j_inner_outer += 1) { + for (int32_t j_inner_inner = 0; j_inner_inner < min(8, (500 + ((-8 * j_inner_outer) + (-32 * j_outer)))); j_inner_inner += 1) { + C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((8 * k_outer) + k_inner)))] * B[((8 * j_inner_outer) + ((32 * j_outer) + ((500 * k_inner) + ((4000 * k_outer) + j_inner_inner))))])); }; }; }; }; }; }; - for (int32_t i_outer = 3; i_outer < 4; i_outer += 1) { - for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) { - for (int32_t k_outer = 0; k_outer < 25; k_outer += 1) { - for (int32_t k_inner = 0; k_inner < 8; k_inner += 1) { - for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) { - for (int32_t j_inner_outer = 0; j_inner_outer < 4; j_inner_outer += 1) { - for (int32_t j_inner_inner = 0; j_inner_inner < min(8, (500 + ((-8 * j_inner_outer) + (-32 * j_outer)))); j_inner_inner += 1) { - C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((8 * k_outer) + k_inner)))] * B[((8 * j_inner_outer) + ((32 * j_outer) + ((500 * k_inner) + ((4000 * k_outer) + j_inner_inner))))])); - }; + }; + for (int32_t i_outer = 3; i_outer < 4; i_outer += 1) { + for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) { + for (int32_t k_outer = 0; k_outer < 25; k_outer += 1) { + for (int32_t k_inner = 0; k_inner < 8; k_inner += 1) { + for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) { + for (int32_t j_inner_outer = 0; j_inner_outer < 4; j_inner_outer += 1) { + for (int32_t j_inner_inner = 0; j_inner_inner < min(8, (500 + ((-8 * j_inner_outer) + (-32 * j_outer)))); j_inner_inner += 1) { + C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((8 * k_outer) + k_inner)))] * B[((8 * j_inner_outer) + ((32 * j_outer) + ((500 * k_inner) + ((4000 * k_outer) + j_inner_inner))))])); }; }; }; }; }; - for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) { - for (int32_t k_outer = 0; k_outer < 25; k_outer += 1) { - for (int32_t k_inner = 0; k_inner < 8; k_inner += 1) { - for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) { - for (int32_t j_inner_outer = 0; j_inner_outer < (63 + (-4 * j_outer)); j_inner_outer += 1) { - for (int32_t j_inner_inner = 0; j_inner_inner < min(8, (500 + ((-8 * j_inner_outer) + (-32 * j_outer)))); j_inner_inner += 1) { - C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((8 * k_outer) + k_inner)))] * B[((8 * j_inner_outer) + ((32 * j_outer) + ((500 * k_inner) + ((4000 * k_outer) + j_inner_inner))))])); - }; + }; + for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) { + for (int32_t k_outer = 0; k_outer < 25; k_outer += 1) { + for (int32_t k_inner = 0; k_inner < 8; k_inner += 1) { + for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) { + for (int32_t j_inner_outer = 0; j_inner_outer < (63 + (-4 * j_outer)); j_inner_outer += 1) { + for (int32_t j_inner_inner = 0; j_inner_inner < min(8, (500 + ((-8 * j_inner_outer) + (-32 * j_outer)))); j_inner_inner += 1) { + C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((8 * k_outer) + k_inner)))] * B[((8 * j_inner_outer) + ((32 * j_outer) + ((500 * k_inner) + ((4000 * k_outer) + j_inner_inner))))])); }; }; }; @@ -129,6 +128,7 @@ void matmul(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, stru } )ROC"; + std::cout << "\n" << out << std::endl; EXPECT_EQ(utils::Trim(target_out), utils::Trim(out)); } diff --git a/tests/test02_matmul_main.cc b/tests/test02_matmul_main.cc index 6cf0e83f0226a..5a77492962bcf 100644 --- a/tests/test02_matmul_main.cc +++ b/tests/test02_matmul_main.cc @@ -295,7 +295,6 @@ TEST(matmul, ArrayPacking) { auto func = Optimize(funcs.front()); module.Append(ir::LoweredFunc(func.As())); - // module.Append(funcs.front()); CodeGenCX86 compiler(target, CodeGenCX86::Feature::AVX256); Outputs outputs;