Skip to content

Commit

Permalink
[Refactor] Remove scope attribute from Buffer class (#8463)
Browse files Browse the repository at this point in the history
Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
masahi and masa authored Jul 20, 2021
1 parent 1141709 commit 1a1be09
Show file tree
Hide file tree
Showing 16 changed files with 65 additions and 54 deletions.
15 changes: 8 additions & 7 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ class BufferNode : public Object {
// Meta data
/*! \brief optional name of the buffer */
String name;
/*! \brief storage scope of the buffer, if other than global */
String scope;
/*! \brief Alignment requirement of data pointer in bytes. */
int data_alignment;
/*!
Expand All @@ -93,7 +91,6 @@ class BufferNode : public Object {
v->Visit("strides", &strides);
v->Visit("elem_offset", &elem_offset);
v->Visit("name", &name);
v->Visit("scope", &scope);
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
v->Visit("buffer_type", &buffer_type);
Expand All @@ -105,7 +102,7 @@ class BufferNode : public Object {
// in its semantics, skip name as name is not important.
return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) &&
equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) &&
equal.DefEqual(elem_offset, other->elem_offset) && equal(scope, other->scope) &&
equal.DefEqual(elem_offset, other->elem_offset) &&
equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type);
}

Expand All @@ -115,7 +112,6 @@ class BufferNode : public Object {
hash_reduce.DefHash(shape);
hash_reduce.DefHash(strides);
hash_reduce.DefHash(elem_offset);
hash_reduce(scope);
hash_reduce(data_alignment);
hash_reduce(buffer_type);
}
Expand All @@ -141,8 +137,8 @@ class Buffer : public ObjectRef {
// User can specify data_alignment and offset_factor to be 0
// A default value will be picked.
TVM_DLL Buffer(Var ptr, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
PrimExpr elem_offset, String name, String scope, int data_alignment,
int offset_factor, BufferType buffer_type, Span span = Span());
PrimExpr elem_offset, String name, int data_alignment, int offset_factor,
BufferType buffer_type, Span span = Span());

/*!
* \brief Return a new buffer that is equivalent with current one
Expand Down Expand Up @@ -182,6 +178,11 @@ class Buffer : public ObjectRef {
*/
TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;

/*!
* \brief Return the storage scope associated with this buffer.
*/
TVM_DLL String scope() const;

TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode);
};
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/topi/detail/extern.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ using namespace tvm::te;
inline Buffer DeclExternBuffer(Array<PrimExpr> shape, DataType dtype, std::string name) {
auto data = var(name, DataType::Handle());
auto elem_offset = PrimExpr();
return Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "", -1, 0, kDefault);
return Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, -1, 0, kDefault);
}

