Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] DeduceBound #40

Merged
merged 27 commits into from
Feb 17, 2017
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion HalideIR
Submodule HalideIR updated from 642ae5 to e68ae6
1 change: 1 addition & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._ctypes._node import register_node

from . import tensor
from . import arith
from . import expr
from . import stmt
from . import make
Expand Down
1 change: 1 addition & 0 deletions python/tvm/_ctypes/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def _init_api_functions(root_namespace):
module_internal = sys.modules["%s._api_internal" % root_namespace]
namespace_match = {
"_make_": sys.modules["%s.make" % root_namespace],
"_arith_": sys.modules["%s.arith" % root_namespace],
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
"_codegen_": sys.modules["%s.codegen" % root_namespace],
"_schedule_": sys.modules["%s.schedule" % root_namespace]
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/arith.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# pylint: disable=protected-access, no-member
"""Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node
from . import _api_internal

@register_node
class IntSet(NodeBase):
pass

@register_node
class IntervalSet(IntSet):
def min(self):
return _api_internal._IntervalSetGetMin(self)

def max(self):
return _api_internal._IntervalSetGetMax(self)

@register_node
class StrideSet(IntSet):
pass

43 changes: 43 additions & 0 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to arith
* \file api_arith.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include "../arithmetic/int_set.h"
#include "../arithmetic/int_set_internal.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this include


namespace tvm {
namespace arith {

TVM_REGISTER_API(_arith_intset_single_point)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::single_point(args[0]);
});

TVM_REGISTER_API(_arith_intset_range)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::range(args[0], args[1]);
});

TVM_REGISTER_API(_arith_DeduceBound)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], args[2]);
});

TVM_REGISTER_API(_IntervalSetGetMin)
.set_body([](TVMArgs args, TVMRetValue *ret) {
IntSet s = args[0].operator IntSet();
*ret = s.as<IntervalSet>()->i.min;
});

TVM_REGISTER_API(_IntervalSetGetMax)
.set_body([](TVMArgs args, TVMRetValue *ret) {
IntSet s = args[0].operator IntSet();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider do it in one line

* ret = args[0].operator IntSet()
           .max()

*ret = s.as<IntervalSet>()->i.max;
Copy link
Member

@tqchen tqchen Feb 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe consider add a function max to integer set, instead of exposing IntervalSet

});

} // namespace arith
} // namespace tvm
203 changes: 203 additions & 0 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
/*!
* Copyright (c) 2017 by Contributors
* \file bound_deducer.cc
* \brief Utility to deduce bound of expression
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h>
#include <unordered_set>
#include <unordered_map>
#include "./int_set.h"

namespace tvm {
namespace arith {

using namespace ir;
using Halide::Internal::Interval;

// a visitor to find the path to the target variable
// from a expression.
class VariablePathFinder: public IRVisitor {
public:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VariablePathFinder

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to look out for errors when a variable appears in multiple locations in the expression

explicit VariablePathFinder(Var target) : target_(target) {}

void Visit(const NodeRef& node) final {
if (!success) return;
if (visited_.count(node.get()) != 0 &&
!node.same_as(target_)) {
return;
}
visited_.insert(node.get());

if (!found_) path_.push_back(node.get());
if (node.same_as(target_)) {
if (!found_) {
found_ = true;
} else {
// target variable appears at multiple location
success = false;
return;
}
}
IRVisitor::Visit(node);
if (!found_) path_.pop_back();
}

std::vector<const Node*> path_;
bool success{true};

private:
bool found_{false};
Var target_;
std::unordered_set<const Node*> visited_;
};

// get the path to the variable,
// return empty vector to represent failure
std::vector<const Node*> GetPath(Var target, Expr expr) {
VariablePathFinder v(target);
v.Visit(expr);
return v.success ? v.path_ : std::vector<const Node*>();
}

// a visitor to deduce the bound of a variable from a expression
class BoundDeducer: public IRVisitor {
public:
BoundDeducer(Var target, Expr expr,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing everything in constructor have a problem of not being able to throw exception out, consider do it in another function, say Deduce

const std::unordered_map<const Variable*, IntSet>& dom_map)
: target_(target), expr_(expr), dom_map_(dom_map) {
// get the path
path_ = GetPath(target, expr);
if (path_.empty()) {
success = false;
return;
}
iter_ = 0;
result = make_zero(expr.type());
// get the sign of every subexpr
expr_map_ = EvalSetForEachSubExpr(expr, dom_map);

Visit(expr);
}

void Visit(const NodeRef& e) final {
if (!success) return;
if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e);
} else {
success = false;
return;
}
}

void Visit_(const LT* op) final {
is_greater = false;
is_equal = false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be bad practice to detect it inside visitor. What if we have something like

a > 10 && b < 8

Since only simple case is handled, consider do the detection outside the visitor

Copy link
Contributor Author

@ZihengJiang ZihengJiang Feb 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add a check before deduce

result = op->b;
Visit(op->a);
}

void Visit_(const LE* op) final {
is_greater = false;
is_equal = true;
result = op->b;
Visit(op->a);
}

void Visit_(const GT* op) final {
is_greater = true;
is_equal = false;
result = op->b;
Visit(op->a);
}

void Visit_(const GE* op) final {
is_greater = true;
is_equal = true;
result = op->b;
Visit(op->a);
}

void Visit_(const Add* op) final {
bool left = op->a.get() == path_[iter_];
result -= left ? op->b : op->a;
Visit(left ? op->a : op->b);
}

void Visit_(const Sub* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
result += op->b;
} else {
result -= op->a;
result = -1 * result;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to -result, or 0- result, negation should be overloaded already?

is_greater = !is_greater;
}
Visit(left ? op->a : op->b);
}

void Visit_(const Mul* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;

SignType sign;
if (operand.type().is_uint()) {
sign = kPositive;
} else {
sign = expr_map_[operand].sign_type();
}

if (sign == SignType::kNegative) {
is_greater = !is_greater;
} else if (sign == SignType::kUnknown) {
// unable to get the sign of operand
success = false;
return;
}

// always use relax bound
if (is_greater) {
result = result / operand + 1;
} else {
result = result / operand - 1;
}
Visit(left ? op->a : op->b);
}

Expr result;
bool is_greater{true};
bool is_equal{true};
bool success{true};

private:
Var target_;
Expr expr_;
const std::unordered_map<const Variable*, IntSet>& dom_map_;
std::vector<const Node*> path_;
size_t iter_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

directly initialize it to 0 here

ExprIntSetMap expr_map_;
};

// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
IntSet DeduceBound(Var v, Expr e,
const Map<Var, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
dmap[kv.first.get()] = kv.second;
}
BoundDeducer d(v, e, dmap);
if (!d.success) return IntSet();
Expr min = Interval::neg_inf, max = Interval::pos_inf;
if (d.is_greater) {
min = d.is_equal ? d.result : d.result+1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space between add

} else {
max = d.is_equal ? d.result : d.result-1;
}
return IntSet::range(min, max);
}

} // namespace arith
} // namespace tvm
Loading