Skip to content

Commit

Permalink
[TIR] Enforce buffer pointer var type to be consistent with dtype. (a…
Browse files Browse the repository at this point in the history
…pache#6317)

Now that we have type_annotation in tir::Var.
We should make sure that the type annotation to be consistent with the dtype
in Buffer declaration and Allocation.

This change allows future passes to directly use the content type information via type_annotation.

This PR turns on the enforcement on Buffer and also fixed a few cases for Allocate.
A follow up PR need to fix a few more cases in the hybrid script parsing
before everything can be made consistent.
  • Loading branch information
tqchen authored and Trevor Morris committed Aug 26, 2020
1 parent b467c88 commit 8a0c61c
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 83 deletions.
17 changes: 17 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,23 @@ TVM_DECLARE_INTRIN_BINARY(hypot);
TVM_DECLARE_INTRIN_BINARY(ldexp);

namespace tir {

/*!
* \brief Check if type is a pointer to a runtime element type.
* \param type The type to be checked.
* \param element_type The corresponding element type.
* \return The check results
*/
inline bool IsPointerType(const Type& type, const DataType& element_type) {
if (!type.defined()) return false;
if (const auto* ptr_type = type.as<PointerTypeNode>()) {
if (const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
return prim_type->dtype == element_type;
}
}
return false;
}

/*!
* \brief Make a const value with certain data type.
* \param t The target type.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from tvm._ffi.base import string_types
from tvm.runtime import Object, convert
from tvm.ir import PrimExpr
from tvm.ir import PrimExpr, PointerType, PrimType
from . import _ffi_api


Expand Down Expand Up @@ -241,7 +241,7 @@ def decl_buffer(shape,
shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
elem_offset = Var('%s_elem_offset' % name, shape_dtype)
if data is None:
data = Var(name, "handle")
data = Var(name, PointerType(PrimType(dtype)))
return _ffi_api.Buffer(
data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor, buffer_type)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Developer API of IR node builder make function."""
from tvm._ffi.base import string_types
from tvm.runtime import ObjectGeneric, DataType, convert, const
from tvm.ir import container as _container
from tvm.ir import container as _container, PointerType, PrimType

from . import stmt as _stmt
from . import expr as _expr
Expand Down Expand Up @@ -325,7 +325,7 @@ def allocate(self, dtype, shape, name="buf", scope=None):
buffer : BufferVar
The buffer var representing the buffer.
"""
buffer_var = _expr.Var(name, dtype="handle")
buffer_var = _expr.Var(name, PointerType(PrimType(dtype)))
if not isinstance(shape, (list, tuple, _container.Array)):
shape = [shape]
if scope:
Expand Down
2 changes: 1 addition & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Target DefaultTargetHost(Target target) {

tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std::string name,
int data_alignment, int offset_factor, bool compact) {
auto data = tir::Var(name, DataType::Handle());
auto data = tir::Var(name, PointerType(PrimType(dtype)));
bool has_any = false;
if (!compact) {
for (const auto& it : shape) {
Expand Down
5 changes: 5 additions & 0 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,14 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
PrimExpr elem_offset, String name, String scope, int data_alignment,
int offset_factor, BufferType buffer_type) {
CHECK(IsPointerType(data->type_annotation, dtype))
<< "Buffer data field expect to have the right pointer type annotation"
<< " annotation=" << data->type_annotation << ", dtype=" << dtype;

auto n = make_object<BufferNode>();
n->data = std::move(data);
n->dtype = dtype;

n->shape = std::move(shape);
n->strides = std::move(strides);
n->name = std::move(name);
Expand Down
3 changes: 3 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
// Allocate
Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body) {
// TODO(tvm-team): Add invariant check to make sure
// IsPointerPType(buffer_var->type_annotation, dtype)
// once we fix the allocate hybrid script printing.
for (size_t i = 0; i < extents.size(); ++i) {
CHECK(extents[i].defined());
CHECK(extents[i].dtype().is_scalar());
Expand Down
157 changes: 82 additions & 75 deletions src/tir/transforms/bf16_legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,11 @@ uint16_t RoundToNearestEven(float src) {
* Lower cast between bf16 and fp32
* Lower bf16 FloatImm to int16
*/
class BF16LowerRewriter : StmtExprMutator {
class BF16LowerRewriter : public StmtExprMutator {
public:
BF16LowerRewriter() {}

std::unordered_map<const BufferNode*, Buffer> buffer_remap;
std::unordered_map<const VarNode*, Var> var_remap;

Stmt operator()(Stmt s) { return VisitStmt(s); }
using StmtExprMutator::operator();

PrimExpr VisitExpr_(const CastNode* op) final {
auto op_val = StmtExprMutator::VisitExpr(op->value);
Expand All @@ -190,7 +187,6 @@ class BF16LowerRewriter : StmtExprMutator {
auto uint32_v = Cast(uint32_dtype, op_val);
// to be endian invariant.
return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16});

} else if (op->dtype.is_bfloat16()) {
// if is cast_to_bf16, check if op->value is fp32
CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32);
Expand All @@ -209,104 +205,104 @@ class BF16LowerRewriter : StmtExprMutator {
}

PrimExpr VisitExpr_(const VarNode* op) final {
auto itr = var_remap.find(op);
if (itr != var_remap.end()) {
Var var = GetRef<Var>(op);

auto itr = var_remap_.find(var);
if (itr != var_remap_.end()) {
return itr->second;
} else {
return std::move(var);
}
if (op->dtype.is_bfloat16()) {
CHECK(!op->type_annotation.defined());
auto ret = Var(op->name_hint, op->dtype);
var_remap[op] = ret;
return std::move(ret);
}
return StmtExprMutator::VisitExpr_(op);
}

Stmt VisitStmt_(const AllocateNode* op) final {
Stmt node_holder;
const AllocateNode* newop;
if (op->dtype.is_bfloat16()) {
auto v = Allocate(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), op->extents,
op->condition, op->body);
node_holder = v;
newop = static_cast<const AllocateNode*>(v.operator->());
DataType dtype = DataType::UInt(16, op->dtype.lanes());
Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype)));
var_remap_[op->buffer_var] = buffer_var;
return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body));
} else {
newop = op;
return StmtExprMutator::VisitStmt_(op);
}
return StmtExprMutator::VisitStmt_(newop);
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
auto itr = buffer_remap.find(op->buffer.operator->());
const BufferStoreNode* newop;
BufferStore newop_holder;
if (itr != buffer_remap.end()) {
newop_holder = BufferStore(itr->second, op->value, op->indices);
newop = newop_holder.operator->();
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<BufferStoreNode>();

auto it = buffer_remap_.find(op->buffer);
if (it != buffer_remap_.end()) {
return BufferStore(it->second, op->value, op->indices);
} else {
newop = op;
return ret;
}
return StmtExprMutator::VisitStmt_(newop);
}

