diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 68ca2663ede9..93a54b044fba 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -617,6 +617,23 @@ TVM_DECLARE_INTRIN_BINARY(hypot); TVM_DECLARE_INTRIN_BINARY(ldexp); namespace tir { + +/*! + * \brief Check if type is a pointer to a runtime element type. + * \param type The type to be checked. + * \param element_type The corresponding element type. + * \return The check results + */ +inline bool IsPointerType(const Type& type, const DataType& element_type) { + if (!type.defined()) return false; + if (const auto* ptr_type = type.as()) { + if (const auto* prim_type = ptr_type->element_type.as()) { + return prim_type->dtype == element_type; + } + } + return false; +} + /*! * \brief Make a const value with certain data type. * \param t The target type. diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 11bfb4c55921..bd7672a52d9a 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -20,7 +20,7 @@ from tvm._ffi.base import string_types from tvm.runtime import Object, convert -from tvm.ir import PrimExpr +from tvm.ir import PrimExpr, PointerType, PrimType from . import _ffi_api @@ -241,7 +241,7 @@ def decl_buffer(shape, shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32" elem_offset = Var('%s_elem_offset' % name, shape_dtype) if data is None: - data = Var(name, "handle") + data = Var(name, PointerType(PrimType(dtype))) return _ffi_api.Buffer( data, dtype, shape, strides, elem_offset, name, scope, data_alignment, offset_factor, buffer_type) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 20180d1be45d..b313e58a03af 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -17,7 +17,7 @@ """Developer API of IR node builder make function.""" from tvm._ffi.base import string_types from tvm.runtime import ObjectGeneric, DataType, convert, const -from tvm.ir import container as _container +from tvm.ir import container as _container, PointerType, PrimType from . import stmt as _stmt from . import expr as _expr @@ -325,7 +325,7 @@ def allocate(self, dtype, shape, name="buf", scope=None): buffer : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, dtype="handle") + buffer_var = _expr.Var(name, PointerType(PrimType(dtype))) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] if scope: diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 142bdfc70dce..14aa4fc56e2e 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -69,7 +69,7 @@ Target DefaultTargetHost(Target target) { tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std::string name, int data_alignment, int offset_factor, bool compact) { - auto data = tir::Var(name, DataType::Handle()); + auto data = tir::Var(name, PointerType(PrimType(dtype))); bool has_any = false; if (!compact) { for (const auto& it : shape) { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 00e3335633ec..d33f2ddf698a 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -383,9 +383,14 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, PrimExpr elem_offset, String name, String scope, int data_alignment, int offset_factor, BufferType buffer_type) { + CHECK(IsPointerType(data->type_annotation, dtype)) + << "Buffer data field expect to have the right pointer type annotation" + << " annotation=" << data->type_annotation << ", dtype=" << dtype; + auto n = make_object(); n->data = std::move(data); n->dtype = dtype; + n->shape = std::move(shape); n->strides = std::move(strides); n->name = std::move(name); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 296f49207cce..d9e1df46e8fa 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -263,6 +263,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Allocate Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, Stmt body) { + // TODO(tvm-team): Add invariant check to make sure + // IsPointerPType(buffer_var->type_annotation, dtype) + // once we fix the allocate hybrid script printing. for (size_t i = 0; i < extents.size(); ++i) { CHECK(extents[i].defined()); CHECK(extents[i].dtype().is_scalar()); diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 4a44b85684b2..97c96edc6ca7 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -172,14 +172,11 @@ uint16_t RoundToNearestEven(float src) { * Lower cast between bf16 and fp32 * Lower bf16 FloatImm to int16 */ -class BF16LowerRewriter : StmtExprMutator { +class BF16LowerRewriter : public StmtExprMutator { public: BF16LowerRewriter() {} - std::unordered_map buffer_remap; - std::unordered_map var_remap; - - Stmt operator()(Stmt s) { return VisitStmt(s); } + using StmtExprMutator::operator(); PrimExpr VisitExpr_(const CastNode* op) final { auto op_val = StmtExprMutator::VisitExpr(op->value); @@ -190,7 +187,6 @@ class BF16LowerRewriter : StmtExprMutator { auto uint32_v = Cast(uint32_dtype, op_val); // to be endian invariant. return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16}); - } else if (op->dtype.is_bfloat16()) { // if is cast_to_bf16, check if op->value is fp32 CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32); @@ -209,104 +205,104 @@ class BF16LowerRewriter : StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* op) final { - auto itr = var_remap.find(op); - if (itr != var_remap.end()) { + Var var = GetRef(op); + + auto itr = var_remap_.find(var); + if (itr != var_remap_.end()) { return itr->second; + } else { + return std::move(var); } - if (op->dtype.is_bfloat16()) { - CHECK(!op->type_annotation.defined()); - auto ret = Var(op->name_hint, op->dtype); - var_remap[op] = ret; - return std::move(ret); - } - return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const AllocateNode* op) final { - Stmt node_holder; - const AllocateNode* newop; if (op->dtype.is_bfloat16()) { - auto v = Allocate(op->buffer_var, DataType::UInt(16, op->dtype.lanes()), op->extents, - op->condition, op->body); - node_holder = v; - newop = static_cast(v.operator->()); + DataType dtype = DataType::UInt(16, op->dtype.lanes()); + Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype))); + var_remap_[op->buffer_var] = buffer_var; + return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body)); } else { - newop = op; + return StmtExprMutator::VisitStmt_(op); } - return StmtExprMutator::VisitStmt_(newop); } Stmt VisitStmt_(const BufferStoreNode* op) final { - auto itr = buffer_remap.find(op->buffer.operator->()); - const BufferStoreNode* newop; - BufferStore newop_holder; - if (itr != buffer_remap.end()) { - newop_holder = BufferStore(itr->second, op->value, op->indices); - newop = newop_holder.operator->(); + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + auto it = buffer_remap_.find(op->buffer); + if (it != buffer_remap_.end()) { + return BufferStore(it->second, op->value, op->indices); } else { - newop = op; + return ret; } - return StmtExprMutator::VisitStmt_(newop); } Stmt VisitStmt_(const AttrStmtNode* op) final { - const AttrStmtNode* newop = op; - Stmt newop_holder; - if (auto buffer = op->node.as()) { - auto itr = buffer_remap.find(buffer); - if (itr != buffer_remap.end()) { - newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body); - newop = newop_holder.as(); + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + if (auto* buffer = op->node.as()) { + auto it = buffer_remap_.find(GetRef(buffer)); + if (it != buffer_remap_.end()) { + return AttrStmt(it->second, op->attr_key, op->value, op->body); } - } else if (auto buffer = op->node.as()) { - auto itr = var_remap.find(buffer); - if (itr != var_remap.end()) { - newop_holder = AttrStmt(itr->second, op->attr_key, op->value, op->body); - newop = newop_holder.as(); + } else if (auto* var = op->node.as()) { + auto it = var_remap_.find(GetRef(var)); + if (it != var_remap_.end()) { + return AttrStmt(it->second, op->attr_key, op->value, op->body); } } - return StmtExprMutator::VisitStmt_(newop); + return ret; } Stmt VisitStmt_(const BufferRealizeNode* op) final { - auto itr = buffer_remap.find(op->buffer.operator->()); - const BufferRealizeNode* newop; - Stmt newop_holder; - if (itr != buffer_remap.end()) { - auto v = BufferRealize(itr->second, op->bounds, op->condition, op->body); - newop_holder = v; - newop = v.operator->(); + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + auto it = buffer_remap_.find(op->buffer); + if (it != buffer_remap_.end()) { + return BufferRealize(it->second, op->bounds, op->condition, op->body); } else { - newop = op; + return ret; + } + } + + Stmt VisitStmt_(const StoreNode* op) final { + // NOTE: we do not explicit recursivly mutate op->buffer_var + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + auto it = var_remap_.find(op->buffer_var); + if (it != var_remap_.end()) { + return Store(it->second, op->value, op->index, op->predicate); + } else { + return ret; } - return StmtExprMutator::VisitStmt_(newop); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { - auto itr = buffer_remap.find(op->buffer.operator->()); - const BufferLoadNode* newop; - BufferLoad newop_holder; - if (itr != buffer_remap.end()) { - newop_holder = BufferLoad(itr->second, op->indices); - newop = newop_holder.operator->(); + PrimExpr ret = StmtExprMutator::VisitExpr_(op); + op = ret.as(); + + auto it = buffer_remap_.find(op->buffer); + if (it != buffer_remap_.end()) { + return BufferLoad(it->second, op->indices); } else { - newop = op; + return ret; } - return StmtExprMutator::VisitExpr_(newop); } PrimExpr VisitExpr_(const LoadNode* op) final { - bool is_bf16 = false; + PrimExpr ret = StmtExprMutator::VisitExpr_(op); + op = ret.as(); + if (op->dtype.is_bfloat16()) { - is_bf16 = true; - } - PrimExpr index = this->VisitExpr(op->index); - PrimExpr predicate = this->VisitExpr(op->predicate); - if (index.same_as(op->index) && predicate.same_as(op->predicate) && !is_bf16) { - return GetRef(op); + auto it = var_remap_.find(op->buffer_var); + CHECK(it != var_remap_.end()) << "bfloat* var needs to be remapped"; + return Load(DataType::UInt(16, op->dtype.lanes()), it->second, op->index, op->predicate); } else { - return Load(is_bf16 ? DataType::UInt(16, op->dtype.lanes()) : op->dtype, op->buffer_var, - index, predicate); + return ret; } } @@ -320,20 +316,31 @@ class BF16LowerRewriter : StmtExprMutator { void AlterBuffers(PrimFuncNode* op) { std::vector> changes; + for (auto& itr : op->buffer_map) { auto oldbuf = itr.second; if (oldbuf->dtype.is_bfloat16()) { - auto newbuf = Buffer(oldbuf->data, DataType::UInt(16, oldbuf->dtype.lanes()), oldbuf->shape, - oldbuf->strides, oldbuf->elem_offset, oldbuf->name, oldbuf->scope, - oldbuf->data_alignment, oldbuf->offset_factor, oldbuf->buffer_type); - buffer_remap[oldbuf.operator->()] = newbuf; + DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes()); + Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype))); + auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset, + oldbuf->name, oldbuf->scope, oldbuf->data_alignment, + oldbuf->offset_factor, oldbuf->buffer_type); + buffer_remap_[oldbuf] = newbuf; + var_remap_[oldbuf->data] = buffer_var; changes.emplace_back(itr.first, newbuf); + } else { + changes.emplace_back(itr); } } - if (buffer_remap.size() != 0) { + + if (buffer_remap_.size() != 0) { op->buffer_map = Map(changes.begin(), changes.end()); } } + + private: + std::unordered_map buffer_remap_; + std::unordered_map var_remap_; }; namespace transform { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 8eb43f8ebc84..7475bf6d2f8e 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -200,9 +200,9 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } - e.buffer = - Buffer(Var(op->buffer->data->name_hint, DataType::Handle()), op->buffer->dtype, shape, - strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault); + e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation), + op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, + skey.to_string(), align, 0, kDefault); buf_map_[key] = e; Stmt body = this->VisitStmt(op->body);