diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 3d35d5241e11..76e5e3833f17 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -58,7 +58,7 @@ class DataType { kBFloat = kDLBfloat, kE4M3Float = 6U, kE5M2Float = 7U, - kE2M1Float = 8U, + kFloat4E2M1Fn = 8U, kCustomBegin = 129 }; /*! \brief default constructor */ @@ -88,7 +88,7 @@ class DataType { if (code == kE4M3Float || code == kE5M2Float) { ICHECK_EQ(bits, 8); } - if (code == kE2M1Float) { + if (code == kFloat4E2M1Fn) { ICHECK_EQ(bits, 4); } } @@ -131,12 +131,10 @@ class DataType { bits() == 8; } /*! \return whether type is a float4 type. */ - bool is_float4() const { return code() == DataType::kE2M1Float && bits() == 4; } + bool is_float4() const { return code() == DataType::kFloat4E2M1Fn && bits() == 4; } bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && bits() == 8); } - bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && bits() == 8); } - - bool is_e2m1_float4() const { return (code() == DataType::kE2M1Float && bits() == 4); } + bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4E2M1Fn && bits() == 4); } /*! \return whether type is a float16 type. */ bool is_float16() const { return is_float() && bits() == 16; } /*! \return whether type is a bfloat16 type. */ @@ -262,11 +260,11 @@ class DataType { */ static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); } /*! - * \brief Construct NV float4 e2m1 datatype. + * \brief Construct NV float4_e2m1fn datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat4E2M1(int lanes = 1) { return DataType(kE2M1Float, 4, lanes); } + static DataType NVFloat4E2M1FN(int lanes = 1) { return DataType(kFloat4E2M1Fn, 4, lanes); } /*! * \brief Construct a bool type. * \param lanes The number of lanes. @@ -313,7 +311,7 @@ inline int GetVectorBytes(DataType dtype) { int data_bits = dtype.bits() * dtype.lanes(); // allow bool to exist if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || - dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1()) { + dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1FN()) { return 1; } ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; @@ -399,8 +397,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) { return "e4m3_float"; case DataType::kE5M2Float: return "e5m2_float"; - case DataType::kE2M1Float: - return "e2m1_float"; + case DataType::kFloat4E2M1Fn: + return "float4_e2m1fn"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); } @@ -458,6 +456,18 @@ inline DLDataType String2DLDataType(std::string s) { } else if (s.substr(0, 4) == "uint") { t.code = kDLUInt; scan = s.c_str() + 4; + } else if (s.substr(0, 13) == "float4_e2m1fn") { + // Avoid being treated as "float" + t.code = DataType::kFloat4E2M1Fn; + t.bits = 4; + scan = s.c_str() + 13; + char* endpt = nullptr; + if (*scan == 'x') { + t.lanes = static_cast(strtoul(scan + 1, &endpt, 10)); + scan = endpt; + } + ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s; + return t; } else if (s.substr(0, 5) == "float") { t.code = kDLFloat; scan = s.c_str() + 5; @@ -482,10 +492,6 @@ inline DLDataType String2DLDataType(std::string s) { t.code = DataType::kE5M2Float; t.bits = 8; scan = s.c_str() + 10; - } else if (s.substr(0, 10) == "e2m1_float") { - t.code = DataType::kE2M1Float; - t.bits = 4; - scan = s.c_str() + 10; } else if (s.substr(0, 6) == "custom") { t.code = ParseCustomDatatype(s, &scan); } else { diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 5dd1a5c733c5..e78e0d51fdc1 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -505,7 +505,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8, DataType::NVFloat8E4M3); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8, DataType::NVFloat8E5M2); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E2M1Float4, DataType::NVFloat4E2M1); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1fn, DataType::NVFloat4E2M1FN); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool()); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 263a4ff69fa6..3f4ceadd1d20 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -68,7 +68,7 @@ class DataTypeCode(object): BFLOAT = 4 E4M3Float = 6 E5M2Float = 7 - E2M1Float = 8 + FLOAT4E2M1FN = 8 class DataType(ctypes.Structure): @@ -83,7 +83,7 @@ class DataType(ctypes.Structure): DataTypeCode.BFLOAT: "bfloat", DataTypeCode.E4M3Float: "e4m3_float", DataTypeCode.E5M2Float: "e5m2_float", - DataTypeCode.E2M1Float: "e2m1_float", + DataTypeCode.FLOAT4E2M1FN: "float4_e2m1fn", } NUMPY2STR = { np.dtype(np.bool_): "bool", @@ -114,7 +114,7 @@ class DataType(ctypes.Structure): "uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1}, "e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1}, "e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1}, - "e2m1_float4": {"type_code": DataTypeCode.E2M1Float, "bits": 4, "lanes": 1}, + "float4_e2m1fn": {"type_code": DataTypeCode.FLOAT4E2M1FN, "bits": 4, "lanes": 1}, "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1}, "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1}, "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1}, @@ -155,6 +155,11 @@ def __init__(self, type_str): elif head.startswith("uint"): self.type_code = DataTypeCode.UINT head = head[4:] + elif head.startswith("float4_e2m1fn"): + # Avoid being treated as "float" + self.type_code = DataTypeCode.FLOAT4E2M1FN + bits = 4 + head = "" elif head.startswith("float"): self.type_code = DataTypeCode.FLOAT head = head[5:] @@ -171,9 +176,6 @@ def __init__(self, type_str): elif head.startswith("e5m2_float"): self.type_code = DataTypeCode.E5M2Float head = head[10:] - elif head.startswith("e2m1_float"): - self.type_code = DataTypeCode.E2M1Float - head = head[10:] elif head.startswith("custom"): # pylint: disable=import-outside-toplevel import tvm.runtime._ffi_api @@ -201,7 +203,12 @@ def __repr__(self): import tvm.runtime._ffi_api type_name = "custom[%s]" % tvm.runtime._ffi_api._datatype_get_type_name(self.type_code) - x = "%s%d" % (type_name, self.bits) + if self.type_code in [ + DataTypeCode.FLOAT4E2M1FN, + ]: + x = type_name + else: + x = "%s%d" % (type_name, self.bits) lanes_as_int = ctypes.c_int16(self.lanes).value if lanes_as_int > 1: x += "x%d" % self.lanes @@ -238,7 +245,7 @@ def itemsize(self): DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16" DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8" DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8" - DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "e2m1_float4" + DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn" RPC_SESS_MASK = 128 diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 3514ee6168d1..47fcccf52bac 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -197,7 +197,9 @@ def copyfrom(self, source_array): source_array = np.ascontiguousarray( source_array, dtype="uint16" if dtype == "bfloat16" else dtype ) - if dtype.startswith("e2m1_float4"): + if self.dtype.startswith("float4_e2m1fn") and self.dtype != "float4_e2m1fn": + # float4_e2m1fn in numpy is not packed. + # So we need to pack the input data when converting to vectorized float4_e2m1fn type. data_bits = source_array.view(dtype="uint8") if data_bits.size % 2: data_bits = np.pad(data_bits, (0, 1), mode="constant", constant_values=0) @@ -261,22 +263,24 @@ def numpy(self): raise RuntimeError( "ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy." ) - if dtype == "e2m1_float4": + if dtype == "float4_e2m1fn": if ml_dtypes is not None: dtype = ml_dtypes.float4_e2m1fn else: raise RuntimeError( - "ml_dtypes is not installed, cannot convert e2m1_float4 array to numpy." + "ml_dtypes is not installed, cannot convert float4_e2m1fn array to numpy." ) np_arr = np.empty(shape, dtype=dtype) assert np_arr.flags["C_CONTIGUOUS"] data = np_arr.ctypes.data_as(ctypes.c_void_p) - if old_dtype.startswith("e2m1_float4"): + if old_dtype.startswith("float4_e2m1fn") and old_dtype != "float4_e2m1fn": nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize // 2) else: nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize) check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes)) - if old_dtype == "int4" or old_dtype.startswith("e2m1_float4"): + if old_dtype == "int4" or ( + old_dtype.startswith("float4_e2m1fn") and old_dtype != "float4_e2m1fn" + ): length = np_arr.size np_arr = np_arr.view("int8") np_arr_ret = np.empty((length,), dtype="int8") diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 6cc19305e49f..c35df7a093ef 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -18,8 +18,8 @@ import functools import inspect -from numbers import Integral import sys +from numbers import Integral from typing import Any, Callable, Dict, List, Optional, Tuple, Union # isort: off @@ -29,8 +29,7 @@ import numpy as np # type: ignore -from tvm import tir -from tvm import ir +from tvm import ir, tir from tvm.ir import Type from tvm.ir.base import deprecated from tvm.runtime import String, convert, ndarray @@ -1457,12 +1456,13 @@ def func( e5m2_float8x32 = func_gen(("E5M2Float8x32")) e5m2_float8x64 = func_gen(("E5M2Float8x64")) -e2m1_float4 = func_gen(("E2M1Float4")) -e2m1_float4x4 = func_gen(("E2M1Float4x4")) -e2m1_float4x8 = func_gen(("E2M1Float4x8")) -e2m1_float4x16 = func_gen(("E2M1Float4x16")) -e2m1_float4x32 = func_gen(("E2M1Float4x32")) -e2m1_float4x64 = func_gen(("E2M1Float4x64")) +float4_e2m1fn = func_gen(("Float4E2M1fn")) +float4_e2m1fnx2 = func_gen(("Float4E2M1fnx2")) +float4_e2m1fnx4 = func_gen(("Float4E2M1fnx4")) +float4_e2m1fnx8 = func_gen(("Float4E2M1fnx8")) +float4_e2m1fnx16 = func_gen(("Float4E2M1fnx16")) +float4_e2m1fnx32 = func_gen(("Float4E2M1fnx32")) +float4_e2m1fnx64 = func_gen(("Float4E2M1fnx64")) # pylint: enable=invalid-name @@ -2013,37 +2013,38 @@ def wrapped(*args, **kwargs): "uint64x64", "e4m3_float8", "e5m2_float8", - "e2m1_float4", + "float4_e2m1fn", "float16", "float32", "float64", + "float4_e2m1fnx2", "e4m3_float8x4", "e5m2_float8x4", - "e2m1_float4x4", + "float4_e2m1fnx4", "float16x4", "float32x4", "float64x4", "e4m3_float8x8", "e5m2_float8x8", - "e2m1_float4x8", + "float4_e2m1fnx8", "float16x8", "float32x8", "float64x8", "e4m3_float8x16", "e5m2_float8x16", - "e2m1_float4x16", + "float4_e2m1fnx16", "float16x16", "float32x16", "float64x16", "e4m3_float8x32", "e5m2_float8x32", - "e2m1_float4x32", + "float4_e2m1fnx32", "float16x32", "float32x32", "float64x32", "e4m3_float8x64", "e5m2_float8x64", - "e2m1_float4x64", + "float4_e2m1fnx64", "float16x64", "float32x64", "float64x64", diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 1812d13b9c14..d87606532509 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -54,7 +54,7 @@ inline void VerifyDataType(DLDataType dtype) { return; else if (dtype.bits == 4 && dtype.code == kDLInt) return; - else if (dtype.bits == 4 && dtype.code == DataType::kE2M1Float) + else if (dtype.bits == 4 && dtype.code == DataType::kFloat4E2M1Fn) return; else ICHECK_EQ(dtype.bits % 8, 0); diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index e452e102bf1e..a73c9cb5b4ce 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -757,8 +757,8 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8); TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.E2M1Float4").set_body_typed(E2M1Float4); -TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E2M1Float4", E2M1Float4); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1fn").set_body_typed(Float4E2M1fn); +TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1fn", Float4E2M1fn); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index f5f17c70ef9d..9c2ce0bbb26b 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -581,7 +581,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { } } else if (dtype.code() == DataType::kE4M3Float || dtype.code() == DataType::kE5M2Float) { etype = llvm::Type::getInt8Ty(*ctx); - } else if (dtype.code() == DataType::kE2M1Float) { + } else if (dtype.code() == DataType::kFloat4E2M1Fn) { etype = llvm::Type::getIntNTy(*ctx, 4); } if (!dtype.is_scalar()) { diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 0e0971b8f86c..575f52e2257a 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -789,6 +789,11 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI } } + if (value_dtype.is_float4_e2m1fn() && lanes != 1) { + // A float4_e2m1fn element has 4 bits, which is an incomplete byte. + // So we cannot vector load it. + can_vector_load = false; + } if (can_vector_load) { std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval()); HandleVolatileLoads(ref, op, os); @@ -839,7 +844,8 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) { } else { arith::PVar base; - if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr)) { + if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr) && + !value_dtype.is_float4_e2m1fn()) { std::string value = this->PrintExpr(op->value); this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value); } else { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 872a02436676..20b29750dc1b 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -82,11 +82,11 @@ std::string GetFP4Type(DataType type) { } else if (lanes == 4) { vec = "x4"; } else { - LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for FP8"; + LOG(FATAL) << "Only support scalar and vector types of width (2, 4) for FP8"; } stream << "__nv_fp4"; std::string suffix; - if (type.code() == DataType::kE2M1Float) { + if (type.code() == DataType::kFloat4E2M1Fn) { suffix = "_e2m1"; } else { LOG(FATAL) << "Unsupported FP8 type in CUDA codegen"; @@ -196,7 +196,7 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; decl_stream << "#endif\n\n"; } - declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_); + declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_, enable_fp4_); if (enable_warp_shuffle_) { decl_stream << _cuda_warp_intrinsic_util; @@ -597,6 +597,9 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, } ICHECK(!type_name.empty()); os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + } else if (t.is_float4_e2m1fn()) { + os << "([](__nv_fp4_storage_t v) { __nv_fp4_e2m1 t; t.__x = v; return t; })((" << vec + << ".__x >> " << i * 4 << ") & 0xF)"; } else { os << vec << "." << access[i]; } @@ -732,8 +735,8 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); if (target_ty.code() == DataType::kE4M3Float || target_ty.code() == DataType::kE5M2Float || - target_ty.code() == DataType::kE2M1Float || from_ty.code() == DataType::kE4M3Float || - from_ty.code() == DataType::kE5M2Float || from_ty.code() == DataType::kE2M1Float) { + target_ty.code() == DataType::kFloat4E2M1Fn || from_ty.code() == DataType::kE4M3Float || + from_ty.code() == DataType::kE5M2Float || from_ty.code() == DataType::kFloat4E2M1Fn) { std::ostringstream val; val << "("; PrintType(target_ty, val); @@ -1036,8 +1039,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id"; os << "for (int local_id = 0; local_id < 8; ++local_id) {\n"; - os << dst << "[" + this->PrintExpr(dst_ind) + "]" - << " = " << src << "[" << src_offset << " + local_id];\n"; + os << dst << "[" + this->PrintExpr(dst_ind) + "] = " << src << "[" << src_offset + << " + local_id];\n"; os << "}\n"; } else if (op->op.same_as(builtin::mma_fill())) { @@ -1155,6 +1158,82 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)" << guard << ")\n"; stream << ");\n"; + } else if (op->op.same_as(builtin::reinterpret())) { + DataType tgt_dtype = op->dtype; + DataType src_dtype = op->args[0]->dtype; + PrimExpr value = op->args[0]; + + // Handle float4_e2m1fn reinterpret + if (!src_dtype.is_float4_e2m1fn() && !tgt_dtype.is_float4_e2m1fn()) { + return CodeGenC::VisitExpr_(op, os); + } + if (src_dtype == tgt_dtype || + tgt_dtype.lanes() * tgt_dtype.bits() == src_dtype.lanes() * src_dtype.bits()) { + return CodeGenC::VisitExpr_(op, os); + } + CHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes()) + << "E2M1 float4 reinterpret expects source and target to have the same number of lanes. " + << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype; + CHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes()) + << "E2M1 float4 reinterpret expects source and target to have the same number of bytes. " + << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype; + + int lanes = tgt_dtype.lanes(); + + int ssa_scope = BeginScope(); + if (lanes == 1) { + // The case of lane=1 is same as the normal reinterpret, + // except that we allow the src and dst dtype to have different number of bits. + std::string rhs = SSAGetID(PrintExpr(value), src_dtype); + os << "(*("; + this->PrintType(tgt_dtype, os); + os << " *)(&(" << rhs << ")))"; + } else if (lanes == 2) { + if (tgt_dtype.is_float4_e2m1fn()) { + // We view the source as an uint16, and then extract bits of two fp4 numbers, + // and finally reinterpret the result as fp4x2. + value = tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value}); + tir::Var temp_var("temp_var", DataType::UInt(16)); + value = tir::Let( + temp_var, value, + tir::Cast(DataType::UInt(8), (temp_var & IntImm(DataType::UInt(16), 0xF)) | + ((temp_var >> 4) & IntImm(DataType::UInt(16), 0xF0)))); + } else { + value = tir::Cast(DataType::UInt(16), + tir::Call(DataType::UInt(8), tir::builtin::reinterpret(), {value})); + tir::Var temp_var("temp_var", DataType::UInt(16)); + value = tir::Let(temp_var, value, + (temp_var & IntImm(DataType::UInt(16), 0xF)) | + ((temp_var & IntImm(DataType::UInt(16), 0xF0)) << 4)); + } + os << PrintExpr(tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value})); + } else if (lanes == 4) { + if (tgt_dtype.is_float4_e2m1fn()) { + // We view the source as an uint32, and then extract bits of four fp4 numbers, + // and finally reinterpret the result as fp4x4. + value = tir::Call(DataType::UInt(32), tir::builtin::reinterpret(), {value}); + tir::Var temp_var("temp_var", DataType::UInt(32)); + value = tir::Let(temp_var, value, + tir::Cast(DataType::UInt(16), + (temp_var & IntImm(DataType::UInt(32), 0xF)) | + ((temp_var >> 4) & IntImm(DataType::UInt(32), 0xF0)) | + ((temp_var >> 8) & IntImm(DataType::UInt(32), 0xF00)) | + ((temp_var >> 12) & IntImm(DataType::UInt(32), 0xF000)))); + } else { + value = tir::Cast(DataType::UInt(32), + tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value})); + tir::Var temp_var("temp_var", DataType::UInt(32)); + value = tir::Let(temp_var, value, + (temp_var & IntImm(DataType::UInt(32), 0xF)) | + ((temp_var & IntImm(DataType::UInt(32), 0xF0)) << 4) | + ((temp_var & IntImm(DataType::UInt(32), 0xF00)) << 8) | + ((temp_var & IntImm(DataType::UInt(32), 0xF000)) << 12)); + } + os << PrintExpr(tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value})); + } else { + LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: " << lanes; + } + EndScope(ssa_scope); } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index abdf22df2616..86f2219fe8cb 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -385,8 +385,9 @@ static constexpr const char* _cuda_warp_intrinsic_util = R"( )"; -void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16, bool enable_fp8) { - if (enable_fp16 || enable_fp8) { +void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16, bool enable_fp8, + bool enable_fp4) { + if (enable_fp16 || enable_fp8 || enable_fp4) { stream << R"( struct __align__(8) half4 { __half x, y, z, w; @@ -455,6 +456,26 @@ struct __align__(8) half4 { result.__x = (a) | (b << 8) | (c << 16) | (d << 24); return result; } + )"; + } + if (enable_fp4) { + stream << R"( + __host__ __device__ explicit half4(const __nv_fp4x4_e2m1& fp4x4) { + __nv_fp4x2_storage_t lo_part, hi_part; + lo_part = static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF); + hi_part = static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF); + __half2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1)); + __half2 hi_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(hi_part, __NV_E2M1)); + x = reinterpret_cast<__half*>(&lo_half2)[0]; + y = reinterpret_cast<__half*>(&lo_half2)[1]; + z = reinterpret_cast<__half*>(&hi_half2)[0]; + w = reinterpret_cast<__half*>(&hi_half2)[1]; + } + __host__ __device__ explicit operator __nv_fp4x4_e2m1() const { + __half2 lo_half2 = *reinterpret_cast(&x); + __half2 hi_half2 = *reinterpret_cast(&z); + return __nv_fp4x4_e2m1(lo_half2, hi_half2); + } )"; } stream << R"( @@ -462,6 +483,20 @@ struct __align__(8) half4 { __host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) { return half4(x, y, z, w); } +)"; + } + if (enable_fp4) { + stream << R"( +__device__ __nv_fp4x2_e2m1 make___nv_fp4x2_e2m1(__nv_fp4_e2m1 x, __nv_fp4_e2m1 y) { + __nv_fp4x2_e2m1 result; + result.__x = (x.__x) | (y.__x << 4); + return result; +} +__device__ __nv_fp4x4_e2m1 make___nv_fp4x4_e2m1(__nv_fp4_e2m1 a, __nv_fp4_e2m1 b, __nv_fp4_e2m1 c, __nv_fp4_e2m1 d) { + __nv_fp4x4_e2m1 result; + result.__x = (static_cast<__nv_fp4x4_storage_t>(a.__x)) | (static_cast<__nv_fp4x4_storage_t>(b.__x) << 4) | (static_cast<__nv_fp4x4_storage_t>(c.__x) << 8) | (static_cast<__nv_fp4x4_storage_t>(d.__x) << 12); + return result; +} )"; } } diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 039acf7e929d..3dab634f162e 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -425,8 +425,10 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) { if (value.dtype() == t) return value; if (!t.is_scalable_vector() && !value.dtype().is_scalable_vector()) { - ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes()) - << "Bitcast requires size match " << t << " vs " << value.dtype(); + ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes() || + ((value.dtype().is_float4_e2m1fn() || t.is_float4_e2m1fn()) && + value.dtype().bytes() * value.dtype().lanes() == t.bytes() * t.lanes())) + << "Reinterpret requires size match " << t << " vs " << value.dtype(); } return tir::Call(t, tir::builtin::reinterpret(), {value}, span); } diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py b/tests/python/codegen/test_target_codegen_cuda_fp4.py index f137e83cc942..46825826a9d1 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp4.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. +from itertools import product -import tvm -from tvm.script import tir as T import numpy as np + +import tvm import tvm.testing -from tvm.script import ir as I, relax as R, tir as T +from tvm.script import tir as T try: import ml_dtypes @@ -28,12 +29,12 @@ ml_dtypes = None native_dtype, promoted_dtype = tvm.testing.parameters( - ("e2m1_float4x2", "float32x2"), - ("e2m1_float4x2", "float16x2"), + ("float4_e2m1fnx2", "float32x2"), + ("float4_e2m1fnx2", "float16x2"), ) -@tvm.testing.requires_cuda_compute_version(9) +@tvm.testing.requires_cuda_compute_version(10) def test_e2m1_vector_conversions(native_dtype, promoted_dtype): vector_length = 64 @@ -63,7 +64,6 @@ def add( target = "cuda" fadd = tvm.build(sch.mod, target=target) - cuda_src = fadd.imported_modules[0].get_source() dev = tvm.device(target, 0) numpytype = "float4_e2m1fn" @@ -92,5 +92,124 @@ def add( ) +@tvm.testing.requires_cuda_compute_version(10) +def test_e2m1_schedule_vectorize(): + native_dtype = "float4_e2m1fn" + n = 128 + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + for promoted_dtype, vector_length in product( + ["float16", "bfloat16", "float32"], + [1, 2, 4], + ): + + @T.prim_func + def add( + A: T.Buffer((n,), native_dtype), + B: T.Buffer((n,), native_dtype), + C: T.Buffer((n,), native_dtype), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(n): + with T.block("C"): + v_i = T.axis.spatial(n, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = T.Cast( + native_dtype, + T.Cast(promoted_dtype, A[v_i]) + T.Cast(promoted_dtype, B[v_i]), + ) + + sch = tvm.tir.Schedule(add) + block = sch.get_block("C") + b = sch.get_loops(block) + bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length]) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + fadd = tvm.build(sch.mod, target=target) + + numpytype = "float4_e2m1fn" + promoted_base_dtype = promoted_dtype + + a_np = np.random.uniform(low=-6, high=6, size=(n,)).astype(numpytype) + a = tvm.nd.empty(shape=(n,), dtype=native_dtype, device=dev) + a.copyfrom(a_np) + b_np = np.random.uniform(low=-6, high=6, size=(n,)).astype(numpytype) + b = tvm.nd.empty(shape=(n,), dtype=native_dtype, device=dev) + b.copyfrom(b_np) + c = tvm.nd.empty(shape=(n,), dtype=native_dtype, device=dev) + fadd(a, b, c) + + if promoted_base_dtype != "bfloat16": + tvm.testing.assert_allclose( + c.numpy().astype(promoted_base_dtype), (a_np + b_np).astype(promoted_base_dtype) + ) + else: + # assert_allclose with bfloat16 throws an error here. + # Thus we convert bfloat16 to float32 for comparison. + tvm.testing.assert_allclose( + c.numpy().astype(promoted_base_dtype).astype("float32"), + (a_np + b_np).astype(promoted_base_dtype).astype("float32"), + ) + + +@tvm.testing.requires_cuda_compute_version(10) +def test_e2m1_reinterpret(): + n = 128 + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + + def get_reinterpret_mod(src_dtype, dst_dtype, vector_length): + @T.prim_func + def reinterpret( + A: T.Buffer((n,), src_dtype), + B: T.Buffer((n,), dst_dtype), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(n): + with T.block("C"): + v_i = T.axis.spatial(n, i) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = T.reinterpret(dst_dtype, A[v_i]) + + sch = tvm.tir.Schedule(reinterpret) + block = sch.get_block("C") + b = sch.get_loops(block) + bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length]) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + return sch.mod + + # Part 1. reinterpret float4_e2m1fn to uint8 + for vector_length in [1, 2, 4]: + mod = get_reinterpret_mod("float4_e2m1fn", "uint8", vector_length) + f = tvm.build(mod, target=target) + a_np = np.random.uniform(low=-6, high=6, size=(n,)).astype("float4_e2m1fn") + a = tvm.nd.empty(shape=(n,), dtype="float4_e2m1fn", device=dev) + a.copyfrom(a_np) + b = tvm.nd.empty(shape=(n,), dtype="uint8", device=dev) + f(a, b) + tvm.testing.assert_allclose(b.numpy(), a_np.view("uint8")) + + # Part 2. reinterpret uint8 to float4_e2m1fn + for vector_length in [1, 2, 4]: + mod = get_reinterpret_mod("uint8", "float4_e2m1fn", vector_length) + f = tvm.build(mod, target=target) + a_np = np.random.uniform(low=-6, high=6, size=(n,)).astype("uint8") + a = tvm.nd.empty(shape=(n,), dtype="uint8", device=dev) + a.copyfrom(a_np) + b = tvm.nd.empty(shape=(n,), dtype="float4_e2m1fn", device=dev) + f(a, b) + tvm.testing.assert_allclose( + b.numpy().astype("float32"), a_np.view("float4_e2m1fn").astype("float32") + ) + + if __name__ == "__main__": tvm.testing.main()