Skip to content

Commit

Permalink
CAS support mod inference (PaddlePaddle#107)
Browse files Browse the repository at this point in the history
* enable mod compute

* code clean and bug fix
  • Loading branch information
Superjomn authored Mar 29, 2020
1 parent d01cfc8 commit 13b63ea
Show file tree
Hide file tree
Showing 17 changed files with 268 additions and 291 deletions.
1 change: 1 addition & 0 deletions cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ std::string CodeGenC::Compile(const lang::Module &module, OutputKind output_kind
return ss_.str();
}
std::string CodeGenC::Compile(const ir::LoweredFunc &function) {
CHECK(function.defined());
Print(function);
os() << "\n\n";
return ss_.str();
Expand Down
3 changes: 3 additions & 0 deletions cinn/common/arithmatic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,16 @@ bool IsPureMath(Expr expr) {
IrNodeTy ::Sub,
IrNodeTy ::Div,
IrNodeTy ::Mul,
IrNodeTy::Mod,
IrNodeTy ::Minus,
});

auto complex_nodes = ir::CollectIRNodes(expr, [&](const Expr* n) { return !valid_node_tys.count(n->node_type()); });
#ifdef CINN_DEBUG
for (auto& node : complex_nodes) {
VLOG(3) << "Found " << node->node_type() << " " << Expr(node);
}
#endif
return complex_nodes.empty();
}

