Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Enforce buffer pointer var type to be consistent with dtype. #6317

Merged
merged 1 commit into from
Aug 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,23 @@ TVM_DECLARE_INTRIN_BINARY(hypot);
TVM_DECLARE_INTRIN_BINARY(ldexp);

namespace tir {

/*!
* \brief Check if type is a pointer to a runtime element type.
* \param type The type to be checked.
* \param element_type The corresponding element type.
* \return The check results
*/
inline bool IsPointerType(const Type& type, const DataType& element_type) {
if (!type.defined()) return false;
if (const auto* ptr_type = type.as<PointerTypeNode>()) {
if (const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
return prim_type->dtype == element_type;
}
}
return false;
}

/*!
* \brief Make a const value with certain data type.
* \param t The target type.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from tvm._ffi.base import string_types
from tvm.runtime import Object, convert
from tvm.ir import PrimExpr
from tvm.ir import PrimExpr, PointerType, PrimType
from . import _ffi_api


Expand Down Expand Up @@ -241,7 +241,7 @@ def decl_buffer(shape,
shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
elem_offset = Var('%s_elem_offset' % name, shape_dtype)
if data is None:
data = Var(name, "handle")
data = Var(name, PointerType(PrimType(dtype)))
return _ffi_api.Buffer(
data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor, buffer_type)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Developer API of IR node builder make function."""
from tvm._ffi.base import string_types
from tvm.runtime import ObjectGeneric, DataType, convert, const
from tvm.ir import container as _container
from tvm.ir import container as _container, PointerType, PrimType

from . import stmt as _stmt
from . import expr as _expr
Expand Down Expand Up @@ -325,7 +325,7 @@ def allocate(self, dtype, shape, name="buf", scope=None):
buffer : BufferVar
The buffer var representing the buffer.
"""
buffer_var = _expr.Var(name, dtype="handle")
buffer_var = _expr.Var(name, PointerType(PrimType(dtype)))
if not isinstance(shape, (list, tuple, _container.Array)):
shape = [shape]
if scope:
Expand Down
2 changes: 1 addition & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Target DefaultTargetHost(Target target) {

tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std::string name,
int data_alignment, int offset_factor, bool compact) {
auto data = tir::Var(name, DataType::Handle());
auto data = tir::Var(name, PointerType(PrimType(dtype)));
bool has_any = false;
if (!compact) {
for (const auto& it : shape) {
Expand Down
5 changes: 5 additions & 0 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,14 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
PrimExpr elem_offset, String name, String scope, int data_alignment,
int offset_factor, BufferType buffer_type) {
CHECK(IsPointerType(data->type_annotation, dtype))
<< "Buffer data field expect to have the right pointer type annotation"
<< " annotation=" << data->type_annotation << ", dtype=" << dtype;

auto n = make_object<BufferNode>();
n->data = std::move(data);
n->dtype = dtype;

n->shape = std::move(shape);
n->strides = std::move(strides);
n->name = std::move(name);
Expand Down
3 changes: 3 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
// Allocate
Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body) {
// TODO(tvm-team): Add invariant check to make sure
// IsPointerPType(buffer_var->type_annotation, dtype)
// once we fix the allocate hybrid script printing.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it be addressed soon in a follow-up PR for hybrid script?

Copy link
Member Author

@tqchen tqchen Aug 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, i believe we can send another PR once the hybrid script printing part is fixed.

for (size_t i = 0; i < extents.size(); ++i) {
CHECK(extents[i].defined());
CHECK(extents[i].dtype().is_scalar());
Expand Down
157 changes: 82 additions & 75 deletions src/tir/transforms/bf16_legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,11 @@ uint16_t RoundToNearestEven(float src) {
* Lower cast between bf16 and fp32
* Lower bf16 FloatImm to int16
*/
class BF16LowerRewriter : StmtExprMutator {
class BF16LowerRewriter : public StmtExprMutator {
public:
BF16LowerRewriter() {}

std::unordered_map<const BufferNode*, Buffer> buffer_remap;
std::unordered_map<const VarNode*, Var> var_remap;

Stmt operator()(Stmt s) { return VisitStmt(s); }
using StmtExprMutator::operator();

PrimExpr VisitExpr_(const CastNode* op) final {
auto op_val = StmtExprMutator::VisitExpr(op->value);
Expand All @@ -190,7 +187,6 @@ class BF16LowerRewriter : StmtExprMutator {
auto uint32_v = Cast(uint32_dtype, op_val);
// to be endian invariant.
return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16});

} else if (op->dtype.is_bfloat16()) {
// if is cast_to_bf16, check if op->value is fp32
CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32);
Expand All @@ -209,104 +205,104 @@ class BF16LowerRewriter : StmtExprMutator {
}

PrimExpr VisitExpr_(const VarNode* op) final {
auto itr = var_remap.find(op);
if (itr != var_remap.end()) {
Var var = GetRef<Var>(op);

auto itr = var_remap_.find(var);
if (itr != var_remap_.end()) {
return itr->second;
} else {
return std::move(var);
}
if (op->dtype.is_bfloat16()) {
CHECK(!op->type_annotation.defined());
auto ret = Var(op->name_hint, op->dtype);
var_remap[op] = ret;
return std::move(ret);
}
return StmtExprMutator::VisitExpr_(op);
}

Stmt VisitStmt_(const AllocateNode* op) final {
Stmt node_holder;
const AllocateNode* newop;
if (op->dtype.is_bfloat16()) {
auto v = Allocate(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), op->extents,
op->condition, op->body);
node_holder = v;
newop = static_cast<const AllocateNode*>(v.operator->());
DataType dtype = DataType::UInt(16, op->dtype.lanes());
Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype)));
var_remap_[op->buffer_var] = buffer_var;
return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body));
} else {
newop = op;
return StmtExprMutator::VisitStmt_(op);
}
return StmtExprMutator::VisitStmt_(newop);
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
auto itr = buffer_remap.find(op->buffer.operator->());
const BufferStoreNode* newop;
BufferStore newop_holder;
if (itr != buffer_remap.end()) {
newop_holder = BufferStore(itr->second, op->value, op->indices);
newop = newop_holder.operator->();
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<BufferStoreNode>();

auto it = buffer_remap_.find(op->buffer);
if (it != buffer_remap_.end()) {
return BufferStore(it->second, op->value, op->indices);
} else {
newop = op;
return ret;
}
return StmtExprMutator::VisitStmt_(newop);
}

