Skip to content

Commit

Permalink
[TE] Promote substituted variable to iter_var's dtype (apache#10571)
Browse files Browse the repository at this point in the history
* [TE] Promote substituted variable to iter_var's dtype

This fixes a bug where an iteration variable and its associated loop
variable have a mismatched dtype.

* add check to iter var constructor. fix two bad uses

* proplem is more complicated then I thought

* one more fix

* remove old comments
  • Loading branch information
Tristan Konolige authored Mar 12, 2022
1 parent 5dc4015 commit 4cdbf5c
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 43 deletions.
4 changes: 4 additions & 0 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ enum IterVarType : int {
/*!
* \brief An iteration variable representing an iteration
* over a one dimensional interval.
*
* The dtype of the extent of the `dom` of the IterVar must match the dtype of the internal Var.
*/
class IterVarNode : public Object {
public:
Expand Down Expand Up @@ -293,6 +295,8 @@ class IterVarNode : public Object {
/*!
* \brief Iteration Variable,
* represents an iteration over an integer interval.
*
* The dtype of the extent of the `dom` of the IterVar must match the dtype of the internal Var.
*/
class IterVar : public ObjectRef {
public:
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,10 @@ def __init__(self, dom, var, iter_type, thread_tag="", span=None):
name = var if var is not None else "iter"
dtype = "int32" if dom is None else dom.extent.dtype
var = Var(name, dtype=dtype, span=span) if not isinstance(var, Var) else var
if dom is not None:
assert (
var.dtype == dom.extent.dtype
), "IterVar's Var dtype must match its domain's extent's dtype"
self.__init_handle_by_constructor__(
_ffi_api.IterVar, dom, var, iter_type, thread_tag, span # type: ignore
)
Expand Down
6 changes: 2 additions & 4 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,10 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
Var new_var(iter_var->var->name_hint, iter_var->var->dtype);
var_map[iter_var->var.get()] = new_var;

IterVarNode* iter_var_node = iter_var.CopyOnWrite();
const PrimExpr& dom_min = analyzer->Simplify(iter_var->dom->min);
const PrimExpr& dom_extent = analyzer->Simplify(iter_var->dom->extent);
iter_var_node->dom = Range::FromMinExtent(dom_min, dom_extent);
iter_var_node->var = new_var;
iter_vars.push_back(iter_var);
iter_vars.push_back(IterVar(Range::FromMinExtent(dom_min, dom_extent), new_var,
iter_var->iter_type, iter_var->thread_tag, iter_var->span));
}
};
f_push_block_vars(compute_op->axis);
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/hybrid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ std::vector<IterVar> GatherLoopVars(Stmt stmt) {
PostOrderVisit(stmt, [&res_](const ObjectRef& node) {
if (const ForNode* op = node.as<ForNode>()) {
Var loop_var(op->loop_var);
Range dom = Range::FromMinExtent(op->min, op->extent);
Range dom = Range::FromMinExtent(op->min, cast(loop_var.dtype(), op->extent));
res_.push_back(IterVar(dom, loop_var, ForKindToIterVarType(op->kind)));
}
});
Expand Down
41 changes: 27 additions & 14 deletions src/te/operation/op_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ namespace te {
using namespace arith;
using namespace tir;

DataType LargerDataType(DataType a, DataType b) { return a.bits() > b.bits() ? a : b; }

std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
size_t begin_iter_pos, bool new_loop_var,
Expand Down Expand Up @@ -67,6 +69,17 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,

Range dom = dom_map.at(iv);

// This is a hack to ensure that the replacing expression has the same
// dtype as the replacing expression. This happens when a thread/block
// itervar is bound to another itervar. Because the thread/block itervar
// has no way to know its correct dtype before it is bound, it defaults to
// int32. Then the itervar it is bound to may have a different dtype. The
// thread/block dtype really should be promoted to dtype of what it is
// bound to (in `bind`) but that would require inplace modification of the
// itervar.
// XXX: we will get integer overflow if the bound itervar is greater than int32::max.
auto promote_to_bound_dtype = [&iv](PrimExpr e) { return cast(iv->var.dtype(), e); };

// initialize the offset and loop_level
Var var = bind_iv->var;

Expand Down Expand Up @@ -112,15 +125,15 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
}
}
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
nest[i + 1].emplace_back(LetStmt(var, cast(var.dtype(), dom->min), no_op));
value_map[iv] = cast(var.dtype(), dom->min);
nest[i + 1].emplace_back(LetStmt(var, promote_to_bound_dtype(dom->min), no_op));
value_map[iv] = promote_to_bound_dtype(dom->min);
} else if (is_zero(dom->min)) {
nest[i + 1].emplace_back(For(var, 0, dom->extent, kind, no_op));
value_map[iv] = var;
value_map[iv] = promote_to_bound_dtype(var);
} else {
Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype());
nest[i + 1].emplace_back(For(idx, 0, dom->extent, kind, no_op));
PrimExpr new_value = dom->min + idx;
Var idx(bind_iv->var->name_hint + ".idx", iv->var.dtype());
nest[i + 1].emplace_back(For(idx, 0, promote_to_bound_dtype(dom->extent), kind, no_op));
PrimExpr new_value = promote_to_bound_dtype(dom->min + idx);
value_map[iv] = new_value;
nest[i + 1].emplace_back(LetStmt(var, new_value, no_op));
}
Expand All @@ -139,44 +152,44 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
ICHECK(is_positive_const(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread, dom->extent, no_op));
value_map[iv] = var;
value_map[iv] = promote_to_bound_dtype(var);
} else if (bind_iv->thread_tag == "pipeline") {
// pipeline marker.
ICHECK(is_zero(dom->min));
ICHECK(is_one(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op));
value_map[iv] = dom->min;
value_map[iv] = promote_to_bound_dtype(dom->min);
} else {
// Always restrict threaded IterVar to starts from 0.
ICHECK(is_zero(dom->min)) << "Itervar " << iv << " must start at zero, but it starts at "
<< dom->min;
// annotate the extent of the IterVar
nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, dom->extent, no_op));
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
value_map[iv] = dom->min;
value_map[iv] = promote_to_bound_dtype(dom->min);
} else if (stage->scope == "") {
value_map[iv] = var;
value_map[iv] = promote_to_bound_dtype(var);
} else {
runtime::ThreadScope ts = runtime::ThreadScope::Create(bind_iv->thread_tag);
runtime::StorageScope ss = runtime::StorageScope::Create(stage->scope);
if (static_cast<int>(ss.rank) <= ts.rank) {
value_map[iv] = var;
value_map[iv] = promote_to_bound_dtype(var);
} else if (stage->scope == "warp" && ts.rank == 1) {
// To determine whether a thread index is inside or outside a warp, we need
// to know the thread extent. We leave a warning for now.
if (ts.dim_index == 0) {
value_map[iv] = var;
value_map[iv] = promote_to_bound_dtype(var);
} else {
LOG(WARNING)
<< "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. "
<< "TVM assumes only threadIdx.x indicates threads inside a warp, "
<< "while threadIdx.y and threadIdx.z indicates different warps.";
value_map[iv] = dom->min;
value_map[iv] = promote_to_bound_dtype(dom->min);
}
} else {
value_map[iv] = dom->min;
value_map[iv] = promote_to_bound_dtype(dom->min);
}
}
}
Expand Down
8 changes: 5 additions & 3 deletions src/te/schedule/bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,11 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
ret[iv] = iv->dom;
}
}
for (auto& p : ret) {
ret[p.first] =
Range::FromMinExtent(analyzer.Simplify(p.second->min), analyzer.Simplify(p.second->extent));
for (auto it = ret.begin(); it != ret.end(); it++) {
it->second = Range::FromMinExtent(
analyzer.Simplify(it->second->min),
// The range associated with each itervar must have the same dtype as it
cast(it->first->var.dtype(), analyzer.Simplify(it->second->extent)));
}
return Map<IterVar, Range>(ret.begin(), ret.end());
}
Expand Down
8 changes: 6 additions & 2 deletions src/te/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,16 @@ void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_st
};
if (r->factor.defined()) {
Update(p_state, r->inner,
Range::FromMinExtent(0, resolve_min_extent_for_split(r->inner, r->factor)), actx);
Range::FromMinExtent(0, cast(range_parent->extent.dtype(),
resolve_min_extent_for_split(r->inner, r->factor))),
actx);
Update(p_state, r->outer,
Range::FromMinExtent(0, ceil_div(range_parent->extent, r->factor)), actx);
} else {
Update(p_state, r->outer,
Range::FromMinExtent(0, resolve_min_extent_for_split(r->outer, r->nparts)), actx);
Range::FromMinExtent(0, cast(range_parent->extent.dtype(),
resolve_min_extent_for_split(r->outer, r->nparts))),
actx);
Update(p_state, r->inner,
Range::FromMinExtent(0, ceil_div(range_parent->extent, r->nparts)), actx);
}
Expand Down
16 changes: 6 additions & 10 deletions src/te/schedule/schedule_dataflow_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -789,21 +789,18 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int f
n->name = compute_op->name + ".rf";
{
// axis relacement.
auto iv_node = make_object<IterVarNode>();
iv_node->dom = dom_map.at(axis);
ICHECK(is_zero(iv_node->dom->min)) << "Can only factor reduction domain starting from 0";
iv_node->var = axis->var;
iv_node->iter_type = kDataPar;
IterVar iv(dom_map.at(axis), axis->var, kDataPar);
ICHECK(is_zero(iv->dom->min)) << "Can only factor reduction domain starting from 0";

const int size = compute_op->axis.size();
for (int idx = 0; idx < size; ++idx) {
if (factor_axis_pos == idx) {
n->axis.push_back(IterVar(iv_node));
n->axis.push_back(iv);
}
n->axis.push_back(compute_op->axis[idx]);
}
if (factor_axis_pos == size) {
n->axis.push_back(IterVar(iv_node));
n->axis.push_back(iv);
}
}
// predicate generation, copy not touched axis.
Expand Down Expand Up @@ -832,9 +829,8 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int f
for (IterVar iv : reduce_stage->leaf_iter_vars) {
if (touch_map.count(iv) && !iv.same_as(axis)) {
ICHECK_EQ(iv->iter_type, kCommReduce);
auto ncpy = make_object<IterVarNode>(*iv.operator->());
ncpy->dom = dom_map.at(iv);
n->reduce_axis.push_back(IterVar(ncpy));
IterVar ncpy(dom_map.at(iv), iv->var, iv->iter_type, iv->thread_tag, iv->span);
n->reduce_axis.push_back(ncpy);
}
}
VarReplacer replacer(vsub);
Expand Down
7 changes: 5 additions & 2 deletions src/te/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ namespace tvm {
namespace te {

IterVar thread_axis(Range dom, std::string tag) {
return IterVar(dom, Var(tag), kThreadIndex, tag);
return IterVar(dom, Var(tag, dom.defined() ? dom->extent.dtype() : DataType::Int(32)),
kThreadIndex, tag);
}

IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name), kCommReduce); }
IterVar reduce_axis(Range dom, std::string name) {
return IterVar(dom, Var(name, dom->extent.dtype()), kCommReduce);
}