Expand Down
179 changes: 107 additions & 72 deletions cinn/common/cas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,71 @@ Expr Exponent(Expr v) {

namespace detail {

// Is a Divisiable to b.
// @{
bool IsDivisible(int64_t a, int64_t b) {
CHECK_NE(b, 0);
return a % b == 0;
}
bool IsDivisible(const Sum* a, int b);
bool IsDivisible(const Product* a, int b) {
for (auto& item : a->operands()) {
if (item.As<IntImm>() && IsDivisible(item.As<IntImm>()->value, b)) return true;
if (item.As<Sum>() && IsDivisible(item.As<Sum>(), b)) return true;
}
return false;
}
bool IsDivisible(const Sum* a, int b) {
for (auto& item : a->operands()) {
auto* vi = item.As<IntImm>();
auto* vp = item.As<Product>();
if (vi && IsDivisible(vi->value, b)) continue;
if (vp && IsDivisible(vp, b)) continue;
return false;
}
return true;
}
bool IsDivisible(Expr a, int b) {
auto* ai = a.As<IntImm>();
auto* as = a.As<Sum>();
auto* ap = a.As<Product>();

if (ai) return IsDivisible(ai->value, b);
if (as) return IsDivisible(as, b);
if (ap) return IsDivisible(ap, b);
return false;
}
// @}

//! Divide a by b, NOTE that a should be divisible by b.
// @{
Expr Divide(const Product* a, int b);
Expr Divide(const Sum* a, int b) {
std::vector<Expr> args;
for (auto& item : a->operands()) {
if (item.As<IntImm>())
args.push_back(make_const(item.type(), item.As<IntImm>()->value / b));
else if (item.As<Product>())
args.push_back(Divide(item.As<Product>(), b));
else
NOT_IMPLEMENTED
}
return Sum::Make(args);
}
Expr Divide(const Product* a, int b) {
auto* a_first_i = a->operand(0).As<IntImm>();
CHECK(a_first_i);
int times = a_first_i->value / b;
if (times == 1) {
return Product::Make(Rest(a->operands()));
} else {
auto args = Rest(a->operands());
args.insert(std::begin(args), make_const(a->type(), times));
return Product::Make(args);
}
}
// @}

inline int Iquot(int n, int d) { return n / d; }

inline int Irem(int n, int d) {
Expand Down Expand Up @@ -798,33 +863,59 @@ Expr CasSimplifyMutator::SimplifyMod(Expr u) {
auto a = CasSimplify(node->a(), var_intervals);
auto b = CasSimplify(node->b(), var_intervals);

auto* ai = a.As<IntImm>();
auto* bi = b.As<IntImm>();
auto* a_i = a.As<IntImm>();
auto* a_product = a.As<Product>();
auto* a_sum = a.As<Sum>();
auto* a_var = a.As<_Var_>();
auto* a_mod = a.As<Mod>();

auto* b_i = b.As<IntImm>();

// 7 % 3
if (ai && bi) {
return make_const(ai->type(), ai->value % bi->value);
if (a_i && b_i) {
return make_const(a_i->type(), a_i->value % b_i->value);
}

// x % 1 = 0
if (bi && bi->value == 1) return make_const(bi->type(), 0);
if (b_i && b_i->value == 1) return make_const(b_i->type(), 0);

// 2x % 2 = 0
if (bi) {
auto* ap = a.As<Product>();
if (ap && ap->operand(0).As<IntImm>()) {
if (ap->operand(0).As<IntImm>()->value % bi->value == 0) return make_const(ap->type(), 0);
}
if (b_i && a_product && a_product->operand(0).As<IntImm>()) {
if (a_product->operand(0).As<IntImm>()->value % b_i->value == 0) return make_const(a_product->type(), 0);
}

// 0 % x = 1, 1 % x = 1
if (a_i && (a_i->value == 0 || a_i->value == 1)) return a;

if (b_i && a_var && var_intervals.count(a_var->name)) {
auto& interval = var_intervals.at(a_var->name);
int b_abs = std::abs(b_i->value);
// x\in[1, 3] % 4 = x
if (std::abs(interval.l) < b_abs && std::abs(interval.r) < b_abs) return a;
// [3,3] % 3 = 0
if (interval.l == interval.r && interval.l % b_abs == 0) return make_const(b_i->type(), 0);
}

if (ai && (ai->value == 0 || ai->value == 1)) return a;
if (a_product && b_i) {
if (IsDivisible(a_product, b_i->value)) {
return make_const(Int(32), 0);
}
}

// (x+y) % 2 = x%2 + y%2
if (a.As<Sum>()) {
// (2x+y+z) % 2 = (y+z) % 2
if (a_sum && b_i) {
std::vector<Expr> sum_args;
for (auto& v : a->operands) {
sum_args.push_back(Mod::Make(v, b));
for (auto& v : a_sum->operands()) {
if (!IsDivisible(v, b_i->value)) {
sum_args.push_back(v);
}
}
return CasSimplify(Sum::Make(sum_args), var_intervals);
if (sum_args.size() == a_sum->operands().size()) return Mod::Make(a, b);
if (sum_args.empty()) return make_const(b_i->type(), 0);
if (sum_args.size() == 1) {
return SimplifyMod(Mod::Make(sum_args.front(), b));
}
return SimplifyMod(Mod::Make(Sum::Make(sum_args), b));
}

return Mod::Make(a, b);
Expand Down Expand Up @@ -1030,61 +1121,6 @@ bool IsExprCasCompatible(Expr expr) {
return ir::CollectIRNodes(expr, teller).empty();
}

bool IsDivisible(int64_t a, int64_t b) {
CHECK_NE(b, 0);
return a % b == 0;
}

bool IsDivisible(const Sum* a, int b);

bool IsDivisible(const Product* a, int b) {
for (auto& item : a->operands()) {
if (item.As<IntImm>() && IsDivisible(item.As<IntImm>()->value, b)) return true;
if (item.As<Sum>() && IsDivisible(item.As<Sum>(), b)) return true;
}
return false;
}

bool IsDivisible(const Sum* a, int b) {
for (auto& item : a->operands()) {
auto* vi = item.As<IntImm>();
auto* vp = item.As<Product>();
if (vi && IsDivisible(vi->value, b)) continue;
if (vp && IsDivisible(vp, b)) continue;
return false;
}
return true;
}

//! Divide a by b, NOTE that a should be divisible by b.
// @{
Expr Divide(const Product* a, int b);
Expr Divide(const Sum* a, int b) {
std::vector<Expr> args;
for (auto& item : a->operands()) {
if (item.As<IntImm>())
args.push_back(make_const(item.type(), item.As<IntImm>()->value / b));
else if (item.As<Product>())
args.push_back(Divide(item.As<Product>(), b));
else
NOT_IMPLEMENTED
}
return Sum::Make(args);
}
Expr Divide(const Product* a, int b) {
auto* a_first_i = a->operand(0).As<IntImm>();
CHECK(a_first_i);
int times = a_first_i->value / b;
if (times == 1) {
return Product::Make(Rest(a->operands()));
} else {
auto args = Rest(a->operands());
args.insert(std::begin(args), make_const(a->type(), times));
return Product::Make(args);
}
}
// @}

// Partially divide a by b. e.g. (2x+y)/2 => x + y/2
Expr DividePartially(Sum* a, int b) {
std::vector<Expr> external_sum_args, sum_args;
Expand Down Expand Up @@ -1138,7 +1174,6 @@ Expr CasSimplifyMutator::FurtherSimplifyFracWithInterval(
if (bi) {
if (av) {
auto it = var_intervals.find(av->name);
if (it != var_intervals.end()) LOG(INFO) << "found " << av->name;
if (it != var_intervals.end() && std::abs(it->second.r) < std::abs(bi->value) &&
std::abs(it->second.l) < std::abs(bi->value))
return make_const(a.type(), 0);
Expand Down
4 changes: 3 additions & 1 deletion cinn/common/cas.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace common {
* Interval of a _Var_.
*/
struct CasInterval {
CasInterval(int l, int r) : l(l), r(r) {}
CasInterval(int64_t l, int64_t r) : l(l), r(r) {}
int l, r;

friend std::ostream& operator<<(std::ostream& os, const CasInterval& i) {
Expand All @@ -19,6 +19,8 @@ struct CasInterval {
}
};

using cas_intervals_t = std::unordered_map<std::string, CasInterval>;

Expr AutoSimplify(Expr u, const std::unordered_map<std::string, CasInterval>& var_intervals = {});

//! Simplify a CAS expression.
Expand Down
44 changes: 43 additions & 1 deletion cinn/common/cas_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ TEST(CAS, SimplifyMod) {
{Power::Make(x, Expr(3)), Mod::Make(x, Expr(5)), y, Expr(1), Mod::Make(Product::Make({x, Expr(4)}), Expr(5))}));

EXPECT_EQ(GetStreamCnt(u1), "0");
EXPECT_EQ(GetStreamCnt(u2), "((x % 2) + (y % 2) + (z % 2))");
EXPECT_EQ(GetStreamCnt(u2), "((x + y + z) % 2)");
EXPECT_EQ(GetStreamCnt(u3), "1");
EXPECT_EQ(GetStreamCnt(u4), "(1 + (x^3) + y)");
}
Expand Down Expand Up @@ -222,6 +222,48 @@ TEST(CAS, FracOp) {
EXPECT_EQ(GetStreamCnt(u4), "((32768 * (y/32)) + (32768 * x))");
}

#define OUTPUT_EQUAL(s__) EXPECT_EQ(GetStreamCnt(u), s__);

TEST(CAS, Mod) {
Var x = ir::_Var_::Make("x", Int(32));
Var y = ir::_Var_::Make("y", Int(32));
Var z = ir::_Var_::Make("z", Int(32));
Var k = ir::_Var_::Make("k", Int(32));

std::unordered_map<std::string, CasInterval> var_intervals0, var_intervals1;
var_intervals0.emplace("x", CasInterval{0, 3});
var_intervals0.emplace("y", CasInterval{0, 3});
var_intervals0.emplace("z", CasInterval{0, 3});
var_intervals0.emplace("k", CasInterval{0, 3});

Expr u;
u = AutoSimplify(x % 5);
EXPECT_EQ(GetStreamCnt(u), "(x % 5)");
OUTPUT_EQUAL("(x % 5)")

u = AutoSimplify((5 + x) % 5);
OUTPUT_EQUAL("(x % 5)")

u = AutoSimplify((x + 5 * y + 1 + 1 + 3 - z * 3) % 5);
OUTPUT_EQUAL("((x + (-3 * z)) % 5)")

u = AutoSimplify((x + 5) % 5, var_intervals0);
OUTPUT_EQUAL("x")

u = AutoSimplify((x + y + 5) % 5, var_intervals0);
OUTPUT_EQUAL("((x + y) % 5)")

u = AutoSimplify((x + 20 * y + 5) % 5, var_intervals0);
OUTPUT_EQUAL("x")

u = AutoSimplify((x % 32) + ((32768 * (x / 32)) + ((32768 * y) + ((32 * z) + (128 * k)))));
OUTPUT_EQUAL("((32768 * (x/32)) + ((x % 32) + ((128 * k) + ((32768 * y) + (32 * z)))))");

u = AutoSimplify((x % 32) + ((32768 * (x / 32)) + ((32768 * y) + ((32 * z) + (128 * k)))), var_intervals0);
OUTPUT_EQUAL("((128 * k) + (x + ((32768 * y) + (32 * z))))")
LOG(INFO) << u;
}

TEST(CAS, IntConnerCase) {
Var x = ir::_Var_::Make("x", Int(32));
Var y = ir::_Var_::Make("y", Int(32));
Expand Down
6 changes: 6 additions & 0 deletions cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,12 @@ Expr Load::index() const {
return res;
}

const std::string &Load::name() const {
auto *t = tensor.As<ir::_Tensor_>();
CHECK(t);
return t->name;
}

Expr Ramp::Make(Expr base, Expr stride, int lanes) {
CHECK(base.defined());
CHECK(stride.defined());
Expand Down
2 changes: 2 additions & 0 deletions cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ struct Load : public ExprNode<Load> {
std::vector<Expr*> expr_fields() override;
std::vector<const Expr*> expr_fields() const override;

const std::string& name() const;

Type type() const override;

static const IrNodeTy _node_type_ = IrNodeTy::Load;
Expand Down
Loading

0 comments on commit 13b63ea

Please sign in to comment.