Stmt VisitStmt_(const AttrStmtNode* op) final {
const AttrStmtNode* newop = op;
Stmt newop_holder;
if (auto buffer = op->node.as<BufferNode>()) {
auto itr = buffer_remap.find(buffer);
if (itr != buffer_remap.end()) {
newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body);
newop = newop_holder.as<AttrStmtNode>();
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();

if (auto* buffer = op->node.as<BufferNode>()) {
auto it = buffer_remap_.find(GetRef<Buffer>(buffer));
if (it != buffer_remap_.end()) {
return AttrStmt(it->second, op->attr_key, op->value, op->body);
}
} else if (auto buffer = op->node.as<VarNode>()) {
auto itr = var_remap.find(buffer);
if (itr != var_remap.end()) {
newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body);
newop = newop_holder.as<AttrStmtNode>();
} else if (auto* var = op->node.as<VarNode>()) {
auto it = var_remap_.find(GetRef<Var>(var));
if (it != var_remap_.end()) {
return AttrStmt(it->second, op->attr_key, op->value, op->body);
}
}
return StmtExprMutator::VisitStmt_(newop);
return ret;
}

Stmt VisitStmt_(const BufferRealizeNode* op) final {
auto itr = buffer_remap.find(op->buffer.operator->());
const BufferRealizeNode* newop;
Stmt newop_holder;
if (itr != buffer_remap.end()) {
auto v = BufferRealize(itr->second, op->bounds, op->condition, op->body);
newop_holder = v;
newop = v.operator->();
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<BufferRealizeNode>();

auto it = buffer_remap_.find(op->buffer);
if (it != buffer_remap_.end()) {
return BufferRealize(it->second, op->bounds, op->condition, op->body);
} else {
newop = op;
return ret;
}
}

Stmt VisitStmt_(const StoreNode* op) final {
// NOTE: we do not explicit recursivly mutate op->buffer_var
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<StoreNode>();

auto it = var_remap_.find(op->buffer_var);
if (it != var_remap_.end()) {
return Store(it->second, op->value, op->index, op->predicate);
} else {
return ret;
}
return StmtExprMutator::VisitStmt_(newop);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto itr = buffer_remap.find(op->buffer.operator->());
const BufferLoadNode* newop;
BufferLoad newop_holder;
if (itr != buffer_remap.end()) {
newop_holder = BufferLoad(itr->second, op->indices);
newop = newop_holder.operator->();
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<BufferLoadNode>();

auto it = buffer_remap_.find(op->buffer);
if (it != buffer_remap_.end()) {
return BufferLoad(it->second, op->indices);
} else {
newop = op;
return ret;
}
return StmtExprMutator::VisitExpr_(newop);
}

PrimExpr VisitExpr_(const LoadNode* op) final {
bool is_bf16 = false;
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<LoadNode>();

if (op->dtype.is_bfloat16()) {
is_bf16 = true;
}
PrimExpr index = this->VisitExpr(op->index);
PrimExpr predicate = this->VisitExpr(op->predicate);
if (index.same_as(op->index) && predicate.same_as(op->predicate) && !is_bf16) {
return GetRef<PrimExpr>(op);
auto it = var_remap_.find(op->buffer_var);
CHECK(it != var_remap_.end()) << "bfloat* var needs to be remapped";
return Load(DataType::UInt(16, op->dtype.lanes()), it->second, op->index, op->predicate);
} else {
return Load(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, op->buffer_var,
index, predicate);
return ret;
}
}

Expand All @@ -320,20 +316,31 @@ class BF16LowerRewriter : StmtExprMutator {

void AlterBuffers(PrimFuncNode* op) {
std::vector<std::pair<Var, Buffer>> changes;

for (auto& itr : op->buffer_map) {
auto oldbuf = itr.second;
if (oldbuf->dtype.is_bfloat16()) {
auto newbuf = Buffer(oldbuf->data, DataType::UInt(16, oldbuf->dtype.lanes()), oldbuf->shape,
oldbuf->strides, oldbuf->elem_offset, oldbuf->name, oldbuf->scope,
oldbuf->data_alignment, oldbuf->offset_factor, oldbuf->buffer_type);
buffer_remap[oldbuf.operator->()] = newbuf;
DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes());
Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype)));
auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset,
oldbuf->name, oldbuf->scope, oldbuf->data_alignment,
oldbuf->offset_factor, oldbuf->buffer_type);
buffer_remap_[oldbuf] = newbuf;
var_remap_[oldbuf->data] = buffer_var;
changes.emplace_back(itr.first, newbuf);
} else {
changes.emplace_back(itr);
}
}
if (buffer_remap.size() != 0) {

if (buffer_remap_.size() != 0) {
op->buffer_map = Map<Var, Buffer>(changes.begin(), changes.end());
}
}

private:
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
};

namespace transform {
Expand Down
6 changes: 3 additions & 3 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ class StorageFlattener : public StmtExprMutator {
strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
}

e.buffer =
Buffer(Var(op->buffer->data->name_hint, DataType::Handle()), op->buffer->dtype, shape,
strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault);
e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation),
op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name,
skey.to_string(), align, 0, kDefault);

buf_map_[key] = e;
Stmt body = this->VisitStmt(op->body);
Expand Down