/*!
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def match_buffer_region(
data=None,
strides=strides,
elem_offset=elem_offset,
scope=buffer_region.buffer.scope,
scope=buffer_region.buffer.scope(),
data_alignment=align,
offset_factor=offset_factor,
span=span,
Expand Down
10 changes: 9 additions & 1 deletion python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ def vstore(self, begin, value):
begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
return _ffi_api.BufferVStore(self, begin, value) # type: ignore

def scope(self):
"""Return the storage scope associated with this buffer.
Returns
-------
scope : str
The storage scope associated with this buffer.
"""
return _ffi_api.BufferStorageScope(self) # type: ignore


def decl_buffer(
shape,
Expand Down Expand Up @@ -260,7 +269,6 @@ def decl_buffer(
strides,
elem_offset,
name,
scope,
data_alignment,
offset_factor,
buffer_type,
Expand Down
2 changes: 1 addition & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std
elem_offset = PrimExpr();
}

return tir::Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "", data_alignment,
return tir::Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, data_alignment,
offset_factor, buffer_type);
}

Expand Down
4 changes: 2 additions & 2 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) {
if (!is_zero(buf->elem_offset)) {
doc << ", elem_offset=" << Print(buf->elem_offset);
}
if (buf->scope != "global") {
doc << ", scope=" << Doc::StrLiteral(buf->scope);
if (GetRef<Buffer>(buf).scope() != "global") {
doc << ", scope=" << Doc::StrLiteral(GetRef<Buffer>(buf).scope());
}
if (buf->data_alignment != 128) {
doc << ", align=" << buf->data_alignment;
Expand Down
4 changes: 2 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
} else {
doc << ", elem_offset=" << Print(buf->elem_offset);
}
if (buf->scope != "global") {
doc << ", scope=" << Doc::StrLiteral(buf->scope);
if (buf.scope() != "global") {
doc << ", scope=" << Doc::StrLiteral(buf.scope());
}
if (buf->data_alignment != -1) {
doc << ", align=" << buf->data_alignment;
Expand Down
31 changes: 19 additions & 12 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, String name, String st
Span span) {
DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape,
Array<PrimExpr>(), PrimExpr(), name, "", 0, 0, kDefault, span);
Array<PrimExpr>(), PrimExpr(), name, 0, 0, kDefault, span);
}

// Split the given expression w.r.t the add operator
Expand Down Expand Up @@ -319,6 +319,15 @@ Stmt Buffer::vstore(Array<PrimExpr> begin, PrimExpr value) const {
}
}

String Buffer::scope() const {
const auto* ptr_type = (*this)->data->type_annotation.as<PointerTypeNode>();
ICHECK(ptr_type) << "Buffer variable is not of pointer type";
if (ptr_type->storage_scope.empty()) {
return "global";
}
return ptr_type->storage_scope;
}

Buffer Buffer::MakeStrideView() const {
if ((*this)->strides.size() != 0) return *this;
if ((*this)->shape.size() == 0) return *this;
Expand Down Expand Up @@ -358,7 +367,7 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
return MakeStrideView().MakeSlice(begins, extents);
}
}
return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", n->scope,
return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice",
n->data_alignment, 0, n->buffer_type);
}

Expand Down Expand Up @@ -391,8 +400,8 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
}

Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
PrimExpr elem_offset, String name, String scope, int data_alignment,
int offset_factor, BufferType buffer_type, Span span) {
PrimExpr elem_offset, String name, int data_alignment, int offset_factor,
BufferType buffer_type, Span span) {
DataType storage_dtype = dtype;
// specially handle bool
if (storage_dtype == DataType::Bool()) {
Expand All @@ -409,10 +418,6 @@ Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr>
n->shape = std::move(shape);
n->strides = std::move(strides);
n->name = std::move(name);
if (scope.length() == 0) {
scope = "global";
}
n->scope = std::move(scope);
if (!elem_offset.defined()) {
elem_offset = make_const(n->DefaultIndexType(), 0);
}
Expand Down Expand Up @@ -444,11 +449,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(BufferNode);

TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_EQ(args.size(), 11);
auto buffer_type = args[9].operator String();
ICHECK_EQ(args.size(), 10);
auto buffer_type = args[8].operator String();
BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
*ret = Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8],
type, args[10]);
*ret =
Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], type, args[9]);
});

TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr);
Expand All @@ -457,5 +462,7 @@ TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload);

TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore);

TVM_REGISTER_GLOBAL("tir.BufferStorageScope").set_body_method(&Buffer::scope);

} // namespace tir
} // namespace tvm
4 changes: 2 additions & 2 deletions src/tir/schedule/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Array<arith::IntSet> AnalyzeRegionUpperBound(const BufferRegion& region,
AsIntSet(LoopDomainOfSRefTreePath(
/*low_inclusive=*/dom_low_inclusive,
/*high_exclusive=*/dom_high_exclusive,
/*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer->scope))));
/*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope()))));
}

/*!
Expand All @@ -67,7 +67,7 @@ Array<arith::IntSet> AnalyzeRegionLowerBound(const BlockRealize& realize,
LoopDomainOfSRefTreePath(
/*low_inclusive=*/dom_low_inclusive,
/*high_exclusive=*/dom_high_exclusive,
/*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer->scope)),
/*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())),
/*predicate=*/realize->predicate, /*analyzer=*/analyzer)) {
return result.value();
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void ArgBinder::BindArray(const Array<PrimExpr>& arg, const Array<PrimExpr>& val

void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name,
bool fuzzy_match) {
ICHECK_EQ(arg->scope, value->scope) << "Argument " << arg_name << " Buffer bind scope mismatch";
ICHECK_EQ(arg.scope(), value.scope()) << "Argument " << arg_name << " Buffer bind scope mismatch";
ICHECK_EQ(arg->dtype, value->dtype)
<< "Argument " << arg_name << " Buffer bind data type mismatch";
if (value->data_alignment % arg->data_alignment != 0) {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/bf16_legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ class BF16LowerRewriter : public StmtExprMutator {
DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes());
Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype)));
auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset,
oldbuf->name, oldbuf->scope, oldbuf->data_alignment,
oldbuf->offset_factor, oldbuf->buffer_type);
oldbuf->name, oldbuf->data_alignment, oldbuf->offset_factor,
oldbuf->buffer_type);
buffer_remap_[oldbuf] = newbuf;
var_remap_[oldbuf->data] = buffer_var;
changes.emplace_back(itr.first, newbuf);
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
std::unordered_map<const VarNode*, arith::IntSet> dom_map;
for (const ForNode* loop : ancestor_loops_) {
const VarNode* loop_var = loop->loop_var.get();
if (NeedRelaxThread(GetRef<For>(loop), runtime::StorageScope::Create(buffer->scope))) {
if (NeedRelaxThread(GetRef<For>(loop), runtime::StorageScope::Create(buffer.scope()))) {
dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent);
}
}
Expand Down
5 changes: 1 addition & 4 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,7 @@ class BufferFlattener : public StmtExprMutator {
}

