diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 76e5e3833f17..65fd0c98fdb7 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -56,9 +56,9 @@ class DataType { kFloat = kDLFloat, kHandle = TVMArgTypeCode::kTVMOpaqueHandle, kBFloat = kDLBfloat, - kE4M3Float = 6U, - kE5M2Float = 7U, - kFloat4E2M1Fn = 8U, + kFloat8_e4m3fn = 6U, + kFloat8_e5m2 = 7U, + kFloat4_e2m1fn = 8U, kCustomBegin = 129 }; /*! \brief default constructor */ @@ -85,10 +85,10 @@ class DataType { if (code == kBFloat) { ICHECK_EQ(bits, 16); } - if (code == kE4M3Float || code == kE5M2Float) { + if (code == kFloat8_e4m3fn || code == kFloat8_e5m2) { ICHECK_EQ(bits, 8); } - if (code == kFloat4E2M1Fn) { + if (code == kFloat4_e2m1fn) { ICHECK_EQ(bits, 4); } } @@ -126,15 +126,15 @@ class DataType { bool is_float() const { return code() == DataType::kFloat; } /*! \return whether type is a float8 type. */ bool is_float8() const { - return (code() == DataType::kFloat || code() == DataType::kE4M3Float || - code() == DataType::kE5M2Float) && + return (code() == DataType::kFloat || code() == DataType::kFloat8_e4m3fn || + code() == DataType::kFloat8_e5m2) && bits() == 8; } /*! \return whether type is a float4 type. */ - bool is_float4() const { return code() == DataType::kFloat4E2M1Fn && bits() == 4; } - bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && bits() == 8); } - bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && bits() == 8); } - bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4E2M1Fn && bits() == 4); } + bool is_float4() const { return code() == DataType::kFloat4_e2m1fn && bits() == 4; } + bool is_float8_e4m3fn() const { return (code() == DataType::kFloat8_e4m3fn && bits() == 8); } + bool is_float8_e5m2() const { return (code() == DataType::kFloat8_e5m2 && bits() == 8); } + bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4_e2m1fn && bits() == 4); } /*! \return whether type is a float16 type. */ bool is_float16() const { return is_float() && bits() == 16; } /*! \return whether type is a bfloat16 type. */ @@ -252,19 +252,19 @@ class DataType { * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kE4M3Float, 8, lanes); } + static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kFloat8_e4m3fn, 8, lanes); } /*! * \brief Construct NV float8 e5m2 datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); } + static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2, 8, lanes); } /*! * \brief Construct NV float4_e2m1fn datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat4E2M1FN(int lanes = 1) { return DataType(kFloat4E2M1Fn, 4, lanes); } + static DataType NVFloat4E2M1FN(int lanes = 1) { return DataType(kFloat4_e2m1fn, 4, lanes); } /*! * \brief Construct a bool type. * \param lanes The number of lanes. @@ -393,11 +393,11 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) { return "handle"; case kDLBfloat: return "bfloat"; - case DataType::kE4M3Float: - return "e4m3_float"; - case DataType::kE5M2Float: - return "e5m2_float"; - case DataType::kFloat4E2M1Fn: + case DataType::kFloat8_e4m3fn: + return "float8_e4m3fn"; + case DataType::kFloat8_e5m2: + return "float8_e5m2"; + case DataType::kFloat4_e2m1fn: return "float4_e2m1fn"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); @@ -420,7 +420,10 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) } if (t.code == kTVMOpaqueHandle) return os; int16_t lanes = static_cast(t.lanes); - os << static_cast(t.bits); + if (t.code != DataType::kFloat8_e4m3fn && t.code != DataType::kFloat8_e5m2 && + t.code != DataType::kFloat4_e2m1fn) { + os << static_cast(t.bits); + } if (lanes > 1) { os << 'x' << lanes; } else if (lanes < -1) { @@ -458,7 +461,7 @@ inline DLDataType String2DLDataType(std::string s) { scan = s.c_str() + 4; } else if (s.substr(0, 13) == "float4_e2m1fn") { // Avoid being treated as "float" - t.code = DataType::kFloat4E2M1Fn; + t.code = DataType::kFloat4_e2m1fn; t.bits = 4; scan = s.c_str() + 13; char* endpt = nullptr; @@ -468,6 +471,30 @@ inline DLDataType String2DLDataType(std::string s) { } ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s; return t; + } else if (s.substr(0, 13) == "float8_e4m3fn") { + // Avoid being treated as "float" + t.code = DataType::kFloat8_e4m3fn; + t.bits = 8; + scan = s.c_str() + 13; + char* endpt = nullptr; + if (*scan == 'x') { + t.lanes = static_cast(strtoul(scan + 1, &endpt, 10)); + scan = endpt; + } + ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s; + return t; + } else if (s.substr(0, 11) == "float8_e5m2") { + // Avoid being treated as "float" + t.code = DataType::kFloat8_e5m2; + t.bits = 8; + scan = s.c_str() + 11; + char* endpt = nullptr; + if (*scan == 'x') { + t.lanes = static_cast(strtoul(scan + 1, &endpt, 10)); + scan = endpt; + } + ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s; + return t; } else if (s.substr(0, 5) == "float") { t.code = kDLFloat; scan = s.c_str() + 5; @@ -484,14 +511,6 @@ inline DLDataType String2DLDataType(std::string s) { t.code = DataType::kBFloat; t.bits = 16; scan = s.c_str() + 6; - } else if (s.substr(0, 10) == "e4m3_float") { - t.code = DataType::kE4M3Float; - t.bits = 8; - scan = s.c_str() + 10; - } else if (s.substr(0, 10) == "e5m2_float") { - t.code = DataType::kE5M2Float; - t.bits = 8; - scan = s.c_str() + 10; } else if (s.substr(0, 6) == "custom") { t.code = ParseCustomDatatype(s, &scan); } else { diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index e78e0d51fdc1..e60a3859acf5 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -502,10 +502,10 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8, DataType::NVFloat8E4M3); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8, DataType::NVFloat8E5M2); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, DataType::NVFloat8E4M3); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2, DataType::NVFloat8E5M2); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1fn, DataType::NVFloat4E2M1FN); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, DataType::NVFloat4E2M1FN); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool()); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 3f4ceadd1d20..317bd6bead7c 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -66,9 +66,9 @@ class DataTypeCode(object): FLOAT = 2 HANDLE = 3 BFLOAT = 4 - E4M3Float = 6 - E5M2Float = 7 - FLOAT4E2M1FN = 8 + Float8E4M3FN = 6 + Float8E5M2 = 7 + Float4E2M1FN = 8 class DataType(ctypes.Structure): @@ -81,9 +81,9 @@ class DataType(ctypes.Structure): DataTypeCode.FLOAT: "float", DataTypeCode.HANDLE: "handle", DataTypeCode.BFLOAT: "bfloat", - DataTypeCode.E4M3Float: "e4m3_float", - DataTypeCode.E5M2Float: "e5m2_float", - DataTypeCode.FLOAT4E2M1FN: "float4_e2m1fn", + DataTypeCode.Float8E4M3FN: "float8_e4m3fn", + DataTypeCode.Float8E5M2: "float8_e5m2", + DataTypeCode.Float4E2M1FN: "float4_e2m1fn", } NUMPY2STR = { np.dtype(np.bool_): "bool", @@ -112,9 +112,9 @@ class DataType(ctypes.Structure): "uint16": {"type_code": DataTypeCode.UINT, "bits": 16, "lanes": 1}, "uint32": {"type_code": DataTypeCode.UINT, "bits": 32, "lanes": 1}, "uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1}, - "e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1}, - "e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1}, - "float4_e2m1fn": {"type_code": DataTypeCode.FLOAT4E2M1FN, "bits": 4, "lanes": 1}, + "float8_e4m3fn": {"type_code": DataTypeCode.Float8E4M3FN, "bits": 8, "lanes": 1}, + "float8_e5m2": {"type_code": DataTypeCode.Float8E5M2, "bits": 8, "lanes": 1}, + "float4_e2m1fn": {"type_code": DataTypeCode.Float4E2M1FN, "bits": 4, "lanes": 1}, "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1}, "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1}, "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1}, @@ -157,9 +157,17 @@ def __init__(self, type_str): head = head[4:] elif head.startswith("float4_e2m1fn"): # Avoid being treated as "float" - self.type_code = DataTypeCode.FLOAT4E2M1FN + self.type_code = DataTypeCode.Float4E2M1FN bits = 4 head = "" + elif head.startswith("float8_e4m3fn"): + self.type_code = DataTypeCode.Float8E4M3FN + bits = 8 + head = "" + elif head.startswith("float8_e5m2"): + self.type_code = DataTypeCode.Float8E5M2 + bits = 8 + head = "" elif head.startswith("float"): self.type_code = DataTypeCode.FLOAT head = head[5:] @@ -170,12 +178,6 @@ def __init__(self, type_str): elif head.startswith("bfloat"): self.type_code = DataTypeCode.BFLOAT head = head[6:] - elif head.startswith("e4m3_float"): - self.type_code = DataTypeCode.E4M3Float - head = head[10:] - elif head.startswith("e5m2_float"): - self.type_code = DataTypeCode.E5M2Float - head = head[10:] elif head.startswith("custom"): # pylint: disable=import-outside-toplevel import tvm.runtime._ffi_api @@ -204,7 +206,9 @@ def __repr__(self): type_name = "custom[%s]" % tvm.runtime._ffi_api._datatype_get_type_name(self.type_code) if self.type_code in [ - DataTypeCode.FLOAT4E2M1FN, + DataTypeCode.Float8E4M3FN, + DataTypeCode.Float8E5M2, + DataTypeCode.Float4E2M1FN, ]: x = type_name else: @@ -243,8 +247,8 @@ def itemsize(self): if ml_dtypes is not None: DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16" - DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8" - DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8" + DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "float8_e4m3fn" + DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2" DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn" RPC_SESS_MASK = 128 diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py index 9bff724df7bc..5a5a3a1c80ac 100644 --- a/python/tvm/contrib/tvmjs.py +++ b/python/tvm/contrib/tvmjs.py @@ -369,19 +369,19 @@ def load_ndarray_cache(cachepath: str, device: tvm.runtime.Device): arr = tvm.nd.empty(shape, dtype, device=device) assert offset + nbytes <= len(raw_data) buffer_source = raw_data[offset : offset + nbytes] - if dtype == "e4m3_float8": + if dtype == "float8_e4m3fn": if ml_dtypes is not None: dtype = ml_dtypes.float8_e4m3fn else: raise RuntimeError( - "ml_dtypes is not installed, cannot convert e4m3_float8 array to numpy." + "ml_dtypes is not installed, cannot convert float8_e4m3fn array to numpy." ) - if dtype == "e5m2_float8": + if dtype == "float8_e5m2": if ml_dtypes is not None: dtype = ml_dtypes.float8_e5m2 else: raise RuntimeError( - "ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy." + "ml_dtypes is not installed, cannot convert float8_e5m2 array to numpy." ) if encode_format == "f32-to-bf16" and dtype == "float32": data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape) diff --git a/python/tvm/relax/backend/cuda/cublas.py b/python/tvm/relax/backend/cuda/cublas.py index 287b18b4409a..6828381e68e1 100644 --- a/python/tvm/relax/backend/cuda/cublas.py +++ b/python/tvm/relax/backend/cuda/cublas.py @@ -27,18 +27,18 @@ from ..pattern_registry import get_patterns_with_prefix, register_patterns from ..patterns import ( - make_matmul_pattern, make_matmul_dequantize_pattern, make_matmul_multiply_pattern, + make_matmul_pattern, ) from ..utils import has_leaking_intermediate_variables def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): """Check if dtypes in the given workload are supported by cuBLAS BYOC.""" - if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8": - # The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8' - return out_dtype != "e5m2_float8" + if lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn": + # The output cannot be 'float8_e5m2' if inputs are 'float8_e4m3fn' + return out_dtype != "float8_e5m2" return ( (lhs_dtype == "float16" and rhs_dtype == "float16") or (lhs_dtype == "float32" and rhs_dtype == "float32") @@ -83,7 +83,7 @@ def _check_matmul(context: PatternCheckContext) -> bool: if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or rhs_shape[-1] % 4 != 0: # Rows number must be multiples of 4 for IGEMM return False - elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8": + elif lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn": matmul_rhs_var = matmul_call.args[1] rhs_transposed = False if matmul_rhs_var in context.matched_bindings: diff --git a/python/tvm/relax/backend/rocm/hipblas.py b/python/tvm/relax/backend/rocm/hipblas.py index c0accc1473e1..63c72b660dc6 100644 --- a/python/tvm/relax/backend/rocm/hipblas.py +++ b/python/tvm/relax/backend/rocm/hipblas.py @@ -30,9 +30,9 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): # pylint: disable=unused-argument """Check if dtypes in the given workload are supported by hipblas BYOC.""" - if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8": - # The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8' - # return out_dtype != "e5m2_float8" + if lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn": + # The output cannot be 'float8_e5m2' if inputs are 'float8_e4m3fn' + # return out_dtype != "float8_e5m2" return False return (lhs_dtype == "float16" and rhs_dtype == "float16") or ( lhs_dtype == "int8" and rhs_dtype == "int8" @@ -61,7 +61,7 @@ def _check_matmul(context: PatternCheckContext) -> bool: if lhs_dtype == "int8" and rhs_dtype == "int8": return False - elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8": + elif lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn": return False lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 47fcccf52bac..d55334a1545b 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -249,19 +249,19 @@ def numpy(self): dtype = "int8" if dtype == "bfloat16": dtype = "uint16" - if dtype == "e4m3_float8": + if dtype == "float8_e4m3fn": if ml_dtypes is not None: dtype = ml_dtypes.float8_e4m3fn else: raise RuntimeError( - "ml_dtypes is not installed, cannot convert e4m3_float8 array to numpy." + "ml_dtypes is not installed, cannot convert float8_e4m3fn array to numpy." ) - if dtype == "e5m2_float8": + if dtype == "float8_e5m2": if ml_dtypes is not None: dtype = ml_dtypes.float8_e5m2 else: raise RuntimeError( - "ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy." + "ml_dtypes is not installed, cannot convert float8_e5m2 array to numpy." ) if dtype == "float4_e2m1fn": if ml_dtypes is not None: diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index c35df7a093ef..2fce022da365 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1442,27 +1442,27 @@ def func( float32x64 = func_gen(("Float32x64")) float64x64 = func_gen(("Float64x64")) -e4m3_float8 = func_gen(("E4M3Float8")) -e4m3_float8x4 = func_gen(("E4M3Float8x4")) -e4m3_float8x8 = func_gen(("E4M3Float8x8")) -e4m3_float8x16 = func_gen(("E4M3Float8x16")) -e4m3_float8x32 = func_gen(("E4M3Float8x32")) -e4m3_float8x64 = func_gen(("E4M3Float8x64")) - -e5m2_float8 = func_gen(("E5M2Float8")) -e5m2_float8x4 = func_gen(("E5M2Float8x4")) -e5m2_float8x8 = func_gen(("E5M2Float8x8")) -e5m2_float8x16 = func_gen(("E5M2Float8x16")) -e5m2_float8x32 = func_gen(("E5M2Float8x32")) -e5m2_float8x64 = func_gen(("E5M2Float8x64")) - -float4_e2m1fn = func_gen(("Float4E2M1fn")) -float4_e2m1fnx2 = func_gen(("Float4E2M1fnx2")) -float4_e2m1fnx4 = func_gen(("Float4E2M1fnx4")) -float4_e2m1fnx8 = func_gen(("Float4E2M1fnx8")) -float4_e2m1fnx16 = func_gen(("Float4E2M1fnx16")) -float4_e2m1fnx32 = func_gen(("Float4E2M1fnx32")) -float4_e2m1fnx64 = func_gen(("Float4E2M1fnx64")) +float8_e4m3fn = func_gen(("Float8E4M3FN")) +float8_e4m3fnx4 = func_gen(("Float8E4M3FNx4")) +float8_e4m3fnx8 = func_gen(("Float8E4M3FNx8")) +float8_e4m3fnx16 = func_gen(("Float8E4M3FNx16")) +float8_e4m3fnx32 = func_gen(("Float8E4M3FNx32")) +float8_e4m3fnx64 = func_gen(("Float8E4M3FNx64")) + +float8_e5m2 = func_gen(("Float8E5M2")) +float8_e5m2x4 = func_gen(("Float8E5M2x4")) +float8_e5m2x8 = func_gen(("Float8E5M2x8")) +float8_e5m2x16 = func_gen(("Float8E5M2x16")) +float8_e5m2x32 = func_gen(("Float8E5M2x32")) +float8_e5m2x64 = func_gen(("Float8E5M2x64")) + +float4_e2m1fn = func_gen(("Float4E2M1FN")) +float4_e2m1fnx2 = func_gen(("Float4E2M1FNx2")) +float4_e2m1fnx4 = func_gen(("Float4E2M1FNx4")) +float4_e2m1fnx8 = func_gen(("Float4E2M1FNx8")) +float4_e2m1fnx16 = func_gen(("Float4E2M1FNx16")) +float4_e2m1fnx32 = func_gen(("Float4E2M1FNx32")) +float4_e2m1fnx64 = func_gen(("Float4E2M1FNx64")) # pylint: enable=invalid-name @@ -2011,39 +2011,39 @@ def wrapped(*args, **kwargs): "uint16x64", "uint32x64", "uint64x64", - "e4m3_float8", - "e5m2_float8", + "float8_e4m3fn", + "float8_e5m2", "float4_e2m1fn", "float16", "float32", "float64", "float4_e2m1fnx2", - "e4m3_float8x4", - "e5m2_float8x4", + "float8_e4m3fnx4", + "float8_e5m2x4", "float4_e2m1fnx4", "float16x4", "float32x4", "float64x4", - "e4m3_float8x8", - "e5m2_float8x8", + "float8_e4m3fnx8", + "float8_e5m2x8", "float4_e2m1fnx8", "float16x8", "float32x8", "float64x8", - "e4m3_float8x16", - "e5m2_float8x16", + "float8_e4m3fnx16", + "float8_e5m2x16", "float4_e2m1fnx16", "float16x16", "float32x16", "float64x16", - "e4m3_float8x32", - "e5m2_float8x32", + "float8_e4m3fnx32", + "float8_e5m2x32", "float4_e2m1fnx32", "float16x32", "float32x32", "float64x32", - "e4m3_float8x64", - "e5m2_float8x64", + "float8_e4m3fnx64", + "float8_e5m2x64", "float4_e2m1fnx64", "float16x64", "float32x64", diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index e1ff18bc8fb9..57b1c3b873d7 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -16,13 +16,13 @@ # under the License. # pylint: disable=invalid-name,missing-function-docstring,unused-variable """Intrinsics for tensorization on NVIDIA GPU.""" -from typing import Dict, Optional, Tuple, Literal +from typing import Dict, Literal, Optional, Tuple from tvm._ffi import register_func from tvm.runtime import convert from tvm.script import tir as T -from tvm.tir.function import PrimFunc from tvm.tir import Cast, IntImm, TensorIntrin +from tvm.tir.function import PrimFunc def shared_16x16_to_ldmatrix_32x8_layout(i, j): @@ -123,7 +123,7 @@ def get_ldmatrix_intrin( matrix_name == "B" or not transposed ), "Now only B matrix can be transposed for int8 matmul" assert k_dim == 32 and ( - dtype == "int8" or dtype == "e4m3_float8" or dtype == "e5m2_float8" + dtype == "int8" or dtype == "float8_e4m3fn" or dtype == "float8_e5m2" ), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now" if matrix_name == "B" and not transposed: @@ -261,25 +261,25 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", True)) LDMATRIX_e4m3_A_INTRIN = "mma_ldmatrix_e4m3_a" -TensorIntrin.register(LDMATRIX_e4m3_A_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "A", False)) +TensorIntrin.register(LDMATRIX_e4m3_A_INTRIN, *get_ldmatrix_intrin(32, "float8_e4m3fn", "A", False)) LDMATRIX_e4m3_B_INTRIN = "mma_ldmatrix_e4m3_b" -TensorIntrin.register(LDMATRIX_e4m3_B_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", False)) +TensorIntrin.register(LDMATRIX_e4m3_B_INTRIN, *get_ldmatrix_intrin(32, "float8_e4m3fn", "B", False)) LDMATRIX_e4m3_B_TRANS_INTRIN = "mma_ldmatrix_e4m3_b_trans" TensorIntrin.register( - LDMATRIX_e4m3_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", True) + LDMATRIX_e4m3_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "float8_e4m3fn", "B", True) ) LDMATRIX_e5m2_A_INTRIN = "mma_ldmatrix_e5m2_a" -TensorIntrin.register(LDMATRIX_e5m2_A_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "A", False)) +TensorIntrin.register(LDMATRIX_e5m2_A_INTRIN, *get_ldmatrix_intrin(32, "float8_e5m2", "A", False)) LDMATRIX_e5m2_B_INTRIN = "mma_ldmatrix_e5m2_b" -TensorIntrin.register(LDMATRIX_e5m2_B_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", False)) +TensorIntrin.register(LDMATRIX_e5m2_B_INTRIN, *get_ldmatrix_intrin(32, "float8_e5m2", "B", False)) LDMATRIX_e5m2_B_TRANS_INTRIN = "mma_ldmatrix_e5m2_b_trans" TensorIntrin.register( - LDMATRIX_e5m2_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", True) + LDMATRIX_e5m2_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "float8_e5m2", "B", True) ) @@ -315,8 +315,8 @@ def get_mma_intrin( "float32": "fp32", "int8": "int8", "int32": "int32", - "e4m3_float8": "e4m3", - "e5m2_float8": "e5m2", + "float8_e4m3fn": "e4m3", + "float8_e5m2": "e5m2", } a_dtype_abbrv = dtype_abbrv[a_dtype] b_dtype_abbrv = dtype_abbrv[b_dtype] @@ -522,25 +522,25 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: MMA_e5m2e5m2f32_INTRIN = "mma_e5m2e5m2f32" TensorIntrin.register( MMA_e5m2e5m2f32_INTRIN, - *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, False), + *get_mma_intrin(32, "float8_e5m2", "float8_e5m2", "float32", False, False), ) MMA_e5m2e5m2f32_TRANS_B_INTRIN = "mma_e5m2e5m2f32_trans_b" TensorIntrin.register( MMA_e5m2e5m2f32_TRANS_B_INTRIN, - *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, True), + *get_mma_intrin(32, "float8_e5m2", "float8_e5m2", "float32", False, True), ) MMA_e4m3e4m3f32_INTRIN = "mma_e4m3e4m3f32" TensorIntrin.register( MMA_e4m3e4m3f32_INTRIN, - *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, False), + *get_mma_intrin(32, "float8_e4m3fn", "float8_e4m3fn", "float32", False, False), ) MMA_e4m3e4m3f32_TRANS_B_INTRIN = "mma_e4m3e4m3f32_trans_b" TensorIntrin.register( MMA_e4m3e4m3f32_TRANS_B_INTRIN, - *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, True), + *get_mma_intrin(32, "float8_e4m3fn", "float8_e4m3fn", "float32", False, True), ) @@ -705,7 +705,7 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: def get_mma_intrin_group( load_scope: Literal["shared", "shared.dyn"], store_scope: Literal["global", "shared", "shared.dyn"], - in_dtype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"], + in_dtype: Literal["float16", "int8", "float8_e4m3fn", "float8_e5m2"], out_dtype: Literal["float16", "float32", "int32"], trans_a: bool, trans_b: bool, @@ -752,7 +752,7 @@ def get_mma_intrin_group( """ assert load_scope in ["shared", "shared.dyn"] assert store_scope in ["global", "shared", "shared.dyn"] - assert in_dtype in ["float16", "int8", "e4m3_float8", "e5m2_float8"] + assert in_dtype in ["float16", "int8", "float8_e4m3fn", "float8_e5m2"] assert out_dtype in ["float16", "float32", "int32"] shape = "16x16" @@ -761,8 +761,8 @@ def get_mma_intrin_group( "float16": "f16", "float32": "f32", "int8": "i8", - "e4m3_float8": "e4m3", - "e5m2_float8": "e5m2", + "float8_e4m3fn": "e4m3", + "float8_e5m2": "e5m2", "int32": "i32", } a_dtype = dtype_mapping[in_dtype] diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 766abf3483c7..8f188e95f013 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -132,15 +132,16 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) { ICHECK_LE(value, support::kMaxBFloat16) << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; } else if (dtype.is_float8()) { - double bound = (dtype.code() == DataType::kE4M3Float) ? support::kMaxE4M3 : support::kMaxE5M2; + double bound = + (dtype.code() == DataType::kFloat8_e4m3fn) ? support::kMaxE4M3FN : support::kMaxE5M2; ICHECK_GE(value, -bound) << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; ICHECK_LE(value, bound) << "ValueError: Literal vaule " << value << " exceeds maximum of " << dtype; } else if (dtype.is_float4()) { - ICHECK_GE(value, -support::kMaxE2M1) + ICHECK_GE(value, -support::kMaxE2M1FN) << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; - ICHECK_LE(value, support::kMaxE2M1) + ICHECK_LE(value, support::kMaxE2M1FN) << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; } } diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h index aa3928ce026a..e63e99548cc8 100644 --- a/src/relax/backend/contrib/utils.h +++ b/src/relax/backend/contrib/utils.h @@ -72,10 +72,12 @@ inline std::string DType2String(const tvm::DataType dtype) { std::ostringstream os; if (dtype.is_float()) { os << "float"; - } else if (dtype.is_e4m3_float8()) { - os << "e4m3_float"; - } else if (dtype.is_e5m2_float8()) { - os << "e5m2_float"; + } else if (dtype.is_float8_e4m3fn()) { + return "float8_e4m3fn"; + } else if (dtype.is_float8_e5m2()) { + return "float8_e5m2"; + } else if (dtype.is_float4_e2m1fn()) { + return "float4_e2m1fn"; } else if (dtype.is_int()) { os << "int"; } else if (dtype.is_uint()) { diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index c9a01fc24e06..ba01f791d98a 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -164,8 +164,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, ab_type = CUDA_R_16F; } else if (TypeMatch(A->dtype, kDLInt, 8)) { ab_type = CUDA_R_8I; - } else if (TypeMatch(A->dtype, DataType::TypeCode::kE4M3Float, 8)) { - ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kE4M3Float, 8)); + } else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) { + ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)); ab_type = CUDA_R_8F_E4M3; } @@ -217,7 +217,6 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, int N = RowCount(A, transa, batch_offset_A); int K = ColumnCount(A, transa, batch_offset_A); bool use_batched_gemm = A->ndim > 2 || B->ndim > 2; - // If A is batched but B is not, flatten all non-reduction axes of A to use the regular GEMM. // This trick is only applicable if batch axes and the other spatial axis (M or N) are // adjacent in both the input and the output matrix. In particular, if A is of shape (M, K) diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index d87606532509..5a328413a148 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -54,7 +54,7 @@ inline void VerifyDataType(DLDataType dtype) { return; else if (dtype.bits == 4 && dtype.code == kDLInt) return; - else if (dtype.bits == 4 && dtype.code == DataType::kFloat4E2M1Fn) + else if (dtype.bits == 4 && dtype.code == DataType::kFloat4_e2m1fn) return; else ICHECK_EQ(dtype.bits % 8, 0); diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index a73c9cb5b4ce..a75a35781001 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -752,13 +752,13 @@ TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float); TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt); TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.E4M3Float8").set_body_typed(E4M3Float8); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float8); -TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8); -TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FN").set_body_typed(Float8E4M3FN); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2").set_body_typed(Float8E5M2); +TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN); +TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1fn").set_body_typed(Float4E2M1fn); -TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1fn", Float4E2M1fn); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1FN").set_body_typed(Float4E2M1FN); +TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); diff --git a/src/support/scalars.h b/src/support/scalars.h index b229a6b3380f..adc449ffd682 100644 --- a/src/support/scalars.h +++ b/src/support/scalars.h @@ -63,14 +63,14 @@ constexpr double kMaxBFloat16 = 3.895313892515354759047080037148786688e38; // 2^8 * (1 + 6/8) // See https://arxiv.org/pdf/2209.05433.pdf -constexpr double kMaxE4M3 = 448; +constexpr double kMaxE4M3FN = 448; // 2^15 * (1 + 3/4) // See https://arxiv.org/pdf/2209.05433.pdf constexpr double kMaxE5M2 = 57344; // 2^2 * (1 + 1/2) -constexpr double kMaxE2M1 = 6.0; +constexpr double kMaxE2M1FN = 6.0; } // namespace support } // namespace tvm diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 9c2ce0bbb26b..ead0bdff3c0f 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -579,9 +579,9 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { default: LOG(FATAL) << "do not support " << dtype; } - } else if (dtype.code() == DataType::kE4M3Float || dtype.code() == DataType::kE5M2Float) { + } else if (dtype.code() == DataType::kFloat8_e4m3fn || dtype.code() == DataType::kFloat8_e5m2) { etype = llvm::Type::getInt8Ty(*ctx); - } else if (dtype.code() == DataType::kFloat4E2M1Fn) { + } else if (dtype.code() == DataType::kFloat4_e2m1fn) { etype = llvm::Type::getIntNTy(*ctx, 4); } if (!dtype.is_scalar()) { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 20b29750dc1b..35973776c818 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -60,9 +60,9 @@ std::string GetFP8Type(DataType type) { } stream << "__nv_fp8"; std::string suffix; - if (type.code() == DataType::kE4M3Float) { + if (type.code() == DataType::kFloat8_e4m3fn) { suffix = "_e4m3"; - } else if (type.code() == DataType::kE5M2Float) { + } else if (type.code() == DataType::kFloat8_e5m2) { suffix = "_e5m2"; } else { LOG(FATAL) << "Unsupported FP8 type in CUDA codegen"; @@ -86,7 +86,7 @@ std::string GetFP4Type(DataType type) { } stream << "__nv_fp4"; std::string suffix; - if (type.code() == DataType::kFloat4E2M1Fn) { + if (type.code() == DataType::kFloat4_e2m1fn) { suffix = "_e2m1"; } else { LOG(FATAL) << "Unsupported FP8 type in CUDA codegen"; @@ -159,9 +159,7 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#endif\n\n"; decl_stream << "#include \n"; - decl_stream << "#if (CUDA_VERSION <12080)\n"; decl_stream << _cuda_half_util; - decl_stream << "#endif\n"; } if (enable_bf16_) { @@ -734,9 +732,9 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { // Emit simple C-style type conversion. if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); - if (target_ty.code() == DataType::kE4M3Float || target_ty.code() == DataType::kE5M2Float || - target_ty.code() == DataType::kFloat4E2M1Fn || from_ty.code() == DataType::kE4M3Float || - from_ty.code() == DataType::kE5M2Float || from_ty.code() == DataType::kFloat4E2M1Fn) { + if (target_ty.code() == DataType::kFloat8_e4m3fn || target_ty.code() == DataType::kFloat8_e5m2 || + target_ty.code() == DataType::kFloat4_e2m1fn || from_ty.code() == DataType::kFloat8_e4m3fn || + from_ty.code() == DataType::kFloat8_e5m2 || from_ty.code() == DataType::kFloat4_e2m1fn) { std::ostringstream val; val << "("; PrintType(target_ty, val); @@ -1508,7 +1506,7 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) os << '(' << std::scientific << op->value << 'f' << ')'; return; } - // Type code is kE5M2Float or kE4M4Float + // Type code is kFloat8_e5m2 or kE4M4Float if (op->dtype.is_float8() || op->dtype.is_float4()) { p->PrintType(op->dtype, os); os << '(' << std::scientific << op->value << 'f' << ')'; diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 3dab634f162e..63c82d1d6c11 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -273,9 +273,9 @@ PrimExpr max_value(const DataType& dtype, Span span) { return FloatImm(dtype, std::numeric_limits::max(), span); } else if (dtype.is_float8()) { // according to https://arxiv.org/pdf/2209.05433.pdf - if (dtype.code() == DataType::TypeCode::kE5M2Float) { + if (dtype.code() == DataType::TypeCode::kFloat8_e5m2) { return FloatImm(dtype, 57344.0, span); - } else if (dtype.code() == DataType::TypeCode::kE4M3Float) { + } else if (dtype.code() == DataType::TypeCode::kFloat8_e4m3fn) { return FloatImm(dtype, 448.0, span); } } else if (dtype.is_float4()) { @@ -316,9 +316,9 @@ PrimExpr min_value(const DataType& dtype, Span span) { return FloatImm(dtype, std::numeric_limits::lowest(), span); } else if (dtype.is_float8()) { // according to https://arxiv.org/pdf/2209.05433.pdf - if (dtype.code() == DataType::TypeCode::kE5M2Float) { + if (dtype.code() == DataType::TypeCode::kFloat8_e5m2) { return FloatImm(dtype, -57344.0, span); - } else if (dtype.code() == DataType::TypeCode::kE4M3Float) { + } else if (dtype.code() == DataType::TypeCode::kFloat8_e4m3fn) { return FloatImm(dtype, -448.0, span); } } else if (dtype.is_float4()) { diff --git a/src/tir/transforms/dtype_conversion.h b/src/tir/transforms/dtype_conversion.h index 8edbf1bc1ebe..a0ed6b5f6d86 100644 --- a/src/tir/transforms/dtype_conversion.h +++ b/src/tir/transforms/dtype_conversion.h @@ -121,7 +121,7 @@ class FloatConfig { // NVIDIA/Arm/Intel's FP8 formats for Deep Learning // Reference: https://arxiv.org/abs/2209.05433 switch (dtype.code()) { - case DataType::kE4M3Float: + case DataType::kFloat8_e4m3fn: // E4M3 format, not consistent with IEEE-754 return FloatConfig(4, 3, 7, InftyStyle::kNone, NaNStyle::kAllOnes); default: diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index d94153003c6a..b91efd61922f 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -16,25 +16,24 @@ # under the License. import sys +from typing import List, Tuple + +import numpy as np import pytest import tvm -from tvm.script import tir as T -import numpy as np import tvm.testing - - -from typing import List, Tuple from tvm import DataType, DataTypeCode, IRModule from tvm import dlight as dl from tvm import relax, te, tir, topi from tvm.relax.frontend import nn from tvm.runtime import NDArray +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T from tvm.target import Target from tvm.topi.utils import get_const_tuple -from tvm.script import ir as I, relax as R, tir as T - try: import ml_dtypes except ImportError: @@ -43,7 +42,7 @@ @tvm.testing.requires_cuda_compute_version(8, 9) def test_e4m3_conversions(): - dtype = "e4m3_float8" + dtype = "float8_e4m3fn" @T.prim_func def add( @@ -90,7 +89,7 @@ def add( def test_e4m3_packing(): length = 64 vector_length = 4 - native_dtype, packed_dtype = ("e4m3_float8x4", "uint32") + native_dtype, packed_dtype = ("float8_e4m3fnx4", "uint32") @T.prim_func def add( @@ -141,13 +140,13 @@ def add( native_dtype, promoted_dtype = tvm.testing.parameters( - ("e4m3_float8", "float32"), - ("e4m3_float8", "float16"), - ("e4m3_float8x2", "float32x2"), - ("e4m3_float8x2", "float16x2"), - ("e4m3_float8x4", "float32x4"), + ("float8_e4m3fn", "float32"), + ("float8_e4m3fn", "float16"), + ("float8_e4m3fnx2", "float32x2"), + ("float8_e4m3fnx2", "float16x2"), + ("float8_e4m3fnx4", "float32x4"), # Supported via half4 vector type extension in codegen - ("e4m3_float8x4", "float16x4"), + ("float8_e4m3fnx4", "float16x4"), ) @@ -343,7 +342,7 @@ def create_quantize_func( axis, output_transpose, ) -> IRModule: - if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float: + if DataType(quantize_dtype).type_code == DataTypeCode.Float8E4M3FN: quantize_func = cls.quantize_fp8x4_e4m3 else: assert NotImplementedError() @@ -387,7 +386,7 @@ def create_dequantize_func( num_elem_per_storage, axis, ) -> IRModule: - if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float: + if DataType(quantize_dtype).type_code == DataTypeCode.Float8E4M3FN: dequantize_func = cls.dequantize_fp8x4_e4m3 else: assert NotImplementedError() @@ -732,7 +731,7 @@ def storage_dtype(self): @tvm.testing.fixture def quantize_dtype(self): - return "e4m3_float8" + return "float8_e4m3fn" @tvm.testing.fixture def num_el_per_storage(self): @@ -807,7 +806,7 @@ def test_main(self, weight_shape, model_dtype, target_str, compiled_functions): @tvm.testing.requires_cuda_compute_version(8, 9) -@pytest.mark.parametrize("dtype", ["e5m2_float8", "e4m3_float8"]) +@pytest.mark.parametrize("dtype", ["float8_e5m2", "float8_e4m3fn"]) def test_const(dtype): @T.prim_func def func(A: T.Buffer((4,), dtype)) -> None: @@ -822,7 +821,7 @@ def func(A: T.Buffer((4,), dtype)) -> None: @tvm.testing.requires_cuda_compute_version(8, 9) -@pytest.mark.parametrize("dtype", ["e5m2_float8", "e4m3_float8"]) +@pytest.mark.parametrize("dtype", ["float8_e5m2", "float8_e4m3fn"]) @pytest.mark.parametrize("vec_len", [2, 4, 8, 16]) def test_copy(dtype, vec_len): @T.prim_func @@ -867,7 +866,7 @@ class SingleBatchMoE_float8_e4m3: @T.prim_func(private=True) def moe_dequantize_gemv( x_handle: T.handle, - w: T.Buffer((num_experts, spatial_size, reduce_size), "e4m3_float8"), + w: T.Buffer((num_experts, spatial_size, reduce_size), "float8_e4m3fn"), scale: T.Buffer((1,), "float16"), indptr: T.Buffer((1, 2), "int32"), o: T.Buffer((2, spatial_size), "float16"), @@ -905,7 +904,7 @@ def moe_dequantize_gemv( def main( x: R.Tensor(("num_seq", reduce_size), dtype="float16"), indptr: R.Tensor((1, 2), dtype="int32"), - weight: R.Tensor((num_experts, spatial_size, reduce_size), dtype="e4m3_float8"), + weight: R.Tensor((num_experts, spatial_size, reduce_size), dtype="float8_e4m3fn"), scale: R.Tensor((1,), dtype="float32"), ) -> R.Tensor((2, spatial_size), dtype="float16"): num_seq = T.int64() @@ -965,4 +964,5 @@ def _pipeline(mod: tvm.ir.IRModule) -> tvm.ir.IRModule: if __name__ == "__main__": + # test_half_broadcast(6) tvm.testing.main() diff --git a/tests/python/ir/test_datatype_nv_fp8.py b/tests/python/ir/test_datatype_nv_fp8.py index 8313a97ee138..b812c70ab687 100644 --- a/tests/python/ir/test_datatype_nv_fp8.py +++ b/tests/python/ir/test_datatype_nv_fp8.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import numpy as np + import tvm import tvm.testing import tvm.tir as tir @@ -22,9 +23,10 @@ from tvm.script import tir as T try: - from ml_dtypes import float8_e4m3fn as e4m3_float8, float8_e5m2 as e5m2_float8 + from ml_dtypes import float8_e4m3fn as float8_e4m3fn + from ml_dtypes import float8_e5m2 as float8_e5m2 except ImportError: - e4m3_float8, e5m2_float8 = None, None + float8_e4m3fn, float8_e5m2 = None, None def fp8_unary(dtype: str): @@ -58,7 +60,7 @@ def func( np_dtype, dtype_str = tvm.testing.parameters( - (e4m3_float8, "e4m3_float8"), (e5m2_float8, "e5m2_float8") + (float8_e4m3fn, "float8_e4m3fn"), (float8_e5m2, "float8_e5m2") ) diff --git a/tests/python/ir/test_dtype.py b/tests/python/ir/test_dtype.py index 77cd1d7e4b5f..988e360748a6 100644 --- a/tests/python/ir/test_dtype.py +++ b/tests/python/ir/test_dtype.py @@ -15,22 +15,23 @@ # specific language governing permissions and limitations # under the License. """Test data type related API""" +import pytest + import tvm -from tvm import DataType import tvm.testing -import pytest +from tvm import DataType @pytest.mark.parametrize( "dtype_str, expected_size", - [("float32", 4), ("float32x4", 16), ("e5m2_float8x4", 4), ("uint8", 1)], + [("float32", 4), ("float32x4", 16), ("float8_e5m2x4", 4), ("uint8", 1)], ) def test_dtype_itemsize(dtype_str, expected_size): dtype = DataType(dtype_str) assert dtype.itemsize() == expected_size -@pytest.mark.parametrize("dtype_str", [("int32xvscalex4")]) +@pytest.mark.parametrize("dtype_str", ["int32xvscalex4"]) def test_dtype_itemmize_error(dtype_str): with pytest.raises(ValueError): size = DataType(dtype_str).itemsize() diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 2fbff8433bf7..c5514e272709 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -315,7 +315,7 @@ def test_matmul_fp8_offload( transpose_y, out_dtype, ): - in_dtype = "e4m3_float8" + in_dtype = "float8_e4m3fn" mod = get_relax_matmul_module( x_shape, y_shape, @@ -342,7 +342,7 @@ def test_matmul_fp8_offload( def test_matmul_fp8_dequantize_offload(): x_shape = (10, 32) y_shape = (64, 32) - in_dtype = "e4m3_float8" + in_dtype = "float8_e4m3fn" mod = get_relax_matmul_dequantize_module( x_shape, y_shape, @@ -369,7 +369,7 @@ def test_matmul_fp8_multiply_offload(): x_shape = (10, 32) y_shape = (64, 32) z_shape = (1,) - in_dtype, acc_dtype = ("e4m3_float8", "float32") + in_dtype, acc_dtype = ("float8_e4m3fn", "float32") mod = get_relax_matmul_multiply_module( x_shape, @@ -397,8 +397,8 @@ def test_matmul_fp8_multiply_offload(): "M, N, K, out_dtype, transposed_y, partition_done", [ (15, 64, 32, "float32", True, True), - (15, 64, 32, "e4m3_float8", True, True), - (15, 64, 32, "e5m2_float8", True, False), + (15, 64, 32, "float8_e4m3fn", True, True), + (15, 64, 32, "float8_e5m2", True, False), (16, 32, 60, "float32", True, False), (16, 30, 64, "float32", True, False), (16, 8, 16, "float16", True, True), @@ -407,7 +407,7 @@ def test_matmul_fp8_multiply_offload(): ) def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, transposed_y, partition_done): mod = get_relax_matmul_module( - (M, K), (N, K), "e4m3_float8", out_dtype, transposed_y=transposed_y + (M, K), (N, K), "float8_e4m3fn", out_dtype, transposed_y=transposed_y ) mod = partition_for_cublas(mod) func_name = "relax_matmul_cublas" if partition_done else "R.matmul" @@ -426,7 +426,7 @@ def test_cublas_partition_fp8_matmul_dequantize(M, N, K, scale, zp, num_bindings mod = get_relax_matmul_dequantize_module( (M, K), (N, K), - "e4m3_float8", + "float8_e4m3fn", "float16", transposed_y=True, scale_const=scale, @@ -443,7 +443,7 @@ def test_cublas_partition_fp8_matmul_multiply(): (M, K), (N, K), (1,), - "e4m3_float8", + "float8_e4m3fn", "float32", "float16", transposed_y=True, diff --git a/tests/python/relax/test_op_inspect.py b/tests/python/relax/test_op_inspect.py index 18d7a88f051a..ca4b0fc440be 100644 --- a/tests/python/relax/test_op_inspect.py +++ b/tests/python/relax/test_op_inspect.py @@ -21,10 +21,10 @@ import pytest import tvm.testing - from tvm import relax from tvm.ir import Op -from tvm.script import ir as I, relax as R +from tvm.script import ir as I +from tvm.script import relax as R # Parameterization for reading dtype of DLTensor. Chosen to have # multiple distinct type codes, number of lanes, and widths. @@ -34,7 +34,7 @@ "float32", "float32x4", "bfloat", - "e4m3_float8", + "float8_e4m3fn", ) shape = tvm.testing.parameter( [], diff --git a/tests/python/relax/test_op_qdq.py b/tests/python/relax/test_op_qdq.py index 8b2d49904166..d773a6c7d28a 100644 --- a/tests/python/relax/test_op_qdq.py +++ b/tests/python/relax/test_op_qdq.py @@ -68,17 +68,17 @@ def test_qdq_op_infer_struct_info_symbolic(): ) -def test_qdq_e4m3_float8_op_infer_struct_info_symbolic(): +def test_qdq_float8_e4m3fn_op_infer_struct_info_symbolic(): bb = relax.BlockBuilder() n = tir.Var("n", "int64") x = relax.Var("x", R.Tensor((n, 3), "float32")) - dx = relax.Var("dx", R.Tensor((n, 3), "e4m3_float8")) + dx = relax.Var("dx", R.Tensor((n, 3), "float8_e4m3fn")) s = relax.Var("s", R.Tensor([3], "float32")) zp = relax.Var("zp", R.Tensor([3], "float16")) _check_inference( bb, - relax.op.quantize(x, s, zp, 1, "e4m3_float8"), - relax.TensorStructInfo((n, 3), "e4m3_float8"), + relax.op.quantize(x, s, zp, 1, "float8_e4m3fn"), + relax.TensorStructInfo((n, 3), "float8_e4m3fn"), ) _check_inference( bb, @@ -87,8 +87,8 @@ def test_qdq_e4m3_float8_op_infer_struct_info_symbolic(): ) -def test_qdq_e5m2_float8_op_infer_struct_info_symbolic(): - dtype = "e5m2_float8" +def test_qdq_float8_e5m2_op_infer_struct_info_symbolic(): + dtype = "float8_e5m2" bb = relax.BlockBuilder() n = tir.Var("n", "int64") x = relax.Var("x", R.Tensor((n, 3), "float32")) diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py index fe9998bc798e..5a80e3e4f6c4 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py @@ -17,21 +17,24 @@ # pylint: disable=missing-docstring import numpy as np import pytest + import tvm import tvm.testing from tvm import te from tvm.testing.tir import mma_schedule from tvm.tir.tensor_intrin.cuda import ( + LDMATRIX_e4m3_A_INTRIN, + LDMATRIX_e4m3_B_TRANS_INTRIN, + LDMATRIX_e5m2_A_INTRIN, + LDMATRIX_e5m2_B_TRANS_INTRIN, LDMATRIX_f16_A_INTRIN, LDMATRIX_f16_B_INTRIN, LDMATRIX_f16_B_TRANS_INTRIN, LDMATRIX_i8_A_INTRIN, - LDMATRIX_i8_B_TRANS_INTRIN, LDMATRIX_i8_B_INTRIN, - LDMATRIX_e4m3_A_INTRIN, - LDMATRIX_e4m3_B_TRANS_INTRIN, - LDMATRIX_e5m2_A_INTRIN, - LDMATRIX_e5m2_B_TRANS_INTRIN, + LDMATRIX_i8_B_TRANS_INTRIN, + MMA_e4m3e4m3f32_TRANS_B_INTRIN, + MMA_e5m2e5m2f32_TRANS_B_INTRIN, MMA_f16f16f16_INTRIN, MMA_f16f16f16_TRANS_B_INTRIN, MMA_f16f16f32_INTRIN, @@ -41,8 +44,6 @@ MMA_fill_16x16_i32_INTRIN, MMA_i8i8i32_INTRIN, MMA_i8i8i32_TRANS_B_INTRIN, - MMA_e5m2e5m2f32_TRANS_B_INTRIN, - MMA_e4m3e4m3f32_TRANS_B_INTRIN, MMA_store_16x16_f16_global_INTRIN, MMA_store_16x16_f32_global_INTRIN, MMA_store_16x16_i32_global_INTRIN, @@ -132,10 +133,10 @@ def run_test( else: b_np = np.random.normal(size=(K, N)).astype("float16") c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype) - elif in_dtype in ["e4m3_float8", "e5m2_float8"]: + elif in_dtype in ["float8_e4m3fn", "float8_e5m2"]: typemap = { - "e4m3_float8": "float8_e4m3fn", - "e5m2_float8": "float8_e5m2", + "float8_e4m3fn": "float8_e4m3fn", + "float8_e5m2": "float8_e5m2", } a_np = ( np.random.uniform(low=-5, high=5, size=(M * K)) @@ -174,7 +175,7 @@ def run_test( f(a, b, c) - if out_dtype != "float16" and in_dtype not in ["e4m3_float8", "e5m2_float8"]: + if out_dtype != "float16" and in_dtype not in ["float8_e4m3fn", "float8_e5m2"]: # The numpy reference is computed with fp32 precision (otherwise too slow). # So there is non-trivial accuracy difference if TVM result is computed with fp16 accumulation. tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-2, atol=1e-2) @@ -384,7 +385,7 @@ def index_map_C(i, j): ) k_inner = 32 - in_dtype = "e4m3_float8" + in_dtype = "float8_e4m3fn" out_dtype = "float32" i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32, 2, 2] @@ -427,7 +428,7 @@ def index_map_C(i, j): ) k_inner = 32 - in_dtype = "e5m2_float8" + in_dtype = "float8_e5m2" out_dtype = "float32" i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32, 2, 2] diff --git a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py index e1f487c572df..0b10fe5c2199 100644 --- a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py +++ b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py @@ -17,8 +17,8 @@ import tvm import tvm.script import tvm.testing -from tvm.target import Target from tvm.script import tir as T +from tvm.target import Target from tvm.tir.transform.transform import BindTarget # pylint: disable=no-member,invalid-name,unused-variable @@ -69,7 +69,7 @@ def main(Aptr: T.handle(dtype), Bptr: T.handle(dtype), Dptr: T.handle(dtype)): def promote_uint8(f8_dtype: str, promote_dtype: str, v): - if f8_dtype == "e4m3_float8": + if f8_dtype == "float8_e4m3fn": if promote_dtype == "float16": mantissa = T.bitwise_and( T.shift_left(T.Cast("uint16", v), T.uint16(7)), T.uint16(0x3FF) @@ -96,7 +96,7 @@ def promote_uint8(f8_dtype: str, promote_dtype: str, v): ) sign = T.shift_left(T.Cast("uint32", T.shift_right(v, T.uint8(7))), T.uint32(31)) return T.reinterpret("float32", T.bitwise_or(T.bitwise_or(mantissa, exponent), sign)) - else: # f8_dtype == "e5m2_float8" + else: # f8_dtype == "float8_e5m2" if promote_dtype == "float16": return T.reinterpret("float16", T.shift_left(T.Cast("uint16", v), T.uint16(8))) else: # promote_dtype == "float32" @@ -115,7 +115,7 @@ def promote_uint8(f8_dtype: str, promote_dtype: str, v): def cast_to_uint8(f8_dtype: str, promote_dtype: str, v): - if f8_dtype == "e4m3_float8": + if f8_dtype == "float8_e4m3fn": if promote_dtype == "float16": uint16_v = T.reinterpret("uint16", v) rounding_bias = T.bitwise_and( @@ -154,7 +154,7 @@ def cast_to_uint8(f8_dtype: str, promote_dtype: str, v): return T.if_then_else( round_to_zero, T.uint8(0), T.bitwise_or(T.bitwise_or(mantissa, exponent), sign) ) - else: # f8_dtype == "e5m2_float8" + else: # f8_dtype == "float8_e5m2" if promote_dtype == "float16": uint16_v = T.reinterpret("uint16", v) rounding_bias = T.bitwise_and( @@ -201,12 +201,12 @@ def main(Aptr: T.handle("uint8"), Bptr: T.handle("uint8"), Dptr: T.handle("uint8 return After -dtype = tvm.testing.parameter("e4m3_float8", "e5m2_float8") +dtype = tvm.testing.parameter("float8_e4m3fn", "float8_e5m2") promote_dtype = tvm.testing.parameter("float16", "float32") def test_fp8_compute_legalize(dtype, promote_dtype): - target = Target("cuda") + target = Target("nvidia/nvidia-a100") before = BindTarget(target)(get_before(dtype)) expected = BindTarget(target)(get_after_compute_legalize(dtype, promote_dtype)) # run the transform twice to ensure we can afford to deal @@ -217,7 +217,7 @@ def test_fp8_compute_legalize(dtype, promote_dtype): def test_fp8_storage_legalize(dtype, promote_dtype): - target = Target("cuda") + target = Target("nvidia/nvidia-a100") before = BindTarget(target)(get_after_compute_legalize(dtype, promote_dtype)) after = tvm.tir.transform.FP8StorageLegalize()(before) expected = BindTarget(target)(get_after_storage_legalize(dtype, promote_dtype)) diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index b7ba57fa9387..943ba54060e6 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -17,6 +17,7 @@ # pylint: disable=missing-docstring import re + import pytest import tvm.testing @@ -917,23 +918,23 @@ def func(): _assert_print(func, expected_output) -@pytest.mark.parametrize("dtype", ["e4m3_float8", "e5m2_float8"]) +@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"]) def test_float8(dtype): from tvm.script import tir as T def get_func(dtype): - if dtype == "e4m3_float8": + if dtype == "float8_e4m3fn": @T.prim_func def func(): - T.evaluate(T.e4m3_float8(0.0)) + T.evaluate(T.float8_e4m3fn(0.0)) return func - elif dtype == "e5m2_float8": + elif dtype == "float8_e5m2": @T.prim_func def func(): - T.evaluate(T.e5m2_float8(0.0)) + T.evaluate(T.float8_e5m2(0.0)) return func