Var var(std::string name_hint, DataType t) { return Var(name_hint, t); }

Expand Down
5 changes: 5 additions & 0 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
// IterVar
IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span) {
ObjectPtr<IterVarNode> n = make_object<IterVarNode>();
if (dom.defined() && dom->extent.defined()) {
CHECK_EQ(dom->extent.dtype(), var.dtype())
<< "The dtype of the extent of an IterVar (" << dom->extent.dtype()
<< ") must match its associated Var's dtype (" << var.dtype() << ")";
}
n->dom = dom;
n->var = var;
n->iter_type = t;
Expand Down
5 changes: 2 additions & 3 deletions src/tir/schedule/primitive/blockize_tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,8 @@ class BlockizedBindingExtractor {
outer_iter_vars.push_back(outer_var);
PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * division[i][1]->extent;
// create iter var for the inner block
IterVar new_iter = iter_var;
auto* new_iter_node = new_iter.CopyOnWrite();
new_iter_node->dom = Range::FromMinExtent(0, division[i][1]->extent);
IterVar new_iter(Range::FromMinExtent(0, division[i][1]->extent), Var(iter_var->var),
iter_var->iter_type, iter_var->thread_tag, iter_var->span);
inner_iter_dom_map.Set(new_iter->var, arith::IntSet::FromRange(new_iter->dom));
analyzer->Bind(new_iter->var, new_iter->dom);
inner_iter_vars.push_back(new_iter);
Expand Down
6 changes: 2 additions & 4 deletions src/tir/transforms/unify_thread_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,8 @@ class ThreadBindingUnifier : public StmtExprMutator {
<< "` should have the same extent. However, there are two loops with extent "
<< new_iter_var->dom->extent << " and " << dom->extent << ", which are not equal";
} else {
ObjectPtr<IterVarNode> p_new_iter_var = make_object<IterVarNode>(*old_iter_var.get());
p_new_iter_var->var = Var(thread_tag);
p_new_iter_var->dom = dom;
new_iter_var = IterVar(p_new_iter_var);
new_iter_var = IterVar(dom, Var(thread_tag, dom->extent.dtype()), old_iter_var->iter_type,
old_iter_var->thread_tag);
thread_tag2iter_var_map_.Set(thread_tag, new_iter_var);
launch_threads_.push_back(new_iter_var);
}
Expand Down

0 comments on commit 4cdbf5c

Please sign in to comment.