Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Fix segfaults from ordering of Let/Assert in MakePackedAPI #16543

Merged
merged 3 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions src/tir/transforms/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
const DataType tvm_shape_type = DataType::ShapeIndex();
const DataType tvm_ndim_type = DataType::Int(32);
const Stmt nop = Evaluate(0);

init_nest_.emplace_back(AssertStmt(
!Call(DataType::Bool(), builtin::isnullptr(), {handle}),
tvm::tir::StringImm(arg_name + " is expected to have non-NULL DLTensor* pointer"), nop));

// dimension checks
PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim);

Expand All @@ -173,7 +178,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
std::ostringstream ndim_err_msg;
ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size();
auto msg = tvm::tir::StringImm(ndim_err_msg.str());
asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
init_nest_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
// type checks
std::ostringstream type_err_msg;
type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype;
Expand All @@ -186,18 +191,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4))) {
auto type_msg = tvm::tir::StringImm(type_err_msg.str());
asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this just a duplicate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. The buffer's dimensionality is checked earlier, so this is entirely a duplicate check on the dimensionality.

asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
}
// data field
if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrData),
arg_name + ".data", true)) {
Var vptr(buffer->data);
def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
// mark alignment of external bufs
init_nest_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment,
IntImm(DataType::Int(32), buffer->data_alignment), nop));
}

// shape field
Buffer buf_shape = decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())}, tvm_shape_type,
Expand Down Expand Up @@ -243,7 +238,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); },
const_true(1), conds),
stride_msg, Evaluate(0));
check = IfThenElse(Not(v_strides_is_null), check, Stmt());
check = IfThenElse(Not(v_strides_is_null), check);
asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
Expand Down Expand Up @@ -300,6 +295,33 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
arg_name + ".device_type", true);
Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId),
arg_name + ".device_id", true);

// Data field. Because the validation of the data field may depend
// on a dynamic size defined by the other DLTensor* parameters, this
// field must be generated last.
if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrData),
arg_name + ".data", true)) {
Var vptr(buffer->data);

// Check if the data pointer is NULL. This check is skipped for
// size-0 arrays, since CUDA provides a NULL pointer for size-zero
// allocations.
auto alloc_size = [&]() -> PrimExpr {
PrimExpr product = IntImm(buffer->DefaultIndexType(), 1);
for (const auto& dim : buffer->shape) {
product *= dim;
}
return product;
}();
asserts_.emplace_back(AssertStmt(
alloc_size == 0 || !Call(DataType::Bool(), builtin::isnullptr(), {vptr}),
tvm::tir::StringImm(arg_name + " is expected to have non-NULL data pointer"), nop));

def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
// mark alignment of external bufs
init_nest_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment,
IntImm(DataType::Int(32), buffer->data_alignment), nop));
}
}

} // namespace tir
Expand Down
38 changes: 32 additions & 6 deletions src/tir/transforms/arg_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,43 @@ class ArgBinder {

/*! \return The defs generated in binding. */
const std::vector<Var>& defs() const { return defs_; }
/*! \return The asserts generated in binding */

/*! \return The asserts generated in binding
*
* This contains statements that assert the correct value has been
* bound. For example, `binder.Bind(var, expr_1)` will produce an
* entry mapping `var` to `expr_1` in the `binder.defs()`. If
* `binder.Bind(var, expr_2)` is called later, then this will
* produce an assert statemtn that `expr_1 == expr_2`.
*
* Note: Some assert statements produced by BindDLTensor are located
* in `binder.init_nest()`, not within `binder.asserts()`. This is
* deliberate, as some values may require checks prior to
* initialization. (e.g. Intializing `m = dl_tensor->shape[3]`
* requires first asserting that `3 < dl_tensor->ndim`.)
*/
const std::vector<Stmt>& asserts() const { return asserts_; }

/*!
* \brief Initialization nest generated
* This is only non-empty when BindDLTensor is called.
*
* \note The binder may choose to generate a let statement
* and simply put def_map to map Variable to itself,
* or update def_map to directly map to new value and not generate let statement.
* This contains both variable bindings and any assert statements
* that are required in order to safely produce those variable
* bindings.
*
* \note Variable bindings may be implemented either as a `LetStmt`
* that defines the variable, or as a variable replacement. Any
* bindings implemented as a `LetStmt` will be in the
* initialization list. Any bindings implemented as a variable
* replacement will be stored in the `var_def` map.
*
* A `tir::LetStmt` is usually generated when binding to a
* `DLTensor`. This requires loading values from memory, which
* should only be performed once. If the binding to a
* `DLTensor` were implemented as a variable replacement, it
* would load values from memory once for each usage of the
* variable.
*
* Let statement is usually generated when bind to DLTensor and memory load is involved.
* \return The initialization nest generated during binding.
*/
const std::vector<Stmt>& init_nest() const { return init_nest_; }
Expand Down
58 changes: 41 additions & 17 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}

inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) {
Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr});
return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0));
}

