Skip to content

Commit

Permalink
[ARITH] Allow Analyzer to MarkGlobalNonNegValue
Browse files Browse the repository at this point in the history
This PR introduces an utility function MarkGlobalNonNegValue.
This function allows analyzer to mark buffer shapes in function arguments
as positive globally and opens doors for more symbolic simplification.
  • Loading branch information
tqchen committed Jul 2, 2023
1 parent e178375 commit 68d038b
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 19 deletions.
16 changes: 16 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,22 @@ class TVM_DLL Analyzer {
TransitiveComparisonAnalyzer transitive_comparisons;
/*! \brief constructor */
Analyzer();
/*!
* \brief Mark the value as non-negative value globally in analyzer.
*
* Only call this function if the non-neg condition is global and
* not context-dependent.
*
* This function does best-effort propagations to the sub-analyzers
*
* \note We expose this function because non-negative global values,
* such as symbolic buffer shapes in function arguments are really
* important to ensure the best simplification, and usually they
* can be handled in a simpler way than the generic constraints.
*
* This function may call into the Update function of the sub-analyzers.
*/
void MarkGlobalNonNegValue(const PrimExpr& value);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
Expand Down
33 changes: 33 additions & 0 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

#include "const_fold.h"
#include "product_normal_form.h"

namespace tvm {
Expand Down Expand Up @@ -63,6 +64,38 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
// skip rewrite simplify
}

void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) {
// split out the symbolic and non-symbolic part
int64_t cscale = 1;
PrimExpr symbolic = tir::make_const(value.dtype(), 1);
auto fcollect = [&](PrimExpr val) {
if (const auto* intimm = val.as<IntImmNode>()) {
cscale *= intimm->value;
} else {
symbolic = symbolic * val;
}
};
UnpackReduction<tir::MulNode>(value, fcollect);
if (cscale <= 0) return;
// override the constant int bound by marking it as non-negative
// NOTE: there might be future opportunities of more bound hint
// this is a simple step and covers all the current needs
//
// We may consider enhance the sub analyzer to directly take
// MarkPositiveVar so their bounds do not overlap
if (const auto* var_ptr = symbolic.as<VarNode>()) {
Var var = GetRef<Var>(var_ptr);
// skip non-index type, keep it to be compatible
// with any_dim that do not represent any value
if (!IsIndexType(var.dtype())) return;
bool allow_override = true;
// mark the constant bound is sufficient
// we cannot mark interval set as that will cause relaxation of the var
// during bound proof which is not our intention
this->const_int_bound.Update(var, ConstIntBound(0, ConstIntBound::kPosInf), allow_override);
}
}

