diff --git a/ci/jenkins/docker-images.ini b/ci/jenkins/docker-images.ini index 6d3f78190f6a..0626ff2b5fa4 100644 --- a/ci/jenkins/docker-images.ini +++ b/ci/jenkins/docker-images.ini @@ -18,8 +18,8 @@ # This data file is read during when Jenkins runs job to determine docker images. [jenkins] ci_arm: tlcpack/ci-arm:20250226-223225-63bc315f -ci_cpu: tlcpack/ci_cpu:20250226-223225-63bc315f -ci_gpu: tlcpack/ci-gpu:20250226-223225-63bc315f +ci_cpu: tlcpack/ci_cpu:20250226-223225-63bc315f_patch +ci_gpu: tlcpack/ci-gpu:20250226-223225-63bc315f_patch ci_hexagon: tlcpack/ci-hexagon:20250226-223225-63bc315f ci_i386: tlcpack/ci-i386:20250226-223225-63bc315f ci_lint: tlcpack/ci-lint:20250226-223225-63bc315f diff --git a/ci/jenkins/unity_jenkinsfile.groovy b/ci/jenkins/unity_jenkinsfile.groovy index 4cfe96937f62..8dee0c7f8c1a 100755 --- a/ci/jenkins/unity_jenkinsfile.groovy +++ b/ci/jenkins/unity_jenkinsfile.groovy @@ -30,8 +30,8 @@ import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> -ci_gpu = 'tlcpack/ci-gpu:20250226-223225-63bc315f' -ci_cpu = 'tlcpack/ci-cpu:20250226-223225-63bc315f' +ci_gpu = 'tlcpack/ci-gpu:20250226-223225-63bc315f_patch' +ci_cpu = 'tlcpack/ci-cpu:20250226-223225-63bc315f_patch' // <--- End of regex-scanned config. // Parameters to allow overriding (in Jenkins UI), the images diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index c49fde1746bc..3d35d5241e11 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -58,6 +58,7 @@ class DataType { kBFloat = kDLBfloat, kE4M3Float = 6U, kE5M2Float = 7U, + kE2M1Float = 8U, kCustomBegin = 129 }; /*! \brief default constructor */ @@ -87,6 +88,9 @@ class DataType { if (code == kE4M3Float || code == kE5M2Float) { ICHECK_EQ(bits, 8); } + if (code == kE2M1Float) { + ICHECK_EQ(bits, 4); + } } /*! \return The type code. */ int code() const { return static_cast(data_.code); } @@ -126,9 +130,13 @@ class DataType { code() == DataType::kE5M2Float) && bits() == 8; } + /*! \return whether type is a float4 type. */ + bool is_float4() const { return code() == DataType::kE2M1Float && bits() == 4; } bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && bits() == 8); } bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && bits() == 8); } + + bool is_e2m1_float4() const { return (code() == DataType::kE2M1Float && bits() == 4); } /*! \return whether type is a float16 type. */ bool is_float16() const { return is_float() && bits() == 16; } /*! \return whether type is a bfloat16 type. */ @@ -253,6 +261,12 @@ class DataType { * \return The constructed data type. */ static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); } + /*! + * \brief Construct NV float4 e2m1 datatype. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType NVFloat4E2M1(int lanes = 1) { return DataType(kE2M1Float, 4, lanes); } /*! * \brief Construct a bool type. * \param lanes The number of lanes. @@ -299,7 +313,7 @@ inline int GetVectorBytes(DataType dtype) { int data_bits = dtype.bits() * dtype.lanes(); // allow bool to exist if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || - dtype == DataType::Int(1)) { + dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1()) { return 1; } ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; @@ -385,6 +399,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) { return "e4m3_float"; case DataType::kE5M2Float: return "e5m2_float"; + case DataType::kE2M1Float: + return "e2m1_float"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); } @@ -466,6 +482,10 @@ inline DLDataType String2DLDataType(std::string s) { t.code = DataType::kE5M2Float; t.bits = 8; scan = s.c_str() + 10; + } else if (s.substr(0, 10) == "e2m1_float") { + t.code = DataType::kE2M1Float; + t.bits = 4; + scan = s.c_str() + 10; } else if (s.substr(0, 6) == "custom") { t.code = ParseCustomDatatype(s, &scan); } else { diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 380c2fcce25d..5dd1a5c733c5 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -505,6 +505,8 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8, DataType::NVFloat8E4M3); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8, DataType::NVFloat8E5M2); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E2M1Float4, DataType::NVFloat4E2M1); + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool()); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index d06bb779d0bb..e98eb46be919 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -940,7 +940,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) return LargeUIntImm(t, static_cast(low), static_cast(high), span); } } - if (t.is_float() || t.is_bfloat16() || t.is_float8()) + if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float4()) return FloatImm(t, static_cast(value), span); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index f79df1644e28..263a4ff69fa6 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -68,6 +68,7 @@ class DataTypeCode(object): BFLOAT = 4 E4M3Float = 6 E5M2Float = 7 + E2M1Float = 8 class DataType(ctypes.Structure): @@ -82,6 +83,7 @@ class DataType(ctypes.Structure): DataTypeCode.BFLOAT: "bfloat", DataTypeCode.E4M3Float: "e4m3_float", DataTypeCode.E5M2Float: "e5m2_float", + DataTypeCode.E2M1Float: "e2m1_float", } NUMPY2STR = { np.dtype(np.bool_): "bool", @@ -112,6 +114,7 @@ class DataType(ctypes.Structure): "uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1}, "e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1}, "e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1}, + "e2m1_float4": {"type_code": DataTypeCode.E2M1Float, "bits": 4, "lanes": 1}, "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1}, "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1}, "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1}, @@ -168,6 +171,9 @@ def __init__(self, type_str): elif head.startswith("e5m2_float"): self.type_code = DataTypeCode.E5M2Float head = head[10:] + elif head.startswith("e2m1_float"): + self.type_code = DataTypeCode.E2M1Float + head = head[10:] elif head.startswith("custom"): # pylint: disable=import-outside-toplevel import tvm.runtime._ffi_api @@ -232,6 +238,7 @@ def itemsize(self): DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16" DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8" DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8" + DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "e2m1_float4" RPC_SESS_MASK = 128 diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index be35bf631943..d12ddf883cf4 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -445,3 +445,19 @@ def have_fp8(compute_version): if major >= 9: return True return False + + +@tvm._ffi.register_func("tvm.contrib.nvcc.supports_fp4") +def have_fp4(compute_version): + """Whether fp4 support is provided in the specified compute capability or not + + Parameters + ---------- + compute_version : str + GPU capability + """ + major, minor = parse_compute_version(compute_version) + # fp4 is suppored in Blackwell (10.0) or later architectures. + if major == 10 and minor == 0: + return True + return False diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 082a28c7e204..3514ee6168d1 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -197,6 +197,13 @@ def copyfrom(self, source_array): source_array = np.ascontiguousarray( source_array, dtype="uint16" if dtype == "bfloat16" else dtype ) + if dtype.startswith("e2m1_float4"): + data_bits = source_array.view(dtype="uint8") + if data_bits.size % 2: + data_bits = np.pad(data_bits, (0, 1), mode="constant", constant_values=0) + data_bits = data_bits.reshape(-1, 2) + packed = ((data_bits[:, 0] & 0x0F) << 4) | (data_bits[:, 1] & 0x0F) + source_array = packed.astype(np.int8) assert source_array.flags["C_CONTIGUOUS"] data = source_array.ctypes.data_as(ctypes.c_void_p) nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize) @@ -254,20 +261,32 @@ def numpy(self): raise RuntimeError( "ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy." ) + if dtype == "e2m1_float4": + if ml_dtypes is not None: + dtype = ml_dtypes.float4_e2m1fn + else: + raise RuntimeError( + "ml_dtypes is not installed, cannot convert e2m1_float4 array to numpy." + ) np_arr = np.empty(shape, dtype=dtype) assert np_arr.flags["C_CONTIGUOUS"] data = np_arr.ctypes.data_as(ctypes.c_void_p) - nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize) + if old_dtype.startswith("e2m1_float4"): + nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize // 2) + else: + nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize) check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes)) - if old_dtype == "int4": + if old_dtype == "int4" or old_dtype.startswith("e2m1_float4"): length = np_arr.size + np_arr = np_arr.view("int8") np_arr_ret = np.empty((length,), dtype="int8") np_arr = np_arr.reshape((length,)) old_index = np.bitwise_and(np_arr, 0x0F) even_index = np.bitwise_and(np_arr >> 4, 0x0F) np_arr_ret[1::2] = old_index[0 : length // 2] np_arr_ret[0::2] = even_index[0 : length // 2] - return np_arr_ret.reshape(shape) + return np_arr_ret.reshape(shape).view(dtype) + return np_arr def copyto(self, target, mem_scope=None): diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index da0e2954e83b..6cc19305e49f 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1457,6 +1457,14 @@ def func( e5m2_float8x32 = func_gen(("E5M2Float8x32")) e5m2_float8x64 = func_gen(("E5M2Float8x64")) +e2m1_float4 = func_gen(("E2M1Float4")) +e2m1_float4x4 = func_gen(("E2M1Float4x4")) +e2m1_float4x8 = func_gen(("E2M1Float4x8")) +e2m1_float4x16 = func_gen(("E2M1Float4x16")) +e2m1_float4x32 = func_gen(("E2M1Float4x32")) +e2m1_float4x64 = func_gen(("E2M1Float4x64")) + + # pylint: enable=invalid-name @@ -2005,31 +2013,37 @@ def wrapped(*args, **kwargs): "uint64x64", "e4m3_float8", "e5m2_float8", + "e2m1_float4", "float16", "float32", "float64", "e4m3_float8x4", "e5m2_float8x4", + "e2m1_float4x4", "float16x4", "float32x4", "float64x4", "e4m3_float8x8", "e5m2_float8x8", + "e2m1_float4x8", "float16x8", "float32x8", "float64x8", "e4m3_float8x16", "e5m2_float8x16", + "e2m1_float4x16", "float16x16", "float32x16", "float64x16", "e4m3_float8x32", "e5m2_float8x32", + "e2m1_float4x32", "float16x32", "float32x32", "float64x32", "e4m3_float8x64", "e5m2_float8x64", + "e2m1_float4x64", "float16x64", "float32x64", "float64x64", diff --git a/src/ir/expr.cc b/src/ir/expr.cc index ded046eafc5d..766abf3483c7 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -110,7 +110,7 @@ TVM_REGISTER_NODE_TYPE(IntImmNode); FloatImm::FloatImm(DataType dtype, double value, Span span) { ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar."; - ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || + ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float4() || dtype.code() >= DataType::kCustomBegin) << "ValueError: FloatImm supports only float, but " << dtype << " was supplied."; @@ -137,6 +137,11 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) { << dtype; ICHECK_LE(value, bound) << "ValueError: Literal vaule " << value << " exceeds maximum of " << dtype; + } else if (dtype.is_float4()) { + ICHECK_GE(value, -support::kMaxE2M1) + << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; + ICHECK_LE(value, support::kMaxE2M1) + << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; } } ObjectPtr node = make_object(); diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index c2cf5f388a21..1812d13b9c14 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -28,6 +28,7 @@ #include #include "runtime_base.h" +#include "tvm/runtime/data_type.h" extern "C" { // C-mangled dlpack deleter. @@ -53,6 +54,8 @@ inline void VerifyDataType(DLDataType dtype) { return; else if (dtype.bits == 4 && dtype.code == kDLInt) return; + else if (dtype.bits == 4 && dtype.code == DataType::kE2M1Float) + return; else ICHECK_EQ(dtype.bits % 8, 0); } diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 17353561ee54..e452e102bf1e 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -757,6 +757,9 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8); TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.E2M1Float4").set_body_typed(E2M1Float4); +TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E2M1Float4", E2M1Float4); + TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void); diff --git a/src/support/scalars.h b/src/support/scalars.h index 05763f8044bf..b229a6b3380f 100644 --- a/src/support/scalars.h +++ b/src/support/scalars.h @@ -69,6 +69,9 @@ constexpr double kMaxE4M3 = 448; // See https://arxiv.org/pdf/2209.05433.pdf constexpr double kMaxE5M2 = 57344; +// 2^2 * (1 + 1/2) +constexpr double kMaxE2M1 = 6.0; + } // namespace support } // namespace tvm diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 6c051fc939cf..f5f17c70ef9d 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -581,6 +581,8 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { } } else if (dtype.code() == DataType::kE4M3Float || dtype.code() == DataType::kE5M2Float) { etype = llvm::Type::getInt8Ty(*ctx); + } else if (dtype.code() == DataType::kE2M1Float) { + etype = llvm::Type::getIntNTy(*ctx, 4); } if (!dtype.is_scalar()) { #if TVM_LLVM_VERSION >= 110 diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 9f68cd8d669a..0e0971b8f86c 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -240,7 +240,7 @@ std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExp } std::string index_str = PrintExpr(index); - if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { + if ((t.bits() == 4 && !t.is_float4()) || (t.bits() == 1 && t.is_int())) { // This is a special case, because CodegenCUDA::PrintType() // returns "int" for bool and for 4-bit integers. In most cases, // we divide by the number of lanes to determine the index. diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 040051825119..872a02436676 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -71,6 +71,30 @@ std::string GetFP8Type(DataType type) { return stream.str(); } +std::string GetFP4Type(DataType type) { + std::stringstream stream; + int32_t lanes = type.lanes(); + std::string vec; + if (type.is_scalar()) { + vec = ""; + } else if (lanes == 2) { + vec = "x2"; + } else if (lanes == 4) { + vec = "x4"; + } else { + LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for FP8"; + } + stream << "__nv_fp4"; + std::string suffix; + if (type.code() == DataType::kE2M1Float) { + suffix = "_e2m1"; + } else { + LOG(FATAL) << "Unsupported FP8 type in CUDA codegen"; + } + stream << vec << suffix; + return stream.str(); +} + CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; } void CodeGenCUDA::Init(bool output_ssa) { @@ -133,7 +157,11 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#else\n"; decl_stream << _cuda_half_t_def; decl_stream << "#endif\n\n"; + + decl_stream << "#include \n"; + decl_stream << "#if (CUDA_VERSION <12080)\n"; decl_stream << _cuda_half_util; + decl_stream << "#endif\n"; } if (enable_bf16_) { @@ -163,6 +191,11 @@ std::string CodeGenCUDA::Finish() { decl_stream << "struct fp8_e5x16_t {\n fp8_e5_t data[16]; \n};\n"; decl_stream << "#endif\n\n"; } + if (enable_fp4_) { + decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n"; + decl_stream << "#include \n"; + decl_stream << "#endif\n\n"; + } declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_); if (enable_warp_shuffle_) { @@ -314,6 +347,14 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "uint" << t.lanes() / 4; } return; + } else if (t.is_float4()) { + enable_fp4_ = true; + if (t.lanes() <= 4) { + os << GetFP4Type(t); + } else { + fail = true; + } + return; } else if (t == DataType::Bool()) { os << "bool"; return; @@ -691,7 +732,8 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); if (target_ty.code() == DataType::kE4M3Float || target_ty.code() == DataType::kE5M2Float || - from_ty.code() == DataType::kE4M3Float || from_ty.code() == DataType::kE5M2Float) { + target_ty.code() == DataType::kE2M1Float || from_ty.code() == DataType::kE4M3Float || + from_ty.code() == DataType::kE5M2Float || from_ty.code() == DataType::kE2M1Float) { std::ostringstream val; val << "("; PrintType(target_ty, val); @@ -1273,7 +1315,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO return; } - if (op->dtype.is_float8()) { + if (op->dtype.is_float8() || op->dtype.is_float4()) { int lanes = op->dtype.lanes(); ICHECK(lanes == 1 || lanes == 2 || lanes == 4); std::string v = PrintExpr(op->value); @@ -1388,7 +1430,7 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) return; } // Type code is kE5M2Float or kE4M4Float - if (op->dtype.is_float8()) { + if (op->dtype.is_float8() || op->dtype.is_float4()) { p->PrintType(op->dtype, os); os << '(' << std::scientific << op->value << 'f' << ')'; return; diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 7fe818b6b4fb..ed5709ac12be 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -42,8 +42,8 @@ class CodeGenCUDA final : public CodeGenC { void Init(bool output_ssa); std::string Finish(); bool need_include_path() { - return (enable_fp16_ || enable_bf16_ || enable_int8_ || enable_fp8_ || need_math_constants_h_ || - need_mma_h_); + return (enable_fp16_ || enable_bf16_ || enable_int8_ || enable_fp8_ || enable_fp4_ || + need_math_constants_h_ || need_mma_h_); } // override behavior void PrintFuncPrefix(std::ostream& os) final; @@ -96,6 +96,8 @@ class CodeGenCUDA final : public CodeGenC { bool enable_bf16_{false}; // whether enable fp8 bool enable_fp8_{false}; + // whether enable fp4 + bool enable_fp4_{false}; // whether enable int8 bool enable_int8_{false}; // whether enable warp shuffle intrinsics diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index dad4ea98d614..039acf7e929d 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -201,6 +201,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } else if (ltype.is_float8() && !rtype.is_float8()) { // Cast int->float8 for rhs when lhs is a float8 rhs = cast(ltype, rhs); + } else if (!ltype.is_float4() && rtype.is_float4()) { + // Cast int->float4 for lhs when rhs is a float4 + lhs = cast(rtype, lhs); + } else if (ltype.is_float4() && !rtype.is_float4()) { + // Cast int->float4 for rhs when lhs is a float4 + rhs = cast(ltype, rhs); } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) { // Promote int to higher bits e.g. int8 + int16 --> int16 + int16 if (ltype.bits() < rtype.bits()) { @@ -272,6 +278,8 @@ PrimExpr max_value(const DataType& dtype, Span span) { } else if (dtype.code() == DataType::TypeCode::kE4M3Float) { return FloatImm(dtype, 448.0, span); } + } else if (dtype.is_float4()) { + return FloatImm(dtype, 6.0, span); } LOG(FATAL) << "Cannot decide max_value for type" << dtype; } @@ -313,6 +321,8 @@ PrimExpr min_value(const DataType& dtype, Span span) { } else if (dtype.code() == DataType::TypeCode::kE4M3Float) { return FloatImm(dtype, -448.0, span); } + } else if (dtype.is_float4()) { + return FloatImm(dtype, -6.0, span); } LOG(FATAL) << "Cannot decide min_value for type" << dtype; } diff --git a/src/tir/transforms/dtype_conversion.cc b/src/tir/transforms/dtype_conversion.cc index de94cf647387..dfb0a5a63114 100644 --- a/src/tir/transforms/dtype_conversion.cc +++ b/src/tir/transforms/dtype_conversion.cc @@ -39,7 +39,7 @@ PrimExpr DTypeConversion(PrimExpr src_value, DataType tgt_dtype, RoundingMode ro CHECK_EQ(src_dtype.lanes(), tgt_dtype.lanes()) << "The lanes for data type for source value must matches the target datatype."; auto is_floating_point = [](DataType dtype) { - return dtype.is_float() || dtype.is_float8() || dtype.is_bfloat16(); + return dtype.is_float() || dtype.is_float8() || dtype.is_bfloat16() || dtype.is_float4(); }; // Both source dtype and target dtype should be floating point. CHECK(is_floating_point(src_dtype) && is_floating_point(tgt_dtype)); diff --git a/src/tir/transforms/dtype_conversion.h b/src/tir/transforms/dtype_conversion.h index b509abb9cd27..8edbf1bc1ebe 100644 --- a/src/tir/transforms/dtype_conversion.h +++ b/src/tir/transforms/dtype_conversion.h @@ -99,7 +99,7 @@ class FloatConfig { * \return The FloatConfig class containing internal floating point representation. */ static FloatConfig FromDataType(DataType dtype) { - CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8()) + CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float4()) << "FloatConfig is only applicable to floating point data types, got " << dtype << " instead."; if (dtype.is_float()) { @@ -117,7 +117,7 @@ class FloatConfig { } else if (dtype.is_bfloat16()) { // bfloat16, return FloatConfig(8, 7, 127, InftyStyle::kIEEE, NaNStyle::kIEEE); - } else { // float8 + } else if (dtype.is_float8()) { // float8 // NVIDIA/Arm/Intel's FP8 formats for Deep Learning // Reference: https://arxiv.org/abs/2209.05433 switch (dtype.code()) { @@ -128,6 +128,10 @@ class FloatConfig { // E5M2 format, consistent with IEEE-754 return FloatConfig(5, 2, 15, InftyStyle::kIEEE, NaNStyle::kIEEE); } + } else { + // float4 + // E2M1 format, not consistent with IEEE-754 + return FloatConfig(2, 1, 1, InftyStyle::kNone, NaNStyle::kNone); } } }; diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py b/tests/python/codegen/test_target_codegen_cuda_fp4.py new file mode 100644 index 000000000000..f137e83cc942 --- /dev/null +++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import tvm +from tvm.script import tir as T +import numpy as np +import tvm.testing +from tvm.script import ir as I, relax as R, tir as T + +try: + import ml_dtypes +except ImportError: + ml_dtypes = None + +native_dtype, promoted_dtype = tvm.testing.parameters( + ("e2m1_float4x2", "float32x2"), + ("e2m1_float4x2", "float16x2"), +) + + +@tvm.testing.requires_cuda_compute_version(9) +def test_e2m1_vector_conversions(native_dtype, promoted_dtype): + vector_length = 64 + + @T.prim_func + def add( + A: T.Buffer((vector_length,), native_dtype), + B: T.Buffer((vector_length,), native_dtype), + C: T.Buffer((vector_length,), native_dtype), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i in range(vector_length): + with T.block("C"): + v_i = T.axis.spatial(vector_length, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = T.Cast( + native_dtype, T.Cast(promoted_dtype, A[v_i]) + T.Cast(promoted_dtype, B[v_i]) + ) + + sch = tvm.tir.Schedule(add) + block = sch.get_block("C") + b = sch.get_loops(block) + bx, tx = sch.split(b[0], factors=[None, 32]) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + + target = "cuda" + fadd = tvm.build(sch.mod, target=target) + cuda_src = fadd.imported_modules[0].get_source() + dev = tvm.device(target, 0) + + numpytype = "float4_e2m1fn" + if "x" in native_dtype: + lanes = int(native_dtype.split("x")[-1]) + else: + lanes = 1 + + if "x" in promoted_dtype: + promoted_base_dtype = promoted_dtype.split("x")[0] + else: + promoted_base_dtype = promoted_dtype + + np_shape = (vector_length, lanes) if lanes > 1 else (vector_length,) + a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype) + a = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev) + a.copyfrom(a_np) + b_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype) + b = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev) + b.copyfrom(b_np) + c = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev) + fadd(a, b, c) + + tvm.testing.assert_allclose( + c.numpy().astype(promoted_base_dtype), (a_np + b_np).astype(promoted_base_dtype) + ) + + +if __name__ == "__main__": + tvm.testing.main()