Stmt VisitStmt_(const AttrStmtNode* op) final {
const AttrStmtNode* newop = op;
Stmt newop_holder;
if (auto buffer = op->node.as<BufferNode>()) {
auto itr = buffer_remap.find(buffer);
if (itr != buffer_remap.end()) {
newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body);
newop = newop_holder.as<AttrStmtNode>();
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();

if (auto* buffer = op->node.as<BufferNode>()) {
auto it = buffer_remap_.find(GetRef<Buffer>(buffer));
if (it != buffer_remap_.end()) {
return AttrStmt(it->second, op->attr_key, op->value, op->body);
}
} else if (auto buffer = op->node.as<VarNode>()) {
auto itr = var_remap.find(buffer);
if (itr != var_remap.end()) {
newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body);
newop = newop_holder.as<AttrStmtNode>();
} else if (auto* var = op->node.as<VarNode>()) {
auto it = var_remap_.find(GetRef<Var>(var));
if (it != var_remap_.end()) {
return AttrStmt(it->second, op->attr_key, op->value, op->body);
}
}
return StmtExprMutator::VisitStmt_(newop);
return ret;
}

Stmt VisitStmt_(const BufferRealizeNode* op) final {
auto itr = buffer_remap.find(op->buffer.operator->());
const BufferRealizeNode* newop;
Stmt newop_holder;
if (itr != buffer_remap.end()) {
auto v = BufferRealize(itr->second, op->bounds, op->condition, op->body);
newop_holder = v;
newop = v.operator->();
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<BufferRealizeNode>();

auto it = buffer_remap_.find(op->buffer);
if (it != buffer_remap_.end()) {
return BufferRealize(it->second, op->bounds, op->condition, op->body);
} else {
newop = op;
return ret;
}
}

Stmt VisitStmt_(const StoreNode* op) final {
// NOTE: we do not explicit recursivly mutate op->buffer_var
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<StoreNode>();

auto it = var_remap_.find(op->buffer_var);
if (it != var_remap_.end()) {
return Store(it->second, op->value, op->index, op->predicate);
} else {
return ret;
}
return StmtExprMutator::VisitStmt_(newop);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto itr = buffer_remap.find(op->buffer.operator->());
const BufferLoadNode* newop;
BufferLoad newop_holder;
if (itr != buffer_remap.end()) {
newop_holder = BufferLoad(itr->second, op->indices);
newop = newop_holder.operator->();
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<BufferLoadNode>();

auto it = buffer_remap_.find(op->buffer);
if (it != buffer_remap_.end()) {
return BufferLoad(it->second, op->indices);
} else {
newop = op;
return ret;
}
return StmtExprMutator::VisitExpr_(newop);
}

PrimExpr VisitExpr_(const LoadNode* op) final {
bool is_bf16 = false;
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<LoadNode>();

if (op->dtype.is_bfloat16()) {
is_bf16 = true;
}
PrimExpr index = this->VisitExpr(op->index);
PrimExpr predicate = this->VisitExpr(op->predicate);
if (index.same_as(op->index) && predicate.same_as(op->predicate) && !is_bf16) {
return GetRef<PrimExpr>(op);
auto it = var_remap_.find(op->buffer_var);
CHECK(it != var_remap_.end()) << "bfloat* var needs to be remapped";
return Load(DataType::UInt(16, op->dtype.lanes()), it->second, op->index, op->predicate);
} else {
return Load(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, op->buffer_var,
index, predicate);
return ret;
}
}

Expand All @@ -320,20 +316,31 @@ class BF16LowerRewriter : StmtExprMutator {

void AlterBuffers(PrimFuncNode* op) {
std::vector<std::pair<Var, Buffer>> changes;

for (auto& itr : op->buffer_map) {
auto oldbuf = itr.second;
if (oldbuf->dtype.is_bfloat16()) {
auto newbuf = Buffer(oldbuf->data, DataType::UInt(16, oldbuf->dtype.lanes()), oldbuf->shape,
oldbuf->strides, oldbuf->elem_offset, oldbuf->name, oldbuf->scope,
oldbuf->data_alignment, oldbuf->offset_factor, oldbuf->buffer_type);
buffer_remap[oldbuf.operator->()] = newbuf;
DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes());
Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype)));
auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset,
oldbuf->name, oldbuf->scope, oldbuf->data_alignment,
oldbuf->offset_factor, oldbuf->buffer_type);
buffer_remap_[oldbuf] = newbuf;
var_remap_[oldbuf->data] = buffer_var;
changes.emplace_back(itr.first, newbuf);
} else {
changes.emplace_back(itr);
}
}
if (buffer_remap.size() != 0) {

if (buffer_remap_.size() != 0) {
op->buffer_map = Map<Var, Buffer>(changes.begin(), changes.end());
}
}

private:
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
};

namespace transform {
Expand Down
6 changes: 3 additions & 3 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ class StorageFlattener : public StmtExprMutator {
strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
}

e.buffer =
Buffer(Var(op->buffer->data->name_hint, DataType::Handle()), op->buffer->dtype, shape,
strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault);
e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation),
op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name,
skey.to_string(), align, 0, kDefault);

buf_map_[key] = e;
Stmt body = this->VisitStmt(op->body);
Expand Down

0 comments on commit 8a0c61c

Please sign in to comment.