Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Remove scope attribute from Buffer class #8463

Merged
merged 6 commits into from
Jul 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ class BufferNode : public Object {
// Meta data
/*! \brief optional name of the buffer */
String name;
/*! \brief storage scope of the buffer, if other than global */
String scope;
/*! \brief Alignment requirement of data pointer in bytes. */
int data_alignment;
/*!
Expand All @@ -93,7 +91,6 @@ class BufferNode : public Object {
v->Visit("strides", &strides);
v->Visit("elem_offset", &elem_offset);
v->Visit("name", &name);
v->Visit("scope", &scope);
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
v->Visit("buffer_type", &buffer_type);
Expand All @@ -105,7 +102,7 @@ class BufferNode : public Object {
// in its semantics, skip name as name is not important.
return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) &&
equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) &&
equal.DefEqual(elem_offset, other->elem_offset) && equal(scope, other->scope) &&
equal.DefEqual(elem_offset, other->elem_offset) &&
equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type);
}

Expand All @@ -115,7 +112,6 @@ class BufferNode : public Object {
hash_reduce.DefHash(shape);
hash_reduce.DefHash(strides);
hash_reduce.DefHash(elem_offset);
hash_reduce(scope);
hash_reduce(data_alignment);
hash_reduce(buffer_type);
}
Expand All @@ -141,8 +137,8 @@ class Buffer : public ObjectRef {
// User can specify data_alignment and offset_factor to be 0
// A default value will be picked.
TVM_DLL Buffer(Var ptr, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
PrimExpr elem_offset, String name, String scope, int data_alignment,
int offset_factor, BufferType buffer_type, Span span = Span());
PrimExpr elem_offset, String name, int data_alignment, int offset_factor,
BufferType buffer_type, Span span = Span());

/*!
* \brief Return a new buffer that is equivalent with current one
Expand Down Expand Up @@ -182,6 +178,11 @@ class Buffer : public ObjectRef {
*/
TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;

/*!
* \brief Return the storage scope associated with this buffer.
*/
TVM_DLL String scope() const;

TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode);
};
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/topi/detail/extern.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ using namespace tvm::te;
inline Buffer DeclExternBuffer(Array<PrimExpr> shape, DataType dtype, std::string name) {
auto data = var(name, DataType::Handle());
auto elem_offset = PrimExpr();
return Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "", -1, 0, kDefault);
return Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, -1, 0, kDefault);
}

