Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class ConstIntBoundAnalyzer {
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false);

/*!
* \brief Bind variable to a range.
*
Expand All @@ -163,6 +164,13 @@ class ConstIntBoundAnalyzer {
*/
TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);

/*!
* \brief Check if a variable is bound to a range.
* \param var The variable.
* \return Whether the variable is bound to a range.
*/
TVM_DLL bool IsBound(const Var& var) const;

private:
friend class Analyzer;
friend class ConstraintContext;
Expand Down
43 changes: 31 additions & 12 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from typing import Union

import tvm.ffi
from tvm import tir, ir
from tvm import ir, tir
from tvm.arith import IntSet
from tvm.runtime import Object

from . import _ffi_api
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(self):
_mod = _ffi_api.CreateAnalyzer()
self._const_int_bound = _mod("const_int_bound")
self._const_int_bound_update = _mod("const_int_bound_update")
self._const_int_bound_is_bound = _mod("const_int_bound_is_bound")
self._bind = _mod("bind")
self._modular_set = _mod("modular_set")
self._simplify = _mod("Simplify")
Expand All @@ -123,7 +125,7 @@ def __init__(self):
self._get_enabled_extensions = _mod("get_enabled_extensions")
self._set_enabled_extensions = _mod("set_enabled_extensions")

def const_int_bound(self, expr):
def const_int_bound(self, expr: tir.PrimExpr) -> ConstIntBound:
"""Find constant integer bound for expr.

Parameters
Expand All @@ -138,7 +140,22 @@ def const_int_bound(self, expr):
"""
return self._const_int_bound(expr)

def modular_set(self, expr):
def const_int_bound_is_bound(self, var: tir.Var) -> bool:
"""Check if a variable is bound to a range.

Parameters
----------
var : tvm.tir.Var
The variable.

Returns
-------
result : bool
Whether the variable is bound to a range.
"""
return self._const_int_bound_is_bound(var)

def modular_set(self, expr: tir.PrimExpr) -> ModularSet:
"""Find a modular set that expr belongs to.

Parameters
Expand All @@ -153,7 +170,7 @@ def modular_set(self, expr):
"""
return self._modular_set(expr)

def simplify(self, expr, steps=2):
def simplify(self, expr: tir.PrimExpr, steps: int = 2) -> tir.PrimExpr:
"""Simplify expression via both rewrite and canonicalization.

Parameters
Expand All @@ -173,7 +190,7 @@ def simplify(self, expr, steps=2):
"""
return self._simplify(expr, steps)

def rewrite_simplify(self, expr):
def rewrite_simplify(self, expr: tir.PrimExpr) -> tir.PrimExpr:
"""Simplify expression via rewriting rules.

Parameters
Expand All @@ -195,7 +212,7 @@ def rewrite_simplify_stats(self):
def reset_rewrite_simplify_stats(self):
self._reset_rewrite_simplify_stats()

def canonical_simplify(self, expr):
def canonical_simplify(self, expr: tir.PrimExpr) -> tir.PrimExpr:
"""Simplify expression via canonicalization.

Parameters
Expand All @@ -210,7 +227,7 @@ def canonical_simplify(self, expr):
"""
return self._canonical_simplify(expr)

def int_set(self, expr, dom_map):
def int_set(self, expr: tir.PrimExpr, dom_map: dict[tir.Var, IntSet]) -> IntSet:
"""Compute a symbolic IntSet that covers expr for all values in dom_map.

Parameters
Expand All @@ -228,7 +245,9 @@ def int_set(self, expr, dom_map):
"""
return self._int_set(expr, dom_map)

def can_prove(self, expr, strength=ProofStrength.DEFAULT):
def can_prove(
self, expr: tir.PrimExpr, strength: ProofStrength = ProofStrength.DEFAULT
) -> bool:
"""Check whether we can prove expr to be true.

Parameters
Expand All @@ -246,7 +265,7 @@ def can_prove(self, expr, strength=ProofStrength.DEFAULT):
"""
return self._can_prove(expr, strength)

def bind(self, var: tir.Var, expr: Union[tir.PrimExpr, ir.Range]):
def bind(self, var: tir.Var, expr: Union[tir.PrimExpr, ir.Range]) -> None:
"""Bind a variable to the expression.

Parameters
Expand All @@ -259,7 +278,7 @@ def bind(self, var: tir.Var, expr: Union[tir.PrimExpr, ir.Range]):
"""
return self._bind(var, expr)

def constraint_scope(self, constraint):
def constraint_scope(self, constraint: tir.PrimExpr) -> ConstraintScope:
"""Create a constraint scope.

Parameters
Expand Down Expand Up @@ -290,7 +309,7 @@ def _fenter():

return ConstraintScope(_fenter)

def update(self, var, info, override=False):
def update(self, var: tir.Var, info: ConstIntBound, override: bool = False) -> None:
"""Update infomation about var

Parameters
Expand All @@ -309,7 +328,7 @@ def update(self, var, info, override=False):
else:
raise TypeError("Do not know how to handle type {}".format(type(info)))

def can_prove_equal(self, lhs: "PrimExpr", rhs: "PrimExpr"):
def can_prove_equal(self, lhs: tir.PrimExpr, rhs: tir.PrimExpr) -> bool:
"""Whether we can prove that lhs == rhs

Parameters
Expand Down
4 changes: 4 additions & 0 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ TVM_FFI_REGISTER_GLOBAL("arith.CreateAnalyzer")
self->const_int_bound.Update(args[0].cast<Var>(), args[1].cast<ConstIntBound>(),
args[2].cast<bool>());
});
} else if (name == "const_int_bound_is_bound") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
*ret = self->const_int_bound.IsBound(args[0].cast<Var>());
});
} else if (name == "Simplify") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
if (args.size() == 1) {
Expand Down
4 changes: 4 additions & 0 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class ConstIntBoundAnalyzer::Impl
BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) {}
};

bool IsBound(const Var& var) const { return var_map_.find(var) != var_map_.end(); }

void Bind(const Var& var, const Range& range, bool allow_override) {
Entry a = VisitExpr(range->min);
Entry b = VisitExpr(range->extent);
Expand Down Expand Up @@ -793,6 +795,8 @@ void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool allow_
impl_->Bind(var, range, allow_override);
}

bool ConstIntBoundAnalyzer::IsBound(const Var& var) const { return impl_->IsBound(var); }

std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) {
return impl_->EnterConstraint(constraint);
}
Expand Down
1 change: 1 addition & 0 deletions tests/python/arith/test_arith_const_int_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_const_bounds(self, test_case):

for var, bounds in test_case.known_bounds.items():
analyzer.update(var, ConstIntBound(*bounds))
assert analyzer.const_int_bound_is_bound(var)

with contextlib.ExitStack() as stack:
if test_case.constraint is not None:
Expand Down
Loading