Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 48 additions & 29 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ class DataType {
kFloat = kDLFloat,
kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
kBFloat = kDLBfloat,
kE4M3Float = 6U,
kE5M2Float = 7U,
kFloat4E2M1Fn = 8U,
kFloat8_e4m3fn = 6U,
kFloat8_e5m2 = 7U,
kFloat4_e2m1fn = 8U,
kCustomBegin = 129
};
/*! \brief default constructor */
Expand All @@ -85,10 +85,10 @@ class DataType {
if (code == kBFloat) {
ICHECK_EQ(bits, 16);
}
if (code == kE4M3Float || code == kE5M2Float) {
if (code == kFloat8_e4m3fn || code == kFloat8_e5m2) {
ICHECK_EQ(bits, 8);
}
if (code == kFloat4E2M1Fn) {
if (code == kFloat4_e2m1fn) {
ICHECK_EQ(bits, 4);
}
}
Expand Down Expand Up @@ -126,15 +126,15 @@ class DataType {
bool is_float() const { return code() == DataType::kFloat; }
/*! \return whether type is a float8 type. */
bool is_float8() const {
return (code() == DataType::kFloat || code() == DataType::kE4M3Float ||
code() == DataType::kE5M2Float) &&
return (code() == DataType::kFloat || code() == DataType::kFloat8_e4m3fn ||
code() == DataType::kFloat8_e5m2) &&
bits() == 8;
}
/*! \return whether type is a float4 type. */
bool is_float4() const { return code() == DataType::kFloat4E2M1Fn && bits() == 4; }
bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && bits() == 8); }
bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && bits() == 8); }
bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4E2M1Fn && bits() == 4); }
bool is_float4() const { return code() == DataType::kFloat4_e2m1fn && bits() == 4; }
bool is_float8_e4m3fn() const { return (code() == DataType::kFloat8_e4m3fn && bits() == 8); }
bool is_float8_e5m2() const { return (code() == DataType::kFloat8_e5m2 && bits() == 8); }
bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4_e2m1fn && bits() == 4); }
/*! \return whether type is a float16 type. */
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
Expand Down Expand Up @@ -252,19 +252,19 @@ class DataType {
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kE4M3Float, 8, lanes); }
static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kFloat8_e4m3fn, 8, lanes); }
/*!
* \brief Construct NV float8 e5m2 datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); }
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2, 8, lanes); }
/*!
* \brief Construct NV float4_e2m1fn datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType NVFloat4E2M1FN(int lanes = 1) { return DataType(kFloat4E2M1Fn, 4, lanes); }
static DataType NVFloat4E2M1FN(int lanes = 1) { return DataType(kFloat4_e2m1fn, 4, lanes); }
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes.
Expand Down Expand Up @@ -393,11 +393,11 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
return "handle";
case kDLBfloat:
return "bfloat";
case DataType::kE4M3Float:
return "e4m3_float";
case DataType::kE5M2Float:
return "e5m2_float";
case DataType::kFloat4E2M1Fn:
case DataType::kFloat8_e4m3fn:
return "float8_e4m3fn";
case DataType::kFloat8_e5m2:
return "float8_e5m2";
case DataType::kFloat4_e2m1fn:
return "float4_e2m1fn";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
Expand All @@ -420,7 +420,10 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
}
if (t.code == kTVMOpaqueHandle) return os;
int16_t lanes = static_cast<int16_t>(t.lanes);
os << static_cast<int>(t.bits);
if (t.code != DataType::kFloat8_e4m3fn && t.code != DataType::kFloat8_e5m2 &&
t.code != DataType::kFloat4_e2m1fn) {
os << static_cast<int>(t.bits);
}
if (lanes > 1) {
os << 'x' << lanes;
} else if (lanes < -1) {
Expand Down Expand Up @@ -458,7 +461,7 @@ inline DLDataType String2DLDataType(std::string s) {
scan = s.c_str() + 4;
} else if (s.substr(0, 13) == "float4_e2m1fn") {
// Avoid being treated as "float"
t.code = DataType::kFloat4E2M1Fn;
t.code = DataType::kFloat4_e2m1fn;
t.bits = 4;
scan = s.c_str() + 13;
char* endpt = nullptr;
Expand All @@ -468,6 +471,30 @@ inline DLDataType String2DLDataType(std::string s) {
}
ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
return t;
} else if (s.substr(0, 13) == "float8_e4m3fn") {
// Avoid being treated as "float"
t.code = DataType::kFloat8_e4m3fn;
t.bits = 8;
scan = s.c_str() + 13;
char* endpt = nullptr;
if (*scan == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
scan = endpt;
}
ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
return t;
} else if (s.substr(0, 11) == "float8_e5m2") {
// Avoid being treated as "float"
t.code = DataType::kFloat8_e5m2;
t.bits = 8;
scan = s.c_str() + 11;
char* endpt = nullptr;
if (*scan == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
scan = endpt;
}
ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
return t;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat;
scan = s.c_str() + 5;
Expand All @@ -484,14 +511,6 @@ inline DLDataType String2DLDataType(std::string s) {
t.code = DataType::kBFloat;
t.bits = 16;
scan = s.c_str() + 6;
} else if (s.substr(0, 10) == "e4m3_float") {
t.code = DataType::kE4M3Float;
t.bits = 8;
scan = s.c_str() + 10;
} else if (s.substr(0, 10) == "e5m2_float") {
t.code = DataType::kE5M2Float;
t.bits = 8;
scan = s.c_str() + 10;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,10 +502,10 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64));

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8, DataType::NVFloat8E4M3);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8, DataType::NVFloat8E5M2);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, DataType::NVFloat8E4M3);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2, DataType::NVFloat8E5M2);

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1fn, DataType::NVFloat4E2M1FN);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, DataType::NVFloat4E2M1FN);

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
Expand Down
42 changes: 23 additions & 19 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ class DataTypeCode(object):
FLOAT = 2
HANDLE = 3
BFLOAT = 4
E4M3Float = 6
E5M2Float = 7
FLOAT4E2M1FN = 8
Float8E4M3FN = 6
Float8E5M2 = 7
Float4E2M1FN = 8


class DataType(ctypes.Structure):
Expand All @@ -81,9 +81,9 @@ class DataType(ctypes.Structure):
DataTypeCode.FLOAT: "float",
DataTypeCode.HANDLE: "handle",
DataTypeCode.BFLOAT: "bfloat",
DataTypeCode.E4M3Float: "e4m3_float",
DataTypeCode.E5M2Float: "e5m2_float",
DataTypeCode.FLOAT4E2M1FN: "float4_e2m1fn",
DataTypeCode.Float8E4M3FN: "float8_e4m3fn",
DataTypeCode.Float8E5M2: "float8_e5m2",
DataTypeCode.Float4E2M1FN: "float4_e2m1fn",
}
NUMPY2STR = {
np.dtype(np.bool_): "bool",
Expand Down Expand Up @@ -112,9 +112,9 @@ class DataType(ctypes.Structure):
"uint16": {"type_code": DataTypeCode.UINT, "bits": 16, "lanes": 1},
"uint32": {"type_code": DataTypeCode.UINT, "bits": 32, "lanes": 1},
"uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
"e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1},
"e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1},
"float4_e2m1fn": {"type_code": DataTypeCode.FLOAT4E2M1FN, "bits": 4, "lanes": 1},
"float8_e4m3fn": {"type_code": DataTypeCode.Float8E4M3FN, "bits": 8, "lanes": 1},
"float8_e5m2": {"type_code": DataTypeCode.Float8E5M2, "bits": 8, "lanes": 1},
"float4_e2m1fn": {"type_code": DataTypeCode.Float4E2M1FN, "bits": 4, "lanes": 1},
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
Expand Down Expand Up @@ -157,9 +157,17 @@ def __init__(self, type_str):
head = head[4:]
elif head.startswith("float4_e2m1fn"):
# Avoid being treated as "float"
self.type_code = DataTypeCode.FLOAT4E2M1FN
self.type_code = DataTypeCode.Float4E2M1FN
bits = 4
head = ""
elif head.startswith("float8_e4m3fn"):
self.type_code = DataTypeCode.Float8E4M3FN
bits = 8
head = ""
elif head.startswith("float8_e5m2"):
self.type_code = DataTypeCode.Float8E5M2
bits = 8
head = ""
elif head.startswith("float"):
self.type_code = DataTypeCode.FLOAT
head = head[5:]
Expand All @@ -170,12 +178,6 @@ def __init__(self, type_str):
elif head.startswith("bfloat"):
self.type_code = DataTypeCode.BFLOAT
head = head[6:]
elif head.startswith("e4m3_float"):
self.type_code = DataTypeCode.E4M3Float
head = head[10:]
elif head.startswith("e5m2_float"):
self.type_code = DataTypeCode.E5M2Float
head = head[10:]
elif head.startswith("custom"):
# pylint: disable=import-outside-toplevel
import tvm.runtime._ffi_api
Expand Down Expand Up @@ -204,7 +206,9 @@ def __repr__(self):

type_name = "custom[%s]" % tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
if self.type_code in [
DataTypeCode.FLOAT4E2M1FN,
DataTypeCode.Float8E4M3FN,
DataTypeCode.Float8E5M2,
DataTypeCode.Float4E2M1FN,
]:
x = type_name
else:
Expand Down Expand Up @@ -243,8 +247,8 @@ def itemsize(self):

if ml_dtypes is not None:
DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "float8_e4m3fn"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn"

RPC_SESS_MASK = 128
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/contrib/tvmjs.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,19 +369,19 @@ def load_ndarray_cache(cachepath: str, device: tvm.runtime.Device):
arr = tvm.nd.empty(shape, dtype, device=device)
assert offset + nbytes <= len(raw_data)
buffer_source = raw_data[offset : offset + nbytes]
if dtype == "e4m3_float8":
if dtype == "float8_e4m3fn":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e4m3fn
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e4m3_float8 array to numpy."
"ml_dtypes is not installed, cannot convert float8_e4m3fn array to numpy."
)
if dtype == "e5m2_float8":
if dtype == "float8_e5m2":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e5m2
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy."
"ml_dtypes is not installed, cannot convert float8_e5m2 array to numpy."
)
if encode_format == "f32-to-bf16" and dtype == "float32":
data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape)
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/relax/backend/cuda/cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@

from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import (
make_matmul_pattern,
make_matmul_dequantize_pattern,
make_matmul_multiply_pattern,
make_matmul_pattern,
)
from ..utils import has_leaking_intermediate_variables


def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
"""Check if dtypes in the given workload are supported by cuBLAS BYOC."""
if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
# The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8'
return out_dtype != "e5m2_float8"
if lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn":
# The output cannot be 'float8_e5m2' if inputs are 'float8_e4m3fn'
return out_dtype != "float8_e5m2"
return (
(lhs_dtype == "float16" and rhs_dtype == "float16")
or (lhs_dtype == "float32" and rhs_dtype == "float32")
Expand Down Expand Up @@ -83,7 +83,7 @@ def _check_matmul(context: PatternCheckContext) -> bool:
if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or rhs_shape[-1] % 4 != 0:
# Rows number must be multiples of 4 for IGEMM
return False
elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
elif lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn":
matmul_rhs_var = matmul_call.args[1]
rhs_transposed = False
if matmul_rhs_var in context.matched_bindings:
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/relax/backend/rocm/hipblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@

def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): # pylint: disable=unused-argument
"""Check if dtypes in the given workload are supported by hipblas BYOC."""
if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
# The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8'
# return out_dtype != "e5m2_float8"
if lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn":
# The output cannot be 'float8_e5m2' if inputs are 'float8_e4m3fn'
# return out_dtype != "float8_e5m2"
return False
return (lhs_dtype == "float16" and rhs_dtype == "float16") or (
lhs_dtype == "int8" and rhs_dtype == "int8"
Expand Down Expand Up @@ -61,7 +61,7 @@ def _check_matmul(context: PatternCheckContext) -> bool:

if lhs_dtype == "int8" and rhs_dtype == "int8":
return False
elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
elif lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn":
return False

lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,19 +249,19 @@ def numpy(self):
dtype = "int8"
if dtype == "bfloat16":
dtype = "uint16"
if dtype == "e4m3_float8":
if dtype == "float8_e4m3fn":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e4m3fn
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e4m3_float8 array to numpy."
"ml_dtypes is not installed, cannot convert float8_e4m3fn array to numpy."
)
if dtype == "e5m2_float8":
if dtype == "float8_e5m2":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e5m2
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy."
"ml_dtypes is not installed, cannot convert float8_e5m2 array to numpy."
)
if dtype == "float4_e2m1fn":
if ml_dtypes is not None:
Expand Down
Loading
Loading