/*!
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def match_buffer_region(
data=None,
strides=strides,
elem_offset=elem_offset,
scope=buffer_region.buffer.scope,
scope=buffer_region.buffer.scope(),
data_alignment=align,
offset_factor=offset_factor,
span=span,
Expand Down
10 changes: 9 additions & 1 deletion python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ def vstore(self, begin, value):
begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
return _ffi_api.BufferVStore(self, begin, value) # type: ignore

def scope(self):
"""Return the storage scope associated with this buffer.
Returns
-------
scope : str
The storage scope associated with this buffer.
"""
return _ffi_api.BufferStorageScope(self) # type: ignore


def decl_buffer(
shape,
Expand Down Expand Up @@ -260,7 +269,6 @@ def decl_buffer(
strides,
elem_offset,
name,
scope,
data_alignment,
offset_factor,
buffer_type,
Expand Down
2 changes: 1 addition & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std
elem_offset = PrimExpr();
}

return tir::Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "", data_alignment,
return tir::Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, data_alignment,
offset_factor, buffer_type);
}

Expand Down
4 changes: 2 additions & 2 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) {
if (!is_zero(buf->elem_offset)) {
doc << ", elem_offset=" << Print(buf->elem_offset);
}
if (buf->scope != "global") {
doc << ", scope=" << Doc::StrLiteral(buf->scope);
if (GetRef<Buffer>(buf).scope() != "global") {
doc << ", scope=" << Doc::StrLiteral(GetRef<Buffer>(buf).scope());
}
if (buf->data_alignment != 128) {
doc << ", align=" << buf->data_alignment;
Expand Down
4 changes: 2 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
} else {
doc << ", elem_offset=" << Print(buf->elem_offset);
}
if (buf->scope != "global") {
doc << ", scope=" << Doc::StrLiteral(buf->scope);
if (buf.scope() != "global") {
doc << ", scope=" << Doc::StrLiteral(buf.scope());
}
if (buf->data_alignment != -1) {
doc << ", align=" << buf->data_alignment;
Expand Down
31 changes: 19 additions & 12 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, String name, String st
Span span) {
DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape,
Array<PrimExpr>(), PrimExpr(), name, "", 0, 0, kDefault, span);
Array<PrimExpr>(), PrimExpr(), name, 0, 0, kDefault, span);
}

// Split the given expression w.r.t the add operator
Expand Down Expand Up @@ -312,6 +312,15 @@ Stmt Buffer::vstore(Array<PrimExpr> begin, PrimExpr value) const {
}
}

String Buffer::scope() const {
const auto* ptr_type = (*this)->data->type_annotation.as<PointerTypeNode>();
ICHECK(ptr_type) << "Buffer variable is not of pointer type";
if (ptr_type->storage_scope.empty()) {
return "global";
}
return ptr_type->storage_scope;
}

Buffer Buffer::MakeStrideView() const {
if ((*this)->strides.size() != 0) return *this;
if ((*this)->shape.size() == 0) return *this;
Expand Down Expand Up @@ -351,7 +360,7 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
return MakeStrideView().MakeSlice(begins, extents);
}
}
return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", n->scope,
return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice",
n->data_alignment, 0, n->buffer_type);
}

Expand Down Expand Up @@ -384,8 +393,8 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
}

Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
PrimExpr elem_offset, String name, String scope, int data_alignment,
int offset_factor, BufferType buffer_type, Span span) {
PrimExpr elem_offset, String name, int data_alignment, int offset_factor,
BufferType buffer_type, Span span) {
DataType storage_dtype = dtype;
// specially handle bool
if (storage_dtype == DataType::Bool()) {
Expand All @@ -402,10 +411,6 @@ Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr>
n->shape = std::move(shape);
n->strides = std::move(strides);
n->name = std::move(name);
if (scope.length() == 0) {
scope = "global";
}
n->scope = std::move(scope);
if (!elem_offset.defined()) {
elem_offset = make_const(n->DefaultIndexType(), 0);
}
Expand Down Expand Up @@ -437,11 +442,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(BufferNode);

TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_EQ(args.size(), 11);
auto buffer_type = args[9].operator String();
ICHECK_EQ(args.size(), 10);
auto buffer_type = args[8].operator String();
BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
*ret = Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8],
type, args[10]);
*ret =
Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], type, args[9]);
});

TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr);
Expand All @@ -450,5 +455,7 @@ TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload);

TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore);

TVM_REGISTER_GLOBAL("tir.BufferStorageScope").set_body_method(&Buffer::scope);

} // namespace tir
} // namespace tvm
4 changes: 2 additions & 2 deletions src/tir/schedule/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Array<arith::IntSet> AnalyzeRegionUpperBound(const BufferRegion& region,
AsIntSet(LoopDomainOfSRefTreePath(
/*low_inclusive=*/dom_low_inclusive,
/*high_exclusive=*/dom_high_exclusive,
/*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer->scope))));
/*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope()))));
}

/*!
Expand All @@ -67,7 +67,7 @@ Array<arith::IntSet> AnalyzeRegionLowerBound(const BlockRealize& realize,
LoopDomainOfSRefTreePath(
/*low_inclusive=*/dom_low_inclusive,
/*high_exclusive=*/dom_high_exclusive,
/*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer->scope)),
/*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())),
/*predicate=*/realize->predicate, /*analyzer=*/analyzer)) {
return result.value();
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void ArgBinder::BindArray(const Array<PrimExpr>& arg, const Array<PrimExpr>& val

void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name,
bool fuzzy_match) {
ICHECK_EQ(arg->scope, value->scope) << "Argument " << arg_name << " Buffer bind scope mismatch";
ICHECK_EQ(arg.scope(), value.scope()) << "Argument " << arg_name << " Buffer bind scope mismatch";
ICHECK_EQ(arg->dtype, value->dtype)
<< "Argument " << arg_name << " Buffer bind data type mismatch";
if (value->data_alignment % arg->data_alignment != 0) {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/bf16_legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ class BF16LowerRewriter : public StmtExprMutator {
DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes());
Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype)));
auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset,
oldbuf->name, oldbuf->scope, oldbuf->data_alignment,
oldbuf->offset_factor, oldbuf->buffer_type);
oldbuf->name, oldbuf->data_alignment, oldbuf->offset_factor,
oldbuf->buffer_type);
buffer_remap_[oldbuf] = newbuf;
var_remap_[oldbuf->data] = buffer_var;
changes.emplace_back(itr.first, newbuf);
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
std::unordered_map<const VarNode*, arith::IntSet> dom_map;
for (const ForNode* loop : ancestor_loops_) {
const VarNode* loop_var = loop->loop_var.get();
if (NeedRelaxThread(GetRef<For>(loop), runtime::StorageScope::Create(buffer->scope))) {
if (NeedRelaxThread(GetRef<For>(loop), runtime::StorageScope::Create(buffer.scope()))) {
dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent);
}
}
Expand Down
5 changes: 1 addition & 4 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,7 @@ class BufferFlattener : public StmtExprMutator {
}

