Skip to content

Commit

Permalink
Add CanProveDivisible for symbolic calculation (#60572)
Browse files Browse the repository at this point in the history
* add CanProveDivisible for symbolic calculation

* delete extra cout for debug

* fix according to some comments
  • Loading branch information
Courtesy-Xs authored Jan 6, 2024
1 parent ed6f32d commit ee3d2fc
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 4 deletions.
121 changes: 117 additions & 4 deletions paddle/cinn/common/integer_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
// limitations under the License.

#include "paddle/cinn/common/integer_set.h"

#include "paddle/cinn/common/arithmatic.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
Expand Down Expand Up @@ -164,11 +166,115 @@ std::optional<bool> SymbolicExprAnalyzer::ProveLT(const ir::Expr& lhs,
return ProveGT(rhs, lhs);
}

// Tell whether lhs can be divisible by rhs, lhs must be a pure math expression
// and rhs must be a var
std::optional<bool> SymbolicExprAnalyzer::ProveDivisible(
const ir::Expr& lhs, const ir::Expr& rhs) const {
CHECK(rhs.is_var()) << "Rhs in ProveDivisible must be a var temporarily!\n";
CHECK(lhs.defined());
CHECK(rhs.defined());
CHECK(cinn::common::IsPureMath(lhs));

ir::Expr lhs_copy = ir::ir_utils::IRCopy(lhs);
if (cinn::common::is_zero(lhs_copy)) return true;

auto OptionalAnd = [](const std::optional<bool>& lhs,
const std::optional<bool>& rhs) -> std::optional<bool> {
if (lhs.has_value() && rhs.has_value()) {
return lhs.value() && rhs.value();
} else {
return std::nullopt;
}
};
auto OptionalOr = [](const std::optional<bool>& lhs,
const std::optional<bool>& rhs) -> std::optional<bool> {
if (lhs.has_value() && rhs.has_value()) {
return lhs.value() || rhs.value();
} else if ((!lhs.has_value()) && (!rhs.has_value())) {
return std::nullopt;
} else if (lhs.has_value() && (!rhs.has_value())) {
return lhs.value() ? std::optional<bool>(lhs.value())
: std::optional<bool>(std::nullopt);
} else {
return rhs.value() ? std::optional<bool>(rhs.value())
: std::optional<bool>(std::nullopt);
}
};

std::vector<ir::Expr> ops{};
std::optional<bool> res = std::nullopt;
ir::Expr zero(0);
ir::Expr tmp_expr;

auto is_ge = ProveGE(lhs, rhs);

switch (lhs.node_type()) {
case cinn::ir::IrNodeTy::_Var_:
return ProveEQ(lhs, rhs);
case cinn::ir::IrNodeTy::IntImm:
return false;
case cinn::ir::IrNodeTy::Sum:
res = true;
ops = lhs.As<ir::Sum>()->operands();
CHECK(!ops.empty());
std::for_each(ops.begin(), ops.end(), [&](const ir::Expr& expr) {
res = OptionalAnd(res, this->ProveDivisible(expr, rhs));
});
res = OptionalAnd(res, is_ge);
return res;
case cinn::ir::IrNodeTy::Product:
res = false;
ops = lhs.As<ir::Product>()->operands();
CHECK(!ops.empty());
std::for_each(ops.begin(), ops.end(), [&](const ir::Expr& expr) {
res = OptionalOr(res, this->ProveDivisible(expr, rhs));
if (res.has_value() && res.value()) return;
});
res = OptionalAnd(res, is_ge);
return res;
case cinn::ir::IrNodeTy::FracOp:
tmp_expr = cinn::common::AutoSimplify(lhs);
if (tmp_expr.node_type() == cinn::ir::IrNodeTy::FracOp)
return std::nullopt;
return OptionalAnd(ProveDivisible(tmp_expr, rhs), is_ge);
case cinn::ir::IrNodeTy::FloatImm:
return false;
case cinn::ir::IrNodeTy::Add:
return OptionalAnd(
OptionalAnd(ProveDivisible(lhs.As<ir::Add>()->a(), rhs),
ProveDivisible(lhs.As<ir::Add>()->b(), rhs)),
is_ge);
case cinn::ir::IrNodeTy::Sub:
return OptionalAnd(
OptionalAnd(ProveDivisible(lhs.As<ir::Sub>()->a(), rhs),
ProveDivisible(lhs.As<ir::Sub>()->b(), rhs)),
is_ge);
case cinn::ir::IrNodeTy::Div:
tmp_expr = cinn::common::AutoSimplify(lhs);
if (tmp_expr.node_type() == cinn::ir::IrNodeTy::Div) return std::nullopt;
return OptionalAnd(ProveDivisible(tmp_expr, rhs), is_ge);
case cinn::ir::IrNodeTy::Mul:
return OptionalAnd(
OptionalOr(ProveDivisible(lhs.As<ir::Mul>()->a(), rhs),
ProveDivisible(lhs.As<ir::Mul>()->b(), rhs)),
is_ge);
case cinn::ir::IrNodeTy::Mod:
return false;
case cinn::ir::IrNodeTy::Minus:
return ProveDivisible(lhs.As<ir::Minus>()->v(), rhs);
default:
LOG(FATAL) << "Not supported yet!";
break;
}
}

class BoundReplacer : public ir::IRMutator<> {
public:
explicit BoundReplacer(const cas_intervals_t& var_intervals,
bool is_lower_bound)
: var_intervals_(var_intervals), sign_(is_lower_bound) {}
: var_intervals_(var_intervals),
sign_(is_lower_bound),
var_visited_({}) {}

void operator()(ir::Expr* expr) { IRMutator::Visit(expr, expr); }

Expand All @@ -183,10 +289,16 @@ class BoundReplacer : public ir::IRMutator<> {
upper_bound =
interval.e_r.defined() ? interval.e_r : ir::Expr(interval.r);
}
if (sign_) {
*op = ir::ir_utils::IRCopy(lower_bound);
if (!var_visited_.count(var->name)) {
if (sign_) {
*op = ir::ir_utils::IRCopy(lower_bound);
var_visited_.insert({var->name, lower_bound});
} else {
*op = ir::ir_utils::IRCopy(upper_bound);
var_visited_.insert({var->name, upper_bound});
}
} else {
*op = ir::ir_utils::IRCopy(upper_bound);
*op = ir::ir_utils::IRCopy(var_visited_.at(var->name));
}
}

Expand Down Expand Up @@ -248,6 +360,7 @@ class BoundReplacer : public ir::IRMutator<> {

private:
const cas_intervals_t& var_intervals_;
std::unordered_map<std::string, ir::Expr> var_visited_;
// Determine replacing with upper or lower bound,
// True means lower bound and False means upper bound.
bool sign_;
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/common/integer_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class SymbolicExprAnalyzer {
std::optional<bool> ProveLE(const ir::Expr& lhs, const ir::Expr& rhs) const;
std::optional<bool> ProveGT(const ir::Expr& lhs, const ir::Expr& rhs) const;
std::optional<bool> ProveLT(const ir::Expr& lhs, const ir::Expr& rhs) const;
std::optional<bool> ProveDivisible(const ir::Expr& lhs,
const ir::Expr& rhs) const;

ir::Expr LowerBound(const ir::Expr& expr) const;
ir::Expr UpperBound(const ir::Expr& expr) const;
Expand Down
44 changes: 44 additions & 0 deletions paddle/cinn/common/integer_set_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,50 @@ TEST_F(TestSymbolicExprAnalyzer, compare) {
analyzer.Prove(e3 < e4).value());
}

TEST_F(TestSymbolicExprAnalyzer, Divisible) {
auto x = ir::Var(ir::Expr(1), ir::Expr(7), "x");
auto y = ir::Var(ir::Expr(1), ir::Expr(15), "y");
auto S = ir::Var(ir::Expr(16), ir::Expr(256), "S");

cas_intervals_t divisible_var_intervals = {
{"x", CasInterval(x->lower_bound, x->upper_bound)},
{"y", CasInterval(y->lower_bound, y->upper_bound)},
{"S", CasInterval(S->lower_bound, S->upper_bound)},
};
SymbolicExprAnalyzer divisible_analyzer{divisible_var_intervals};

// case 1
ir::Expr e1 = 4 * x + 2 * y * x;
ir::Expr e2 = x;
ir::Expr e3 = y;

EXPECT_TRUE(divisible_analyzer.ProveDivisible(e1, e2).value_or(false));
EXPECT_FALSE(divisible_analyzer.ProveDivisible(e1, e3).value_or(false));

// case 2
ir::Expr e4 = y + y * x + 4 * y - x * y;

EXPECT_TRUE(divisible_analyzer.ProveDivisible(e4, e3).value_or(false));
EXPECT_FALSE(divisible_analyzer.ProveDivisible(e4, e2).value_or(false));

// case 3
ir::Expr e5 = x / y + x + y;

EXPECT_FALSE(divisible_analyzer.ProveDivisible(e5, e3).value_or(false));
EXPECT_FALSE(divisible_analyzer.ProveDivisible(e5, e2).value_or(false));

// case 4
ir::Expr e6 = S * x / 4 + x * y;

EXPECT_FALSE(divisible_analyzer.ProveDivisible(e6, e2).value_or(false));
EXPECT_FALSE(divisible_analyzer.ProveDivisible(e6, e3).value_or(false));

ir::Expr e7 = 16 * x / 4 + x * y;

EXPECT_TRUE(divisible_analyzer.ProveDivisible(e7, e2).value_or(false));
EXPECT_FALSE(divisible_analyzer.ProveDivisible(e7, e3).value_or(false));
}

TEST(SingleIntervalIntSet, constant) {
SingleIntervalIntSet empty_set(ir::Expr(0), ir::Expr(-1));
SingleIntervalIntSet all_set(SymbolicExprLimit::negative_inf,
Expand Down

0 comments on commit ee3d2fc

Please sign in to comment.