Skip to content

Commit

Permalink
Distinguish between asserts on values and asserts for unpacking
Browse files Browse the repository at this point in the history
Asserts that are required for correct unpacking (e.g. asserting the
number of buffer dimensions) must go in the `init_nest_`.  Asserts on
the unpacked values themselves must go in the `asserts_`, as they may
depend on symbolic variables defined by not-yet-unpacked buffers.
  • Loading branch information
Lunderberg committed Feb 23, 2024
1 parent d102983 commit a7f822a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 36 deletions.
27 changes: 8 additions & 19 deletions src/tir/transforms/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ bool ArgBinder::Bind_(const PrimExpr& arg, const PrimExpr& value, const std::str
}
return true;
} else {
BinderAddAssert(&analyzer_, it->second == value, arg_name, &init_nest_);
BinderAddAssert(&analyzer_, it->second == value, arg_name, &asserts_);
}
} else {
BinderAddAssert(&analyzer_, arg == value, arg_name, &init_nest_);
BinderAddAssert(&analyzer_, arg == value, arg_name, &asserts_);
}
return false;
}
Expand Down Expand Up @@ -113,7 +113,7 @@ void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::st
PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset",
&init_nest_);
&asserts_);
}
}
}
Expand Down Expand Up @@ -191,8 +191,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4))) {
auto type_msg = tvm::tir::StringImm(type_err_msg.str());
init_nest_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
init_nest_.emplace_back(AssertStmt(cond, type_msg, nop));
asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
}

// shape field
Expand Down Expand Up @@ -239,8 +238,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); },
const_true(1), conds),
stride_msg, Evaluate(0));
check = IfThenElse(Not(v_strides_is_null), check, Stmt());
init_nest_.emplace_back(SeqStmt({check, Evaluate(0)}));
check = IfThenElse(Not(v_strides_is_null), check);
asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
DataType stype = buffer->DefaultIndexType();
Expand Down Expand Up @@ -287,7 +286,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset",
&init_nest_);
&asserts_);
}
}
}
Expand All @@ -314,7 +313,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
}
return product;
}();
init_nest_.emplace_back(AssertStmt(
asserts_.emplace_back(AssertStmt(
alloc_size == 0 || !Call(DataType::Bool(), builtin::isnullptr(), {vptr}),
tvm::tir::StringImm(arg_name + " is expected to have non-NULL data pointer"), nop));

Expand All @@ -325,15 +324,5 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
}
}

std::vector<Stmt> ArgBinder::asserts() const {
std::vector<Stmt> asserts;
for (const auto& stmt : init_nest_) {
if (stmt->IsInstance<AssertStmtNode>()) {
asserts.push_back(stmt);
}
}
return asserts;
}

} // namespace tir
} // namespace tvm
33 changes: 16 additions & 17 deletions src/tir/transforms/arg_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,29 +107,26 @@ class ArgBinder {

/*! \return The asserts generated in binding
*
* This function is maintained for backwards compatibility. The
* `init_nest()` function should be used instead.
*
* In earlier implementations, the variable bindings and asserts
* produced by `ArgBinder` were kept as separate lists. However,
* this causes issues when generating the full initialization
* sequence. If all asserts are applied before any variable
* bindings, then the asserts would depend on not-yet-defined
* variables. If all variable bindings are applied before any
* asserts, then the variable defitions may illegally access
* out-of-bounds memory.
*
* In general, the initialization sequence may need to contain both
* variable bindings and assert statements, in mixed order.
*
* This contains statements that assert the correct value has been
* bound. For example, `binder.Bind(var, expr_1)` will produce an
* entry mapping `var` to `expr_1` in the `binder.defs()`. If
* `binder.Bind(var, expr_2)` is called later, then this will
* produce an assert statemtn that `expr_1 == expr_2`.
*
* Note: Some assert statements produced by BindDLTensor are located
* in `binder.init_nest()`, not within `binder.asserts()`. This is
* deliberate, as some values may require checks prior to
* initialization. (e.g. Intializing `m = dl_tensor->shape[3]`
* requires first asserting that `3 < dl_tensor->ndim`.)
*/
std::vector<Stmt> asserts() const;
const std::vector<Stmt>& asserts() const { return asserts_; }

/*!
* \brief Initialization nest generated
*
* This contains both variable bindings and assert statements.
* This contains both variable bindings and any assert statements
* that are required in order to safely produce those variable
* bindings.
*
* \note Variable bindings may be implemented either as a `LetStmt`
* that defines the variable, or as a variable replacement. Any
Expand Down Expand Up @@ -162,6 +159,8 @@ class ArgBinder {
std::vector<Stmt> init_nest_;
/*! \brief handle data type in the defintiions */
Map<Var, PrimExpr> def_handle_dtype_;
/*! \brief asserts generated */
std::vector<Stmt> asserts_;
/*! \brief internal analyzer. */
arith::Analyzer analyzer_;
};
Expand Down

0 comments on commit a7f822a

Please sign in to comment.