diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 2e49a9c5185b..8f23e9e7f531 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -312,5 +312,34 @@ class FuncType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); }; +/*! + * \brief The type of tensor map. + * \sa TensorMapType + */ +class TensorMapTypeNode : public TypeNode { + public: + void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const TensorMapTypeNode* other, SEqualReducer equal) const { + return equal(span, other->span); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(span); } + + static constexpr const char* _type_key = "TensorMapType"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorMapTypeNode, TypeNode); +}; + +/*! + * \brief Managed reference to TensorMapTypeNode. + * \sa TensorMapTypeNode + */ +class TensorMapType : public Type { + public: + TVM_DLL TensorMapType(Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TensorMapType, Type, TensorMapTypeNode); +}; + } // namespace tvm #endif // TVM_IR_TYPE_H_ diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index febdac55d9aa..30b5bb3382f4 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -452,6 +452,8 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation); } +inline Var TensormapHandle() { return tvm::tir::Var("", PointerType(TensorMapType())); } + #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ inline PrimExpr FuncName(Optional expr = std::nullopt, bool is_size_var = false) { \ DataType dtype = DType; \ diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index 9ec8ef8fbd02..d0bf7014e27b 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -107,3 +107,19 @@ def __init__(self, arg_types, ret_type): arg_types, ret_type, ) + + +@tvm.ffi.register_object("TensorMapType") +class TensorMapType(Type): + """TensorMapType used in the low-level TIR. + + Parameters + ---------- + span : tvm.ir.Span + The span information. + """ + + def __init__(self, span=None): + self.__init_handle_by_constructor__( + _ffi_api.TensorMapType, span # pylint: disable=no-member + ) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index c7589f4a19a6..5864de2cac77 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1566,6 +1566,8 @@ def handle( res : PrimExpr The new tir.Var with type handle or casted expression with type handle. """ + if dtype == "tensormap": + return _ffi_api.TensormapHandle() # type: ignore[attr-defined] # pylint: disable=no-member is_unknown_type = dtype is None if dtype is None: dtype = "void" diff --git a/src/ir/type.cc b/src/ir/type.cc index 8bc48a11141f..732840157e22 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -82,4 +82,16 @@ TVM_FFI_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields) { return TupleType(fields); }); +TVM_FFI_REGISTER_GLOBAL("ir.TensorMapType").set_body_typed([](Span span) { + return TensorMapType(span); +}); + +TensorMapType::TensorMapType(Span span) { + ObjectPtr n = make_object(); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TensorMapTypeNode); + } // namespace tvm diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 399312e19321..98a83f4ed7e8 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -357,5 +357,211 @@ TVM_DLL int GetCudaDeviceCount() { TVM_FFI_REGISTER_GLOBAL("runtime.GetCudaDeviceCount").set_body_typed(GetCudaDeviceCount); +/** + * \brief FFI wrapper for cuTensorMapEncodeTiled. + * + * This function registers a global function `runtime.cuTensorMapEncodeTiled` that can be + * called from other parts of the TVM runtime (e.g., Python). It wraps the CUDA Driver API + * function `cuTensorMapEncodeTiled`, which initializes a tensor map descriptor (CUtensorMap). + * + * \param tensor_map (handle): A `void*` pointer to the CUtensorMap object to be initialized. + * \param tensor_dtype (DataType): The TVM data type of the tensor. + * \param tensor_rank (int): The rank (number of dimensions) of the tensor. + * \param tensor_ptr (handle): A `void*` pointer to the start of the tensor in global memory. + * \param global_shape (int...): `tensor_rank` integer arguments for the global tensor dimensions. + * \param global_strides (int...): `tensor_rank - 1` integer arguments for the global tensor + * strides. The stride for the innermost dimension is not provided as it's assumed to be contiguous. + * \param shared_shape (int...): `tensor_rank` integer arguments for the shape of the tile (box) + * in shared memory. + * \param shared_strides (int...): `tensor_rank` integer arguments for the strides of the tile (box) + * in shared memory. + * \param interleaved_kind (int): An integer corresponding to the CUtensorMapInterleave enum. + * \param swizzle_kind (int): An integer corresponding to the CUtensorMapSwizzle enum. + * \param l2_promotion_kind (int): An integer corresponding to the CUtensorMapL2promotion enum. + * \param oob_fill_kind (int): An integer corresponding to the CUtensorMapFloatOOBfill enum. + */ +TVM_FFI_REGISTER_GLOBAL("runtime.cuTensorMapEncodeTiled") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + CHECK_GE(args.size(), 4) << "init_cuTensorMap expects at least 4 arguments"; + size_t arg_cnt = 0; + CUtensorMap* tensor_map = static_cast(args[arg_cnt++].cast()); + runtime::DataType tensor_dtype = args[arg_cnt++].cast(); + uint32_t tensor_rank = static_cast(args[arg_cnt++].cast()); + void* tensor_ptr = static_cast(args[arg_cnt++].cast()); + + CHECK_EQ(args.size(), 4 + tensor_rank * 4 + 3) + << "cuTensorMapEncodeTiled expects " << 4 + tensor_rank * 4 + 3 << " arguments" + << "tensor_map, tensor_dtype, tensor_rank, tensor_ptr, global_shape(" << tensor_rank + << "), global_strides(" << tensor_rank - 1 << "), shared_shape(" << tensor_rank + << "), shared_strides(" << tensor_rank << "), interleaved_kind, swizzle_kind" + << ", l2_promotion_kind, oob_fill_kind"; + + std::vector global_shape(tensor_rank); + std::vector global_strides(tensor_rank); + std::vector shared_shape(tensor_rank); + std::vector shared_strides(tensor_rank); + for (size_t i = 0; i < tensor_rank; ++i) { + global_shape[i] = static_cast(args[arg_cnt++].cast()); + } + for (size_t i = 0; i < tensor_rank - 1; ++i) { + global_strides[i] = static_cast(args[arg_cnt++].cast()); + CHECK_EQ(global_strides[i] % 16, 0) << "global strides must be multiple of 16"; + } + for (size_t i = 0; i < tensor_rank; ++i) { + shared_shape[i] = static_cast(args[arg_cnt++].cast()); + CHECK_GE(shared_shape[i], 0) << "boxDim must be non-negative"; + CHECK_LE(shared_shape[i], 256) << "boxDim must be less than or equal to 256"; + } + for (size_t i = 0; i < tensor_rank; ++i) { + shared_strides[i] = static_cast(args[arg_cnt++].cast()); + } + auto interleaved_kind = static_cast(args[arg_cnt++].cast()); + auto swizzle_kind = static_cast(args[arg_cnt++].cast()); + auto l2_promotion_kind = static_cast(args[arg_cnt++].cast()); + auto oob_fill_kind = static_cast(args[arg_cnt++].cast()); + + ICHECK_EQ(tensor_dtype.lanes(), 1) + << "Expect tensor_dtype to have lanes=1, but get " << tensor_dtype; + CUtensorMapDataType cu_dtype; + switch (tensor_dtype.code()) { + case DataType::kInt: + // int + switch (tensor_dtype.bits()) { + case 8: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 32: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT32; + break; + case 64: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT64; + break; + default: + LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + } + break; + case DataType::kUInt: + // unsigned int + switch (tensor_dtype.bits()) { + case 8: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 16: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 32: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + case 64: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT64; + break; + default: + LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + } + break; + case DataType::kFloat: + // float + switch (tensor_dtype.bits()) { + case 16: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + break; + case 32: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + break; + case 64: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT64; + break; + default: + LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + } + break; + case DataType::kBFloat: + // bfloat + switch (tensor_dtype.bits()) { + case 16: + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + break; + default: + LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + } + break; + case DataType::kFloat8_e4m3fn: + // NV float8 e4m3 + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case DataType::kFloat8_e5m2: + // NV float8 e5m2 + cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + default: + LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + } + + // sanity checks per cuTensorMapEncodeTiled requirements + // see + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + CHECK_EQ((reinterpret_cast(tensor_ptr) & 0b1111), 0); // 16-byte alignment + CHECK_EQ((reinterpret_cast(tensor_map) & 0b111111), 0); // 64-byte alignment + CHECK_LE(tensor_rank, 5) << "cuTensorMapEncodeTiled only supports up to 5D tensors"; + + if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_32B) { + CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 32) + << "CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner dimension will be <= 32."; + } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_64B) { + CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 64) + << "CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner dimension will be <= 64."; + } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_128B) { + CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 128) + << "CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner dimension will be <= " + "128."; + } + + const cuuint64_t* global_shape_ptr = global_shape.data(); + const cuuint64_t* global_strides_ptr = global_strides.data(); + const uint32_t* shared_shape_ptr = shared_shape.data(); + const uint32_t* shared_strides_ptr = shared_strides.data(); + + CUresult res = + cuTensorMapEncodeTiled(tensor_map, cu_dtype, tensor_rank, tensor_ptr, global_shape_ptr, + global_strides_ptr, shared_shape_ptr, shared_strides_ptr, + interleaved_kind, swizzle_kind, l2_promotion_kind, oob_fill_kind); + const char* errstr; + cuGetErrorString(res, &errstr); + if (res != CUDA_SUCCESS) { + // get error string + const char* error_string = nullptr; + cuGetErrorString(res, &error_string); + std::cerr << "Error in cuTensorMapEncodeTiled: " << error_string << std::endl; + std::cout << "cu_dtype: " << cu_dtype << "\n"; + std::cout << "TMA Desc Addr: " << tensor_map << "\n"; + std::cout << "TMA Interleave: " << interleaved_kind << "\n"; + std::cout << "TMA L2Promotion: " << l2_promotion_kind << "\n"; + std::cout << "TMA OOBFill: " << oob_fill_kind << "\n"; + std::cout << "SMEM Swizzle: " << swizzle_kind << "\n"; + std::cout << "tensor rank: " << tensor_rank << "\n"; + std::cout << "global prob shape: "; + for (size_t i = 0; i < tensor_rank; i++) { + std::cout << global_shape[i] << " "; + } + std::cout << "\n"; + std::cout << "global prob stride: "; + for (size_t i = 0; i < tensor_rank; i++) { + std::cout << global_strides[i] << " "; + } + std::cout << "\n"; + std::cout << "smem box shape: "; + for (size_t i = 0; i < tensor_rank; i++) { + std::cout << shared_shape[i] << " "; + } + std::cout << "\n"; + std::cout << "smem box stride: "; + for (size_t i = 0; i < tensor_rank; i++) { + std::cout << shared_strides[i] << " "; + } + std::cout << "\n"; + CHECK_EQ(res, CUDA_SUCCESS) << "Error in cuTensorMapEncodeTiled: " << errstr; + } + }); + } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index a29d303acf7f..6d69fde5cdba 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -266,7 +266,7 @@ ffi::Function CUDAModuleNode::GetFunction(const String& name, const FunctionInfo& info = it->second; CUDAWrappedFunc f; f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); - return PackFuncVoidAddr(f, info.arg_types); + return PackFuncVoidAddr(f, info.arg_types, info.arg_extra_tags); } Module CUDAModuleCreate(std::string data, std::string fmt, diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 2aa377b9f8bd..513efbd9fbed 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -45,6 +45,11 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const { writer->WriteObjectKeyValue("name", name); writer->WriteObjectKeyValue("arg_types", sarg_types); writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags); + std::vector iarg_extra_tags(arg_extra_tags.size()); + for (size_t i = 0; i < arg_extra_tags.size(); ++i) { + iarg_extra_tags[i] = static_cast(arg_extra_tags[i]); + } + writer->WriteObjectKeyValue("arg_extra_tags", iarg_extra_tags); writer->EndObject(); } @@ -56,6 +61,12 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { helper.DeclareOptionalField("launch_param_tags", &launch_param_tags); helper.DeclareOptionalField("thread_axis_tags", &launch_param_tags); // for backward compatibility + std::vector iarg_extra_tags; + helper.DeclareOptionalField("arg_extra_tags", &iarg_extra_tags); + arg_extra_tags.resize(iarg_extra_tags.size()); + for (size_t i = 0; i < arg_extra_tags.size(); ++i) { + arg_extra_tags[i] = static_cast(iarg_extra_tags[i]); + } helper.ReadAllFields(reader); arg_types.resize(sarg_types.size()); for (size_t i = 0; i < arg_types.size(); ++i) { @@ -67,12 +78,14 @@ void FunctionInfo::Save(dmlc::Stream* writer) const { writer->Write(name); writer->Write(arg_types); writer->Write(launch_param_tags); + writer->Write(arg_extra_tags); } bool FunctionInfo::Load(dmlc::Stream* reader) { if (!reader->Read(&name)) return false; if (!reader->Read(&arg_types)) return false; if (!reader->Read(&launch_param_tags)) return false; + if (!reader->Read(&arg_extra_tags)) return false; return true; } diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 51120c1f9efb..8acefecaad8a 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -59,6 +59,9 @@ struct FunctionInfo { std::vector arg_types; std::vector launch_param_tags; + enum class ArgExtraTags : int { kNone = 0, kTensorMap = 1 }; + std::vector arg_extra_tags; + void Save(dmlc::JSONWriter* writer) const; void Load(dmlc::JSONReader* reader); void Save(dmlc::Stream* writer) const; diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 0068db51d522..8929f90b0f09 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -64,12 +64,15 @@ union ArgUnion64 { * * \param f with signiture (ffi::PackedArgs args, ffi::Any* rv, void* void_args) * \param arg_types The arguments type information. + * \param arg_extra_tags extra tags for the arguments * \tparam F the function type * * \return The wrapped packed function. */ template -inline ffi::Function PackFuncVoidAddr(F f, const std::vector& arg_types); +inline ffi::Function PackFuncVoidAddr( + F f, const std::vector& arg_types, + const std::vector& arg_extra_tags = {}); /*! * \brief Create a packed function that from function only packs buffer arguments. * @@ -130,7 +133,8 @@ enum ArgConvertCode { INT64_TO_UINT32, FLOAT64_TO_FLOAT32, FLOAT64_TO_FLOAT64, - HANDLE_TO_HANDLE + HANDLE_TO_HANDLE, + HANDLE_TO_TENSORMAP }; inline ArgConvertCode GetArgConvertCode(DLDataType t) { @@ -183,6 +187,10 @@ inline ffi::Function PackFuncVoidAddr_(F f, const std::vector& c addr[i] = &(holder[i]); break; } + case HANDLE_TO_TENSORMAP: { + addr[i] = raw_args[i].v_ptr; + break; + } } } f(args, ret, addr); @@ -222,7 +230,8 @@ inline ffi::Function PackFuncNonBufferArg_(F f, int base, holder[i].v_float32[0] = static_cast(raw_args[base + i].v_float64); break; } - case HANDLE_TO_HANDLE: { + case HANDLE_TO_HANDLE: + case HANDLE_TO_TENSORMAP: { LOG(FATAL) << "not reached"; break; } @@ -284,6 +293,7 @@ inline ffi::Function PackFuncPackedArgAligned_(F f, const std::vector -inline ffi::Function PackFuncVoidAddr(F f, const std::vector& arg_types) { +inline ffi::Function PackFuncVoidAddr( + F f, const std::vector& arg_types, + const std::vector& arg_extra_tags) { std::vector codes(arg_types.size()); for (size_t i = 0; i < arg_types.size(); ++i) { - codes[i] = detail::GetArgConvertCode(arg_types[i]); + if (arg_extra_tags.size() > i && arg_extra_tags[i] == FunctionInfo::ArgExtraTags::kTensorMap) { + codes[i] = detail::HANDLE_TO_TENSORMAP; + } else { + codes[i] = detail::GetArgConvertCode(arg_types[i]); + } } size_t num_void_args = arg_types.size(); // specialization diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 79e14feee47d..7ef970fa0971 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -830,6 +830,7 @@ TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.TensormapHandle").set_body_typed(TensormapHandle); TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void); TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.min") diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 549247449e33..d2f02f7908b9 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -37,19 +37,23 @@ ExprDoc PrintVarCreation(const tir::Var& var, const ObjectPath& var_p, const IRD } if (const auto* ptr_type = type.as()) { - const auto* prim_type = ptr_type->element_type.as(); - ICHECK(prim_type); - ExprDoc element_type = - LiteralDoc::DataType(prim_type->dtype, type_p->Attr("element_type")->Attr("dtype")); - rhs = TIR(d, "handle"); - rhs->source_paths.push_back(var_p->Attr("dtype")); - if (ptr_type->storage_scope == "") { - rhs = rhs->Call({element_type}, kwargs_keys, kwargs_values); - } else { - rhs = rhs->Call({element_type, - LiteralDoc::Str(ptr_type->storage_scope, // - type_p->Attr("storage_scope"))}, - kwargs_keys, kwargs_values); + if (const auto* prim_type = ptr_type->element_type.as()) { + ExprDoc element_type = + LiteralDoc::DataType(prim_type->dtype, type_p->Attr("element_type")->Attr("dtype")); + rhs = TIR(d, "handle"); + rhs->source_paths.push_back(var_p->Attr("dtype")); + if (ptr_type->storage_scope == "") { + rhs = rhs->Call({element_type}, kwargs_keys, kwargs_values); + } else { + rhs = rhs->Call({element_type, + LiteralDoc::Str(ptr_type->storage_scope, // + type_p->Attr("storage_scope"))}, + kwargs_keys, kwargs_values); + } + } else if (ptr_type->element_type->IsInstance()) { + rhs = TIR(d, "handle") + ->Call({LiteralDoc::Str("tensormap", type_p->Attr("element_type")->Attr("dtype"))}, + {}, {}); } } else { rhs = TIR(d, DType2Str(var->dtype)); diff --git a/src/target/build_common.h b/src/target/build_common.h index 70f15d091ed2..fda7e2e67c0f 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -49,6 +49,16 @@ inline std::unordered_map ExtractFuncInfo(co runtime::FunctionInfo info; for (size_t i = 0; i < f->params.size(); ++i) { info.arg_types.push_back(f->params[i].dtype()); + auto is_tensormap = [](const tir::Var& var) -> bool { + const auto* type = var->type_annotation.as(); + if (type == nullptr) { + return false; + } + return type->element_type.as() != nullptr; + }; + info.arg_extra_tags.push_back(is_tensormap(f->params[i]) + ? runtime::FunctionInfo::ArgExtraTags::kTensorMap + : runtime::FunctionInfo::ArgExtraTags::kNone); } if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { for (const auto& tag : opt.value()) { diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 76e825ab7546..b16617e3d6bc 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -1037,6 +1037,10 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { return builder_->CreateAlloca(t_tvm_ffi_any_, num); } else if (type == "array") { return builder_->CreateAlloca(t_tvm_array_, num); + } else if (type == "tensormap") { + auto* alloca = builder_->CreateAlloca(t_tvm_tensormap_, num); + alloca->setAlignment(llvm::Align(64)); + return alloca; } else { LOG(FATAL) << "Unknown stack alloca type " << type; } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index e9bcfa97fd01..45dafa85b939 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -154,6 +154,8 @@ void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target, t_int32_ = llvm::Type::getInt32Ty(*ctx); t_int64_ = llvm::Type::getInt64Ty(*ctx); t_float64_ = llvm::Type::getDoubleTy(*ctx); + // CUTensorMap is a 128 byte struct, so we use a 128 byte array to represent it. + t_tvm_tensormap_ = llvm::ArrayType::get(t_char_, 128); // meta data md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1); md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa"); @@ -620,11 +622,15 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const { if (primtype->dtype.is_void() || primtype->dtype.code() >= DataType::kCustomBegin) { return t_void_p_; } + } else if (ptr->element_type->IsInstance()) { + return t_tvm_tensormap_->getPointerTo(); } // TODO(tvm-team) consider put storage scope into the pointer type. return llvmGetPointerTo(GetLLVMType(ptr->element_type), GetGlobalAddressSpace()); } else if (IsVoidType(type)) { return t_void_; + } else if (type->IsInstance()) { + return t_tvm_tensormap_; } else { LOG(FATAL) << "Type " << type << " does not have a corresponding LLVM Type"; } @@ -2292,7 +2298,7 @@ llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir) { return GetDebugType(ty_tir, GetLLVMType(ty_tir)); } llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) { - if (ty_llvm == nullptr || ty_llvm == t_void_) { + if (ty_llvm == nullptr || ty_llvm == t_void_ || ty_llvm == t_tvm_tensormap_) { return nullptr; } else if (ty_llvm->isPointerTy()) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index f7e4e819030e..e1667b637578 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -540,6 +540,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::Type* t_int32_{nullptr}; llvm::Type* t_int64_{nullptr}; llvm::Type* t_float64_{nullptr}; + llvm::ArrayType* t_tvm_tensormap_{nullptr}; // meta data llvm::MDNode* md_very_likely_branch_{nullptr}; llvm::MDNode* md_tbaa_root_{nullptr}; diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 344c0857c4d4..11f0eaf1ba7b 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -93,10 +93,24 @@ void CodeGenC::PrintFunctionSignature(const String& function_name, const PrimFun PrintStorageScope(it->second, os); } - PrintType(GetType(v), os); + auto is_tensormap_ptr = [&]() -> bool { + if (auto* ptr = v->type_annotation.as()) { + return ptr->element_type.as(); + } + return false; + }; + if (is_tensormap_ptr()) { + os << "const __grid_constant__ CUtensorMap"; + } else { + PrintType(GetType(v), os); + } bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias); bool is_handle = v.dtype().is_handle(); + auto* ptr = v->type_annotation.as(); + if (ptr && ptr->element_type.as()) { + is_handle = false; + } if (no_alias && is_handle) { PrintRestrict(v, os); } diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index d4e1b785b866..21fbc20f473d 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -182,6 +182,8 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) { } std::string CodeGenCUDA::Finish() { + decl_stream << "#include \n"; + if (enable_fp16_) { decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n"; decl_stream << "#include \n"; @@ -194,7 +196,6 @@ std::string CodeGenCUDA::Finish() { decl_stream << _cuda_half_t_def; decl_stream << "#endif\n\n"; - decl_stream << "#include \n"; decl_stream << _cuda_half_util; } diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index e96217034f4b..063ed0469b50 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -746,5 +746,36 @@ def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None: tvm.compile(func, target="cuda") +@tvm.testing.requires_cuda +@tvm.testing.requires_cuda_compute_version(9) +def test_cuda_tensormap(): + # fmt: off + @T.prim_func + def main(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) + + A_map: T.handle("tensormap") = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapInit", A_map, "float32", 2, A.data, + 16, 16, 64, 16, 16, 1, 1, 0, 0, 0, 0) + + for blockIdx in T.thread_binding(1, thread="blockIdx.x"): + for threadIdx in T.thread_binding(128, thread="threadIdx.x"): + if threadIdx == 0: + A[0, 0] = T.reinterpret("float64", A_map) + # fmt: on + + mod = tvm.IRModule({"main": main}) + mod = tvm.compile(mod, target="cuda") + assert ( + """ +extern "C" __global__ void __launch_bounds__(128) main_kernel(float* __restrict__ A, const __grid_constant__ CUtensorMap A_map) { + if (((int)threadIdx.x) == 0) { + A[0] = ((float)(*(double *)(&(A_map)))); + } +}""".strip() + in mod.mod.imported_modules[0].get_source() + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index f620610f3977..1858c00e8662 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -239,7 +239,8 @@ def test_inject_async_copy_barrier(): tvm.testing.assert_allclose(B_nd.numpy(), A_np) -expected_cuda_script = r"""__forceinline__ __device__ unsigned int +expected_cuda_script = r"""#include +__forceinline__ __device__ unsigned int cast_smem_ptr_to_int(const void* const smem_ptr) { unsigned int smem_int; @@ -469,6 +470,7 @@ def simple_compute( with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): tvm.compile(mod, target="cuda") generated_code = postproc_if_missing_async_support() + print(generated_code) assert generated_code == expected_cuda_script