Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix][TensorIR] specialize() updates the attrs of PrimFuncs #8606

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class DictAttrsNode : public BaseAttrsNode {
class DictAttrs : public Attrs {
public:
/*!
* \brief Consruct a Attrs backed by DictAttrsNode.
* \brief Construct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
* \return The dict attributes.
*/
Expand Down
35 changes: 28 additions & 7 deletions src/tir/ir/specialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,25 +66,46 @@ class PrimFuncSpecializer : public StmtExprMutator {
}
}

// Updating parmeters
// Updating parameters
Array<Var> params;
bool param_updated = false;
for (const auto& var : f->params) {
// Remove parmeters which has been specialized.
// Remove parameters which have been specialized.
if (var_map.find(var) == var_map.end()) {
params.push_back(var);
} else {
param_updated = true;
}
}

// Updating the `attrs` dictionary
DictAttrs attrs{nullptr};
bool attrs_updated = false;
if (f->attrs.defined()) {
Map<String, ObjectRef> dict;
for (const std::pair<String, ObjectRef>& kv : f->attrs->dict) {
const auto* expr = kv.second.as<PrimExprNode>();
if (expr == nullptr) {
dict.Set(kv.first, kv.second);
continue;
}
PrimExpr result = Substitute(GetRef<PrimExpr>(expr), var_map);
dict.Set(kv.first, result);
if (result.get() != expr) {
attrs_updated = true;
}
}
attrs = DictAttrs(dict);
}

// Updating function body
Stmt body = specializer(f->body);

if (param_updated || buffer_map_updated || !f->body.same_as(body)) {
if (param_updated || buffer_map_updated || attrs_updated || !f->body.same_as(body)) {
PrimFuncNode* f_ptr = f.CopyOnWrite();
f_ptr->params = std::move(params);
f_ptr->buffer_map = std::move(buffer_map);
f_ptr->attrs = std::move(attrs);
f_ptr->body = std::move(body);
}
return f;
Expand Down Expand Up @@ -248,9 +269,9 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer
// build var mapping using specific_buf's parameters
auto build_var_mapping = [&](const PrimExpr& new_expr, const PrimExpr& old_expr) {
if (!equal(new_expr, old_expr)) {
CHECK(old_expr->IsInstance<VarNode>())
<< "TypeError: The signature of target buffer exprected an independent Var, but got "
<< old_expr << ".";
CHECK(old_expr->IsInstance<VarNode>()) << "TypeError: The signature of target buffer is "
"expected to be an independent Var, but it is "
<< old_expr << ".";
const Var& var = Downcast<Var>(old_expr);
auto it = var_map->find(var);
if (it != var_map->end()) {
Expand Down Expand Up @@ -282,7 +303,7 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer
build_var_mapping(specific_buf->elem_offset, buf_to_specialize->elem_offset);

// Check data_alignment and offset_factor.
// These two signatures are int, so we do not need map them.
// These two signatures are int, so we do not need to map them.
CHECK_EQ(specific_buf->data_alignment, buf_to_specialize->data_alignment)
<< "ValueError: The buffer data_alignment mismatched" << buf_to_specialize->data_alignment
<< " vs. " << specific_buf->data_alignment << ".";
Expand Down
36 changes: 36 additions & 0 deletions tests/python/unittest/test_tir_specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,38 @@ def element_wise_128_n(a: ty.handle, c: ty.handle) -> None:
C[vi, vj] = B[vi, vj] + 1.0


@tvm.script.tir
def element_wise_func_attr(a: ty.handle, c: ty.handle) -> None:
m = tir.var("int32")
n = tir.var("int32")
tir.func_attr({"test_attr": m * n})

A = tir.match_buffer(a, (m, n), "float32")
C = tir.match_buffer(c, (m, n), "float32")

B = tir.alloc_buffer((m, n), "float32")

with tir.block([m, n], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0

with tir.block([m, n], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0


@tvm.script.tir
def element_wise_func_attr_128_64(a: ty.handle, c: ty.handle) -> None:
tir.func_attr({"test_attr": 128 * 64})
A = tir.match_buffer(a, (128, 64), "float32")
C = tir.match_buffer(c, (128, 64), "float32")
B = tir.alloc_buffer((128, 64), "float32")

with tir.block([128, 64], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0

with tir.block([128, 64], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0


@tvm.script.tir
def mem_copy(
a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty.int32, q: ty.int32
Expand Down Expand Up @@ -172,6 +204,10 @@ def test_specialize_elemwise():
# partially specialized
func = element_wise.specialize({c: tir.decl_buffer((128, C.shape[1]))})
tvm.ir.assert_structural_equal(func, element_wise_128_n)
# specialization also updates the attributes dictionary
a_attr, _ = element_wise_func_attr.params
func = element_wise_func_attr.specialize({a_attr: tir.decl_buffer((128, 64))})
tvm.ir.assert_structural_equal(func, element_wise_func_attr_128_64)


def test_specialize_mem_copy():
Expand Down