static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body) {
String storage_scope = buffer->scope;
if (storage_scope.empty()) {
storage_scope = "global";
}
String storage_scope = buffer.scope();
PrimExpr area = BufferArea(buffer);
body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body));
body = AttrStmt(buffer->data, attr::storage_scope, StringImm(storage_scope), std::move(body));
Expand Down
6 changes: 2 additions & 4 deletions src/tir/transforms/inject_copy_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,9 @@ class CopyIntrinInjector : public StmtMutator {
dst_strides.push_back(make_const(DataType::Int(32), 1));
}
Buffer dst = Buffer(store->buffer_var, store->value.dtype(), dst_shape, dst_strides,
store_strides[loop_var_size], store->buffer_var->name_hint,
GetStorageScope(store->buffer_var.get()), 0, 0, kDefault);
store_strides[loop_var_size], store->buffer_var->name_hint, 0, 0, kDefault);
Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset,
load->buffer_var->name_hint, GetStorageScope(load->buffer_var.get()), 0, 0,
kDefault);
load->buffer_var->name_hint, 0, 0, kDefault);
*out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value);
ICHECK(out->defined()) << "flower function did not return correct stmt";
return true;
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class StorageFlattener : public StmtExprMutator {
auto new_var =
Var(op->buffer->data->name_hint, PointerType(ptr_type->element_type, skey.to_string()));
e.buffer = Buffer(new_var, op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name,
skey.to_string(), align, 0, kDefault);
align, 0, kDefault);

buf_map_[key] = e;
Stmt body = this->VisitStmt(op->body);
Expand All @@ -224,7 +224,7 @@ class StorageFlattener : public StmtExprMutator {
ret = Allocate(e.buffer->data, storage_type, shape,
make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
}
ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret);
ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(skey.to_string()), ret);

if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound,
Expand Down
22 changes: 11 additions & 11 deletions vta/python/vta/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,21 +495,21 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value):
# FIXME: pad_value is ignored...
env = get_env()
_ = pad_value
if dst.scope == "global":
if dst.scope() == "global":
# Store
if pad_before or pad_after:
raise RuntimeError("Do not support copy into DRAM with pad")
if src.scope == env.acc_scope:
if src.scope() == env.acc_scope:
elem_width = env.OUT_WIDTH
elem_bytes = env.OUT_ELEM_BYTES
mem_type = env.dev.MEM_ID_OUT
data_type = "int%d" % env.OUT_WIDTH
task_qid = env.dev.QID_STORE_OUT
else:
raise RuntimeError("Do not support copy %s->dram" % (src.scope))
raise RuntimeError("Do not support copy %s->dram" % (src.scope()))
_check_compact(src)
x_size, y_size, x_stride, offset = _get_2d_pattern(
dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True
dst, elem_width, elem_bytes, data_type, src.scope(), allow_fold=True
)
irb = tvm.tir.ir_builder.create()
irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid))
Expand All @@ -528,27 +528,27 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value):
)
)
return irb.get()
elif src.scope == "global":
if dst.scope == env.acc_scope:
elif src.scope() == "global":
if dst.scope() == env.acc_scope:
elem_width = env.ACC_WIDTH
elem_bytes = env.ACC_ELEM_BYTES
mem_type = env.dev.MEM_ID_ACC
data_type = "int%d" % env.ACC_WIDTH
task_qid = env.dev.QID_LOAD_OUT
elif dst.scope == env.inp_scope:
elif dst.scope() == env.inp_scope:
elem_width = env.INP_WIDTH
elem_bytes = env.INP_ELEM_BYTES
mem_type = env.dev.MEM_ID_INP
data_type = "int%d" % env.INP_WIDTH
task_qid = env.dev.QID_LOAD_INP
elif dst.scope == env.wgt_scope:
elif dst.scope() == env.wgt_scope:
elem_width = env.WGT_WIDTH
elem_bytes = env.WGT_ELEM_BYTES
mem_type = env.dev.MEM_ID_WGT
data_type = "int%d" % env.WGT_WIDTH
task_qid = env.dev.QID_LOAD_WGT
else:
raise RuntimeError("Do not support copy dram->%s" % (dst.scope))
raise RuntimeError("Do not support copy dram->%s" % (dst.scope()))
# collect pad statistics
if pad_before:
assert pad_after
Expand Down Expand Up @@ -586,7 +586,7 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value):

_check_compact(dst)
x_size, y_size, x_stride, offset = _get_2d_pattern(
src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold
src, elem_width, elem_bytes, data_type, dst.scope(), allow_fold=allow_fold
)

if data_type != src.dtype:
Expand Down Expand Up @@ -617,7 +617,7 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value):
return irb.get()

else:
raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
raise RuntimeError("Do not support copy %s->%s" % (src.scope(), dst.scope()))

return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy)

Expand Down

0 comments on commit 1a1be09

Please sign in to comment.