static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body) {
String storage_scope = buffer->scope;
if (storage_scope.empty()) {
storage_scope = "global";
}
String storage_scope = buffer.scope();
PrimExpr area = BufferArea(buffer);
body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body));
body = AttrStmt(buffer->data, attr::storage_scope, StringImm(storage_scope), std::move(body));
Expand Down
6 changes: 2 additions & 4 deletions src/tir/transforms/inject_copy_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,9 @@ class CopyIntrinInjector : public StmtMutator {
dst_strides.push_back(make_const(DataType::Int(32), 1));
}
Buffer dst = Buffer(store->buffer_var, store->value.dtype(), dst_shape, dst_strides,
store_strides[loop_var_size], store->buffer_var->name_hint,
GetStorageScope(store->buffer_var.get()), 0, 0, kDefault);
store_strides[loop_var_size], store->buffer_var->name_hint, 0, 0, kDefault);
Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset,
load->buffer_var->name_hint, GetStorageScope(load->buffer_var.get()), 0, 0,
kDefault);
load->buffer_var->name_hint, 0, 0, kDefault);
*out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value);
ICHECK(out->defined()) << "flower function did not return correct stmt";
return true;
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class StorageFlattener : public StmtExprMutator {
auto new_var =
Var(op->buffer->data->name_hint, PointerType(ptr_type->element_type, skey.to_string()));
e.buffer = Buffer(new_var, op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name,
skey.to_string(), align, 0, kDefault);
align, 0, kDefault);

buf_map_[key] = e;
Stmt body = this->VisitStmt(op->body);
Expand All @@ -224,7 +224,7 @@ class StorageFlattener : public StmtExprMutator {
ret = Allocate(e.buffer->data, storage_type, shape,
make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
}
ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret);
ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(skey.to_string()), ret);

if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound,
Expand Down
22 changes: 11 additions & 11 deletions vta/python/vta/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,21 +495,21 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value):
# FIXME: pad_value is ignored...
env = get_env()
_ = pad_value
if dst.scope == "global":
if dst.scope() == "global":
# Store
if pad_before or pad_after:
raise RuntimeError("Do not support copy into DRAM with pad")
if src.scope == env.acc_scope:
if src.scope() == env.acc_scope:
elem_width = env.OUT_WIDTH
elem_bytes = env.OUT_ELEM_BYTES
mem_type = env.dev.MEM_ID_OUT
data_type = "int%d" % env.OUT_WIDTH
task_qid = env.dev.QID_STORE_OUT
else:
raise RuntimeError("Do not support copy %s->dram" % (src.scope))
raise RuntimeError("Do not support copy %s->dram" % (src.scope()))
_check_compact(src)
x_size, y_size, x_stride, offset = _get_2d_pattern(
dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True
dst, elem_width, elem_bytes, data_type, src.scope(), allow_fold=True
)
irb = tvm.tir.ir_builder.create()
irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid))
Expand All @@ -528,27 +528,27 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value):
)
)
return irb.get()
elif src.scope == "global":
if dst.scope == env.acc_scope:
elif src.scope() == "global":
if dst.scope() == env.acc_scope:
elem_width = env.ACC_WIDTH
elem_bytes = env.ACC_ELEM_BYTES
mem_type = env.dev.MEM_ID_ACC
data_type = "int%d" % env.ACC_WIDTH
task_qid = env.dev.QID_LOAD_OUT
elif dst.scope == env.inp_scope:
elif dst.scope() == env.inp_scope:
elem_width = env.INP_WIDTH
elem_bytes = env.INP_ELEM_BYTES
mem_type = env.dev.MEM_ID_INP
data_type = "int%d" % env.INP_WIDTH
task_qid = env.dev.QID_LOAD_INP
elif dst.scope == env.wgt_scope:
elif dst.scope() == env.wgt_scope:
elem_width = env.WGT_WIDTH
elem_bytes = env.WGT_ELEM_BYTES
mem_type = env.dev.MEM_ID_WGT
data_type = "int%d" % env.WGT_WIDTH
task_qid = env.dev.QID_LOAD_WGT
else:
raise RuntimeError("Do not support copy dram->%s" % (dst.scope))
raise RuntimeError("Do not support copy dram->%s" % (dst.scope()))
# collect pad statistics
if pad_before:
assert pad_after
Expand Down Expand Up @@ -586,7 +586,7 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value):

_check_compact(dst)
x_size, y_size, x_stride, offset = _get_2d_pattern(
src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold
src, elem_width, elem_bytes, data_type, dst.scope(), allow_fold=allow_fold
)

if data_type != src.dtype:
Expand Down Expand Up @@ -617,7 +617,7 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value):
return irb.get()

else:
raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
raise RuntimeError("Do not support copy %s->%s" % (src.scope(), dst.scope()))

return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy)

Expand Down