Skip to content

Commit

Permalink
[tir] fix buffer_decl buffer allocation
Browse files Browse the repository at this point in the history
- Fix a bug where `buffer_decl`, combined with certain
  usage patterns of the resulting buffer, cause an TVM-internal
  assert failure during TIR-compilation.
  • Loading branch information
Christian Convey committed Feb 2, 2023
1 parent ea34e6e commit 2035a7e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/tir/transforms/plan_update_buffer_allocation_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/var.h>

#include "ir_utils.h"

Expand Down Expand Up @@ -111,6 +112,12 @@ class BufferAllocationLocator : public StmtExprMutator {
collector(func->body);
unmanaged_allocations_ = collector.unmanaged_allocations;

for (const Var& param : func->params) {
if (param->type_annotation.defined() && param->type_annotation.as<PointerTypeNode>()) {
unmanaged_allocations_.insert(param.get());
}
}

for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
arg_buffer_vars.emplace(buffer->data.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,5 +416,23 @@ def test_allocate_const_after_tensorize():
_ = seq(sch.mod)


def test_buffer_conditional_lowering():
"""
Confirm that the `tir.PlanAndUpdateBufferAllocationLocation` pass
leaves (Buffer nodes corresponding to pointer-typed PrimFunc arguments)
unchanged, rather than lowering them to `reads`, `writes`, and `alloc_buffer` nodes.
"""

@T.prim_func
def before(A: T.Ptr("float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i in range(1):
A_1 = T.Buffer((1,), data=A)
A_1[i] = 0

after = before
_check(before, after)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 2035a7e

Please sign in to comment.