Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Add storage scope to PointerType #8017

Merged
merged 2 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,15 @@ class PointerTypeNode : public TypeNode {
* \brief The type of the element which the pointer points to.
*/
Type element_type;
/*!
* \brief The storage scope of the pointer
*/
String storage_scope;

void VisitAttrs(AttrVisitor* v) { v->Visit("element_type", &element_type); }
void VisitAttrs(AttrVisitor* v) {
v->Visit("element_type", &element_type);
v->Visit("storage_scope", &storage_scope);
}

bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const {
return equal(element_type, other->element_type);
Expand All @@ -175,8 +182,9 @@ class PointerType : public Type {
/*!
* \brief Constructor
* \param element_type The type of the element which the pointer points to.
* \param storage_scope The storage scope into which the pointer addresses
*/
TVM_DLL explicit PointerType(Type element_type);
TVM_DLL explicit PointerType(Type element_type, String storage_scope = "");

TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode);
};
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,13 @@ class PointerType(Type):
----------
element_type : tvm.ir.Type
The type of pointer's element.

storage_scope : str
The storage scope into which the pointer addresses.
"""

def __init__(self, element_type):
self.__init_handle_by_constructor__(_ffi_api.PointerType, element_type)
def __init__(self, element_type, storage_scope=""):
self.__init_handle_by_constructor__(_ffi_api.PointerType, element_type, storage_scope)


@tvm._ffi.register_object("TypeVar")
Expand Down
13 changes: 9 additions & 4 deletions src/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,26 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << node->dtype;
});

PointerType::PointerType(Type element_type) {
PointerType::PointerType(Type element_type, String storage_scope) {
ObjectPtr<PointerTypeNode> n = make_object<PointerTypeNode>();
n->element_type = std::move(element_type);
n->storage_scope = std::move(storage_scope);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(PointerTypeNode);

TVM_REGISTER_GLOBAL("ir.PointerType").set_body_typed([](Type element_type) {
return PointerType(element_type);
});
TVM_REGISTER_GLOBAL("ir.PointerType")
.set_body_typed([](Type element_type, String storage_scope = "") {
return PointerType(element_type, storage_scope);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PointerTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PointerTypeNode*>(ref.get());
if (!node->storage_scope.empty()) {
p->stream << node->storage_scope << " ";
}
p->Print(node->element_type);
p->stream << '*';
});
Expand Down
2 changes: 1 addition & 1 deletion src/ir/type_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ Type TypeMutator::VisitType_(const PointerTypeNode* op) {
if (element_type.same_as(op->element_type)) {
return GetRef<Type>(op);
} else {
return PointerType(element_type);
return PointerType(element_type, op->storage_scope);
}
}

Expand Down
6 changes: 5 additions & 1 deletion src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,11 @@ Doc TIRTextPrinter::VisitType_(const PrimTypeNode* node) {

Doc TIRTextPrinter::VisitType_(const PointerTypeNode* node) {
Doc doc;
doc << "Pointer(" << Print(node->element_type) << ")";
doc << "Pointer(";
if (!node->storage_scope.empty()) {
doc << node->storage_scope << " ";
}
doc << Print(node->element_type) << ")";
return doc;
}

Expand Down
6 changes: 5 additions & 1 deletion src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,11 @@ Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) {

Doc TVMScriptPrinter::VisitType_(const PointerTypeNode* node) {
Doc doc;
doc << "ty.Ptr[" << Print(node->element_type) << "]";
doc << "ty.Ptr[";
if (!node->storage_scope.empty()) {
doc << node->storage_scope << " ";
}
doc << Print(node->element_type) << "]";
return doc;
}

Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_tir_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,15 @@ def test_stmt_constructor():
assert x.buffer_var == buffer_var
assert x.body == nop

storage_scope = "global.texture"
buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"), storage_scope))
x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop)
assert isinstance(x, tvm.tir.Allocate)
assert x.dtype == "float32"
assert x.buffer_var == buffer_var
assert x.buffer_var.type_annotation.storage_scope == storage_scope
assert x.body == nop

x = tvm.tir.AttrStmt(buffer_var, "xyz", 1, nop)
assert isinstance(x, tvm.tir.AttrStmt)
assert x.node == buffer_var
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_tir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,17 @@ def test_vars():
assert isinstance(ptype.element_type, tvm.ir.PrimType)


def test_scoped_storage_vars():
dtype = "float"
storage_scope = "global.texture"
ptype = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
x = tvm.tir.Var("xyz", ptype)
assert x.dtype == "handle"
assert x.type_annotation == ptype
assert x.type_annotation.storage_scope == storage_scope
assert isinstance(ptype.element_type, tvm.ir.PrimType)


def test_buffer_load_store():
b = tvm.tir.decl_buffer((10,), "float32")
x = tvm.tir.BufferLoad(b, [0])
Expand Down Expand Up @@ -460,6 +471,7 @@ def test_block_blockrealize():
test_intimm_cond()
test_buffer_load_store()
test_vars()
test_scoped_storage_var()
test_prim_func()
test_cast()
test_attr()
Expand Down