void Analyzer::Bind(const Map<Var, Range>& variables, bool allow_override) {
for (const auto& iter : variables) {
this->Bind(iter.first, iter.second, allow_override);
Expand Down
38 changes: 31 additions & 7 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,31 @@ class ConstIntBoundAnalyzer::Impl
return Intersect(a, b);
}

/*!
* \brief Process the divisor by making assumption that divide by zero
* won't happen in a valid program.
*
* This is important for us to get a lot of symbolic shape bound right
* now that the shape n >= 0, but in cases
* when mod or divide of n occur, the intention is actually n > 0
*
* \param divisor The input divsor entry
* \return The processed entry
*/
Entry AssumeNoZeroDivisor(Entry divisor) {
ICHECK(!divisor.is_const(0)) << "Find divide by zero";
// NOTE: here we make the assumption that
// divide by zero won't happen in a valid program
// this is important for us to get a lot of symbolic shape bound right
// where most conditions know that the shape n >= 0, but in cases
// when mod or divide of n occur, the intention is actually n > 0
if (divisor.min_value == 0) {
divisor.min_value = 1;
ICHECK_GE(divisor.max_value, 1);
}
return divisor;
}

Entry VisitExpr_(const IntImmNode* op) final { return MakeBound(op->value, op->value); }

Entry VisitExpr_(const AddNode* op) final {
Expand Down Expand Up @@ -223,14 +248,14 @@ class ConstIntBoundAnalyzer::Impl

Entry VisitExpr_(const DivNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
ICHECK(!b.is_const(0)) << "divide by zero";
Entry b = AssumeNoZeroDivisor(VisitExpr(op->b));
return HandleDivision(a, b, op->dtype, InfAwareDiv);
}

Entry VisitExpr_(const ModNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry b = AssumeNoZeroDivisor(VisitExpr(op->b));

if (b.min_value > 0) {
int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
if (a.min_value >= 0) {
Expand All @@ -252,8 +277,7 @@ class ConstIntBoundAnalyzer::Impl

Entry VisitExpr_(const FloorDivNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
ICHECK(!b.is_const(0)) << "floordiv by zero";
Entry b = AssumeNoZeroDivisor(VisitExpr(op->b));
return HandleDivision(a, b, op->dtype, InfAwareFloorDiv);
}

Expand All @@ -276,7 +300,8 @@ class ConstIntBoundAnalyzer::Impl
* That is, min(0, b_min + 1) <= floormod(a, b) <= max(0, b_max - 1)
*/
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry b = AssumeNoZeroDivisor(VisitExpr(op->b));

if (b.min_value > 0) {
int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
if (a.min_value >= 0) {
Expand Down Expand Up @@ -457,7 +482,6 @@ class ConstIntBoundAnalyzer::Impl
// at a negative value and ends at a positive one, narrow it down to
// be closer to 0, because BinaryOpBoundary only checks end-points of
// the domain ranges.

// If the range of b contains 0, then some infinity will be involved
if (b.min_value <= 0 && 0 <= b.max_value && dt.is_int()) {
Entry b_neg = b.min_value < 0 ? MakeBound(b.min_value, -1) : Everything(dt);
Expand Down
9 changes: 9 additions & 0 deletions src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ namespace arith {

using namespace tir;

void IRMutatorWithAnalyzer::MarkBufferMapShapes(const tir::PrimFunc& func) {
// Mark the all the symbolic buffer shape values in the buffer map as positive value.
for (auto kv : func->buffer_map) {
for (PrimExpr shape : kv.second->shape) {
analyzer_->MarkGlobalNonNegValue(shape);
}
}
}

Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
// record the loop variable as iterators
Range dom = Range::FromMinExtent(op->min, op->extent);
Expand Down
8 changes: 8 additions & 0 deletions src/arith/ir_mutator_with_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
PrimExpr VisitExpr_(const tir::ReduceNode* op) override;

protected:
/*!
* \brief Mark the all the buffer shape values in the buffer map as positive value.
*
* \note call this function before Visit function's body to maximize
* simplification efficiency
*/
void MarkBufferMapShapes(const tir::PrimFunc& func);

/*! \brief internal analyzer field. */
Analyzer* analyzer_;
// the following two fields are useful in case we want
Expand Down
1 change: 1 addition & 0 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
arith::Analyzer ana;
auto pass = BufferFlattener(&ana);
auto writer = func.CopyOnWrite();
pass.MarkBufferMapShapes(func);
writer->body = pass.VisitStmt(func->body);
// The buffers in func->buffer_map are deliberately left
// unflattened, as they are used for validation of user-provided
Expand Down
21 changes: 9 additions & 12 deletions src/tir/transforms/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,24 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.Simplify", SimplifyConfig);

class StmtSimplifier : public IRMutatorWithAnalyzer {
public:
static Stmt Apply(Stmt stmt, Analyzer* analyzer, Optional<SimplifyConfig> config_opt = NullOpt) {
static PrimFunc Apply(PrimFunc func, Analyzer* analyzer,
Optional<SimplifyConfig> config_opt = NullOpt) {
auto config = config_opt.value_or(AttrsWithDefaultValues<arith::SimplifyConfig>());
analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions());

std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
if (config->propagate_knowns_to_prove_conditional ||
config->propagate_knowns_to_simplify_expressions) {
touch_pattern = ControlFlowGraph(stmt);
touch_pattern = ControlFlowGraph(func->body);
}

std::unordered_set<const VarNode*> used_in_buffer_def = CollectVarsUsedInBufferDefinition(stmt);
std::unordered_set<const VarNode*> used_in_buffer_def =
CollectVarsUsedInBufferDefinition(func->body);
StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
std::move(used_in_buffer_def));
return simplifier(std::move(stmt));
simplifier.MarkBufferMapShapes(func);
func.CopyOnWrite()->body = simplifier(func->body);
return func;
}

private:
Expand Down Expand Up @@ -335,21 +339,14 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
} // namespace arith

namespace tir {

Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer) {
return arith::StmtSimplifier::Apply(stmt, analyzer);
}

namespace transform {

Pass Simplify() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
arith::Analyzer analyzer;
auto cfg = ctx->GetConfig<arith::SimplifyConfig>("tir.Simplify");

auto* n = f.CopyOnWrite();
n->body = arith::StmtSimplifier::Apply(std::move(n->body), &analyzer, cfg);
return f;
return arith::StmtSimplifier::Apply(f, &analyzer, cfg);
};
return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {});
}
Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_arith_const_int_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,23 @@ def test_floormod_negative_divisor():
assert bd.max_value == 6


def test_divmod_assume_no_zero_divsor():
# Divmod non negative expression makes assumption that divide by zero won't occur
# this assumption is important to get best result from symbolic shape programs
analyzer = tvm.arith.Analyzer()
flm, fld = tvm.te.floormod, tvm.te.floordiv
a, b = te.var("a"), te.var("b")
analyzer.update(a, tvm.arith.ConstIntBound(0, 6))
analyzer.update(b, tvm.arith.ConstIntBound(0, tvm.arith.ConstIntBound.POS_INF))
bd = analyzer.const_int_bound(fld(a, b))
assert bd.min_value == 0
assert bd.max_value == 6

bd = analyzer.const_int_bound(flm(a, b))
assert bd.min_value == 0
assert bd.max_value == 6


def test_multiple_condition():
analyzer = tvm.arith.Analyzer()
flm, fld = tvm.te.floormod, tvm.te.floordiv
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1733,5 +1733,21 @@ def before(A_ptr: T.handle("float32"), A_stride: T.int32):
expected = before


class TestBufferShapeConstraint(BaseBeforeAfter):
"""If enabled, rewrite boolean expressions into AND of OR"""

convert_boolean_to_and_of_ors = True

def before(a: T.handle):
n = T.int64()
A = T.match_buffer(a, (n * 32,), "float32")
A[T.min(T.int64(0), n)] = T.float32(0)

def expected(a: T.handle):
n = T.int64()
A = T.match_buffer(a, (n * 32,), "float32")
A[T.int64(0)] = T.float32(0)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 68d038b

Please sign in to comment.