Skip to content

Commit

Permalink
[tir] fix buffer_decl buffer allocation
Browse files Browse the repository at this point in the history
- Fix a bug where `buffer_decl`, combined with certain
  usage patterns of the resulting buffer, cause an TVM-internal
  assert failure during TIR-compilation.
  • Loading branch information
Christian Convey committed Feb 2, 2023
1 parent ea34e6e commit 9166859
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/tir/transforms/plan_update_buffer_allocation_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/var.h>

#include "ir_utils.h"

Expand Down Expand Up @@ -111,6 +112,12 @@ class BufferAllocationLocator : public StmtExprMutator {
collector(func->body);
unmanaged_allocations_ = collector.unmanaged_allocations;

for (const Var& param : func->params) {
if (param->type_annotation.defined() && param->type_annotation.as<PointerTypeNode>()) {
unmanaged_allocations_.insert(param.get());
}
}

for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
arg_buffer_vars.emplace(buffer->data.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,5 +416,29 @@ def test_allocate_const_after_tensorize():
_ = seq(sch.mod)


def test_buffer_decl_allocation():
"""
Confirm that buffer_decl is sufficient for creating
precisely one buffer.
This confirms a fix to
`src/tir/transforms/plan_update_buffer_allocation_location.cc`
in which a declared buffer was erroneously duplicated, resulting in a
TIR-compilation failure.
"""

@tvm.script.ir_module
class IRMod:
@T.prim_func
def func(a: T.Ptr[T.float32]):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.buffer_decl(1, "float32", data=a)
for i in range(1):
A[i] = 0

ir_mod = IRMod
built_mod = tvm.build(ir_mod)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 9166859

Please sign in to comment.