Skip to content

Commit

Permalink
[TIR] Avoid re-defining var = arg_var in ArgBinder
Browse files Browse the repository at this point in the history
Prior to this commit, `ArgBinder` would always introduce a new
variable to represent the input argument, even if the argument already
a primitive type.  This introduces trivial let bindings that are
expected to be simplified out, but which can produce dangling
`tir::Var` usage in some cases (see
apache#14951).

This commit updates `ArgBinder` to prefer using the original
`tir::Var` when possible.  That is, when a function takes `n: T.int32`
as input, the packed function should produce a binding `n: T.int32 =
T.tvm_struct_get(...)`, rather than producing a binding `arg_n =
T.tvm_struct_get(...)` followed by `n = arg_n`.
  • Loading branch information
Lunderberg committed May 25, 2023
1 parent 94f4e25 commit 11e2e41
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 33 deletions.
32 changes: 9 additions & 23 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,24 +259,13 @@ PrimFunc MakePackedAPI(PrimFunc func) {
return res;
};

// Need to re-declare vars, in case some arguments also appears in the buffer.
std::vector<std::pair<Var, Var>> var_def;
// Need to delay binding of the buffers, in case some arguments also
// appear in the buffer.
std::vector<std::pair<PrimExpr, Var>> var_def;
std::vector<std::pair<Var, Buffer>> buffer_def;

for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i];
std::string param_name = [&]() {
std::ostringstream oss;
oss << "arg";
if (param->name_hint.defined() && (!param->name_hint.empty())) {
oss << "." << param->name_hint;

} else {
oss << i;
}
return oss.str();
}();
Var v_arg = Var(param_name, param->dtype);

// Pluck the device API context out based on name
if (param->name_hint == kDeviceContextVar) {
Expand All @@ -285,19 +274,16 @@ PrimFunc MakePackedAPI(PrimFunc func) {
continue;
}

var_def.emplace_back(f_arg_value(param.dtype(), i), param);
if (func_ptr->buffer_map.count(param)) {
buffer_def.emplace_back(v_arg, func_ptr->buffer_map[param]);
} else {
var_def.emplace_back(v_arg, param);
buffer_def.emplace_back(param, func_ptr->buffer_map[param]);
}

// Value loads
seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop));
// type code checks
Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
Var tcode(param->name_hint + ".code", DataType::Int(32));
seq_init.emplace_back(
LetStmt(tcode, BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}), nop));
DataType t = v_arg.dtype();
DataType t = param.dtype();
if (t.is_handle()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
Expand Down Expand Up @@ -327,8 +313,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
// either 0 or the original stride will be correctly used. Checks here have
// to use the args that may have no let binding yet. Therefore, hoisting let
// binding for args before buffer declaration is needed.
for (const auto& kv : var_def) {
binder.Bind(kv.second, kv.first, name_hint + "." + kv.first->name_hint, true);
for (const auto& [expr, param] : var_def) {
binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
}

for (const auto& kv : buffer_def) {
Expand Down
17 changes: 7 additions & 10 deletions tests/python/unittest/test_tir_transform_make_packed_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,15 @@ def test_variable_passed_from_args():
assert func.body.condition.b == 2

# Arguments unpacking
assignment = _find_assignment(func.body, "arg.input_buffer")
assignment = _find_assignment(func.body, "input_buffer")
assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")'

assignment = _find_assignment(func.body, "arg.not_device_context")
assert str(assignment.value) == 'T.tvm_struct_get(args, 1, 12, "handle")'

assignment = _find_assignment(func.body, "input_buffer")
assert str(assignment.value) == 'T.tvm_struct_get(arg_input_buffer, 0, 1, "handle")'
assignment = _find_assignment(assignment.body, "input_buffer")
assert str(assignment.value) == 'T.tvm_struct_get(input_buffer, 0, 1, "handle")'
unpacked_input_buffer = assignment.var

assignment = _find_assignment(func.body, "not_device_context")
assert str(assignment.value) == "arg_not_device_context"
assert str(assignment.value) == 'T.tvm_struct_get(args, 1, 12, "handle")'
unpacked_not_device_context = assignment.var

seq_stmt = _find_next(assignment, tvm.tir.SeqStmt)
Expand Down Expand Up @@ -147,11 +144,11 @@ def test_device_api_context_implicit_resource_handle():
assert func.body.condition.b == 1

# Arguments unpacking
assignment = _find_assignment(func.body, "arg.input_buffer")
assignment = _find_assignment(func.body, "input_buffer")
assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")'

assignment = _find_assignment(func.body, "input_buffer")
assert str(assignment.value) == 'T.tvm_struct_get(arg_input_buffer, 0, 1, "handle")'
assignment = _find_assignment(assignment.body, "input_buffer")
assert str(assignment.value) == 'T.tvm_struct_get(input_buffer, 0, 1, "handle")'
unpacked_input_buffer = assignment.var

seq_stmt = _find_next(assignment, tvm.tir.SeqStmt)
Expand Down

0 comments on commit 11e2e41

Please sign in to comment.