From 4e7559785c796f0a9c2386ce744d494bf13cb3d2 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 21 Aug 2020 10:53:39 -0700 Subject: [PATCH] [TIR] Enforce buffer pointer var type to be consistent with dtype. (#6317) Now that we have type_annotation in tir::Var. We should make sure that the type annotation to be consistent with the dtype in Buffer declaration and Allocation. This change allows future passes to directly use the content type information via type_annotation. This PR turns on the enforcement on Buffer and also fixed a few cases for Allocate. A follow up PR need to fix a few more cases in the hybrid script parsing before everything can be made consistent. --- include/tvm/tir/op.h | 17 +++ python/tvm/tir/buffer.py | 4 +- python/tvm/tir/ir_builder.py | 4 +- src/driver/driver_api.cc | 2 +- src/tir/ir/buffer.cc | 5 + src/tir/ir/stmt.cc | 3 + src/tir/transforms/bf16_legalize.cc | 157 ++++++++++++++------------ src/tir/transforms/storage_flatten.cc | 6 +- 8 files changed, 115 insertions(+), 83 deletions(-) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 68ca2663ede94..93a54b044fba8 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -617,6 +617,23 @@ TVM_DECLARE_INTRIN_BINARY(hypot); TVM_DECLARE_INTRIN_BINARY(ldexp); namespace tir { + +/*! + * \brief Check if type is a pointer to a runtime element type. + * \param type The type to be checked. + * \param element_type The corresponding element type. + * \return The check results + */ +inline bool IsPointerType(const Type& type, const DataType& element_type) { + if (!type.defined()) return false; + if (const auto* ptr_type = type.as()) { + if (const auto* prim_type = ptr_type->element_type.as()) { + return prim_type->dtype == element_type; + } + } + return false; +} + /*! * \brief Make a const value with certain data type. * \param t The target type. diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 11bfb4c55921c..bd7672a52d9a1 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -20,7 +20,7 @@ from tvm._ffi.base import string_types from tvm.runtime import Object, convert -from tvm.ir import PrimExpr +from tvm.ir import PrimExpr, PointerType, PrimType from . import _ffi_api @@ -241,7 +241,7 @@ def decl_buffer(shape, shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32" elem_offset = Var('%s_elem_offset' % name, shape_dtype) if data is None: - data = Var(name, "handle") + data = Var(name, PointerType(PrimType(dtype))) return _ffi_api.Buffer( data, dtype, shape, strides, elem_offset, name, scope, data_alignment, offset_factor, buffer_type) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 20180d1be45d5..b313e58a03aff 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -17,7 +17,7 @@ """Developer API of IR node builder make function.""" from tvm._ffi.base import string_types from tvm.runtime import ObjectGeneric, DataType, convert, const -from tvm.ir import container as _container +from tvm.ir import container as _container, PointerType, PrimType from . import stmt as _stmt from . import expr as _expr @@ -325,7 +325,7 @@ def allocate(self, dtype, shape, name="buf", scope=None): buffer : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, dtype="handle") + buffer_var = _expr.Var(name, PointerType(PrimType(dtype))) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] if scope: diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 142bdfc70dcec..14aa4fc56e2e1 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -69,7 +69,7 @@ Target DefaultTargetHost(Target target) { tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std::string name, int data_alignment, int offset_factor, bool compact) { - auto data = tir::Var(name, DataType::Handle()); + auto data = tir::Var(name, PointerType(PrimType(dtype))); bool has_any = false; if (!compact) { for (const auto& it : shape) { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 00e3335633ecd..d33f2ddf698af 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -383,9 +383,14 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, PrimExpr elem_offset, String name, String scope, int data_alignment, int offset_factor, BufferType buffer_type) { + CHECK(IsPointerType(data->type_annotation, dtype)) + << "Buffer data field expect to have the right pointer type annotation" + << " annotation=" << data->type_annotation << ", dtype=" << dtype; + auto n = make_object(); n->data = std::move(data); n->dtype = dtype; + n->shape = std::move(shape); n->strides = std::move(strides); n->name = std::move(name); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 296f49207cce3..d9e1df46e8fa9 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -263,6 +263,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Allocate Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, Stmt body) { + // TODO(tvm-team): Add invariant check to make sure + // IsPointerPType(buffer_var->type_annotation, dtype) + // once we fix the allocate hybrid script printing. for (size_t i = 0; i < extents.size(); ++i) { CHECK(extents[i].defined()); CHECK(extents[i].dtype().is_scalar()); diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 4a44b85684b26..97c96edc6ca77 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -172,14 +172,11 @@ uint16_t RoundToNearestEven(float src) { * Lower cast between bf16 and fp32 * Lower bf16 FloatImm to int16 */ -class BF16LowerRewriter : StmtExprMutator { +class BF16LowerRewriter : public StmtExprMutator { public: BF16LowerRewriter() {} - std::unordered_map buffer_remap; - std::unordered_map var_remap; - - Stmt operator()(Stmt s) { return VisitStmt(s); } + using StmtExprMutator::operator(); PrimExpr VisitExpr_(const CastNode* op) final { auto op_val = StmtExprMutator::VisitExpr(op->value); @@ -190,7 +187,6 @@ class BF16LowerRewriter : StmtExprMutator { auto uint32_v = Cast(uint32_dtype, op_val); // to be endian invariant. return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16}); - } else if (op->dtype.is_bfloat16()) { // if is cast_to_bf16, check if op->value is fp32 CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32); @@ -209,104 +205,104 @@ class BF16LowerRewriter : StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* op) final { - auto itr = var_remap.find(op); - if (itr != var_remap.end()) { + Var var = GetRef(op); + + auto itr = var_remap_.find(var); + if (itr != var_remap_.end()) { return itr->second; + } else { + return std::move(var); } - if (op->dtype.is_bfloat16()) { - CHECK(!op->type_annotation.defined()); - auto ret = Var(op->name_hint, op->dtype); - var_remap[op] = ret; - return std::move(ret); - } - return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const AllocateNode* op) final { - Stmt node_holder; - const AllocateNode* newop; if (op->dtype.is_bfloat16()) { - auto v = Allocate(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), op->extents, - op->condition, op->body); - node_holder = v; - newop = static_cast(v.operator->()); + DataType dtype = DataType::UInt(16, op->dtype.lanes()); + Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype))); + var_remap_[op->buffer_var] = buffer_var; + return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body)); } else { - newop = op; + return StmtExprMutator::VisitStmt_(op); } - return StmtExprMutator::VisitStmt_(newop); } Stmt VisitStmt_(const BufferStoreNode* op) final { - auto itr = buffer_remap.find(op->buffer.operator->()); - const BufferStoreNode* newop; - BufferStore newop_holder; - if (itr != buffer_remap.end()) { - newop_holder = BufferStore(itr->second, op->value, op->indices); - newop = newop_holder.operator->(); + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + auto it = buffer_remap_.find(op->buffer); + if (it != buffer_remap_.end()) { + return BufferStore(it->second, op->value, op->indices); } else { - newop = op; + return ret; } - return StmtExprMutator::VisitStmt_(newop); } Stmt VisitStmt_(const AttrStmtNode* op) final { - const AttrStmtNode* newop = op; - Stmt newop_holder; - if (auto buffer = op->node.as()) { - auto itr = buffer_remap.find(buffer); - if (itr != buffer_remap.end()) { - newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body); - newop = newop_holder.as(); + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + if (auto* buffer = op->node.as()) { + auto it = buffer_remap_.find(GetRef(buffer)); + if (it != buffer_remap_.end()) { + return AttrStmt(it->second, op->attr_key, op->value, op->body); } - } else if (auto buffer = op->node.as()) { - auto itr = var_remap.find(buffer); - if (itr != var_remap.end()) { - newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body); - newop = newop_holder.as(); + } else if (auto* var = op->node.as()) { + auto it = var_remap_.find(GetRef(var)); + if (it != var_remap_.end()) { + return AttrStmt(it->second, op->attr_key, op->value, op->body); } } - return StmtExprMutator::VisitStmt_(newop); + return ret; } Stmt VisitStmt_(const BufferRealizeNode* op) final { - auto itr = buffer_remap.find(op->buffer.operator->()); - const BufferRealizeNode* newop; - Stmt newop_holder; - if (itr != buffer_remap.end()) { - auto v = BufferRealize(itr->second, op->bounds, op->condition, op->body); - newop_holder = v; - newop = v.operator->(); + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + auto it = buffer_remap_.find(op->buffer); + if (it != buffer_remap_.end()) { + return BufferRealize(it->second, op->bounds, op->condition, op->body); } else { - newop = op; + return ret; + } + } + + Stmt VisitStmt_(const StoreNode* op) final { + // NOTE: we do not explicit recursivly mutate op->buffer_var + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + auto it = var_remap_.find(op->buffer_var); + if (it != var_remap_.end()) { + return Store(it->second, op->value, op->index, op->predicate); + } else { + return ret; } - return StmtExprMutator::VisitStmt_(newop); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { - auto itr = buffer_remap.find(op->buffer.operator->()); - const BufferLoadNode* newop; - BufferLoad newop_holder; - if (itr != buffer_remap.end()) { - newop_holder = BufferLoad(itr->second, op->indices); - newop = newop_holder.operator->(); + PrimExpr ret = StmtExprMutator::VisitExpr_(op); + op = ret.as(); + + auto it = buffer_remap_.find(op->buffer); + if (it != buffer_remap_.end()) { + return BufferLoad(it->second, op->indices); } else { - newop = op; + return ret; } - return StmtExprMutator::VisitExpr_(newop); } PrimExpr VisitExpr_(const LoadNode* op) final { - bool is_bf16 = false; + PrimExpr ret = StmtExprMutator::VisitExpr_(op); + op = ret.as(); + if (op->dtype.is_bfloat16()) { - is_bf16 = true; - } - PrimExpr index = this->VisitExpr(op->index); - PrimExpr predicate = this->VisitExpr(op->predicate); - if (index.same_as(op->index) && predicate.same_as(op->predicate) && !is_bf16) { - return GetRef(op); + auto it = var_remap_.find(op->buffer_var); + CHECK(it != var_remap_.end()) << "bfloat* var needs to be remapped"; + return Load(DataType::UInt(16, op->dtype.lanes()), it->second, op->index, op->predicate); } else { - return Load(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, op->buffer_var, - index, predicate); + return ret; } } @@ -320,20 +316,31 @@ class BF16LowerRewriter : StmtExprMutator { void AlterBuffers(PrimFuncNode* op) { std::vector> changes; + for (auto& itr : op->buffer_map) { auto oldbuf = itr.second; if (oldbuf->dtype.is_bfloat16()) { - auto newbuf = Buffer(oldbuf->data, DataType::UInt(16, oldbuf->dtype.lanes()), oldbuf->shape, - oldbuf->strides, oldbuf->elem_offset, oldbuf->name, oldbuf->scope, - oldbuf->data_alignment, oldbuf->offset_factor, oldbuf->buffer_type); - buffer_remap[oldbuf.operator->()] = newbuf; + DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes()); + Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype))); + auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset, + oldbuf->name, oldbuf->scope, oldbuf->data_alignment, + oldbuf->offset_factor, oldbuf->buffer_type); + buffer_remap_[oldbuf] = newbuf; + var_remap_[oldbuf->data] = buffer_var; changes.emplace_back(itr.first, newbuf); + } else { + changes.emplace_back(itr); } } - if (buffer_remap.size() != 0) { + + if (buffer_remap_.size() != 0) { op->buffer_map = Map(changes.begin(), changes.end()); } } + + private: + std::unordered_map buffer_remap_; + std::unordered_map var_remap_; }; namespace transform { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 8eb43f8ebc847..7475bf6d2f8ee 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -200,9 +200,9 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } - e.buffer = - Buffer(Var(op->buffer->data->name_hint, DataType::Handle()), op->buffer->dtype, shape, - strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault); + e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation), + op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, + skey.to_string(), align, 0, kDefault); buf_map_[key] = e; Stmt body = this->VisitStmt(op->body);