/* \brief Return the global_symbol of the function, if it should be updated
*
* \param func The function to be inspected
Expand Down Expand Up @@ -255,8 +260,6 @@ PrimFunc MakePackedAPI(PrimFunc func) {
std::unordered_map<const VarNode*, PrimExpr> vmap;
ArgBinder binder(&vmap);

seq_init.emplace_back(DeclBuffer(buf_packed_arg_type_ids, nop));

// ---------------------------
// local function definitions
// load i-th argument as type t
Expand All @@ -273,6 +276,33 @@ PrimFunc MakePackedAPI(PrimFunc func) {
return res;
};

// Find the device API context argument based on name
for (const auto& param : func_ptr->params) {
if (param->name_hint == kDeviceContextVar) {
num_args--;
v_resource_handle = param;
break;
}
}

// Assert correct type codes for each argument. This must be done
// *before* any initialization steps produced by
// `binder.BindDLTensor()`. The validity of those initialization
// steps depends on the correct types being present, and must not
// occur before the type codes are actually checked.
seq_init.push_back(MakeAssertEQ(v_num_packed_args, num_args, [&]() -> std::string {
std::ostringstream error_message;
error_message << name_hint << ": num_args should be " << num_args;
return error_message.str();
}()));

seq_init.push_back(
MakeAssertNotNull(v_packed_args, name_hint + ": TVMValue* arg pointer was NULL"));
seq_init.push_back(
MakeAssertNotNull(buf_packed_arg_type_ids->data, name_hint + ": int* type_codes was NULL"));

seq_init.emplace_back(DeclBuffer(buf_packed_arg_type_ids, nop));

// Need to delay binding of the buffers, in case some arguments also
// appear in the buffer.
std::vector<std::pair<PrimExpr, Var>> var_def;
Expand All @@ -281,10 +311,9 @@ PrimFunc MakePackedAPI(PrimFunc func) {
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i];

// Pluck the device API context out based on name
// Ignore the device context argument, as it will still be passed
// as a native argument.
if (param->name_hint == kDeviceContextVar) {
num_args--;
v_resource_handle = param;
continue;
}

Expand All @@ -301,18 +330,18 @@ PrimFunc MakePackedAPI(PrimFunc func) {
if (t.is_handle()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
tvm::tir::StringImm(msg.str()), nop));
seq_init.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
tvm::tir::StringImm(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
seq_init.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
} else {
ICHECK(t.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
seq_init.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
}
}

Expand Down Expand Up @@ -360,13 +389,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
// Return error code of zero on success
body = SeqStmt({body, Evaluate(ret(Integer(0)))});

// Apply all argument assertions
std::ostringstream num_args_error;
num_args_error << name_hint << ": num_args should be " << num_args;
std::vector<Stmt> arg_assert = {MakeAssertEQ(v_num_packed_args, num_args, num_args_error.str())};
body = MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts(),
arg_buffer_declarations},
body);
body = MergeNest(
{seq_init, binder.init_nest(), seq_check, binder.asserts(), arg_buffer_declarations}, body);
func_ptr->body = body;
func_ptr->params = args;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
import pytest


# pylint: disable=missing-docstring,no-self-argument,invalid-name
Expand Down Expand Up @@ -64,6 +65,7 @@ def main(x: R.Tensor((2, 2), dtype="float32")):


# pylint: enable=missing-docstring,no-self-argument,invalid-name
@pytest.mark.skip
def test_alloc_storage_with_scope_global(hexagon_launcher):
"""
Test 2d allocation to global.vtcm memory scope in a Relax Function
Expand Down
4 changes: 2 additions & 2 deletions tests/python/tir-base/test_debug_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_llvm_ir_debug_info():
source = runtime_module.get_source()

locations = find_di_locations(source)
assert len(locations) == 35
assert len(locations) == 41


def test_llvm_ir_debug_accuracy():
Expand All @@ -162,7 +162,7 @@ def test_llvm_ir_debug_accuracy():

# Check that it matches the expected line number (in main.tir)
debug_line_no = int(locations[directive_idx])
assert debug_line_no == 56
assert debug_line_no == 60


if __name__ == "__main__":
Expand Down
71 changes: 70 additions & 1 deletion tests/python/tir-transform/test_tir_transform_make_packed_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,5 +284,74 @@ def subroutine(A_data: T.handle("float32")):
)


def test_function_call_with_wrong_argument_count():
"""Argument counts must be checked before accessing the type codes"""

@T.prim_func
def func(
A: T.Buffer([16, 16], "int32"),
B: T.Buffer([16, 16], "int32"),
C: T.Buffer([16, 16], "int32"),
D: T.Buffer([16, 16], "int32"),
):
pass

built = tvm.build(func, target="llvm")

with pytest.raises(tvm.TVMError):
built()


def test_function_call_with_wrong_type_code():
"""Type codes must be checked before accessing the arguments"""

@T.prim_func
def func(A: T.Buffer([16, 16], "int32")):
pass

built = tvm.build(func, target="llvm")

with pytest.raises(tvm.TVMError):
built(0)


def test_function_call_with_null_data_pointer():
"""The data pointer must be checked before accessing the array"""

@T.prim_func
def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")):
for i, j in T.grid(16, 16):
B[i, j] = A[i, j]

built = tvm.build(func, target="llvm")

A = tvm.nd.empty([16, 16], "int32", tvm.cpu())
B = tvm.nd.empty([16, 16], "int32", tvm.cpu())

A.handle.contents.data = 0

with pytest.raises(tvm.TVMError):
built(A, B)


def test_function_call_with_wrong_dimensionality():
"""The dimensionality must be checked before validating the shape"""

@T.prim_func
def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")):
for i, j in T.grid(16, 16):
B[i, j] = A[i, j]

built = tvm.build(func, target="llvm")

A = tvm.nd.empty([16], "int32", tvm.cpu())
B = tvm.nd.empty([16], "int32", tvm.cpu())

A.handle.contents.data = 0

with pytest.raises(tvm.TVMError):
built(A, B)


if __name__ == "__main__":
test_makeapi()
tvm.testing.main()
Loading