Skip to content

Commit b9a2b92

Browse files
authored
fix: More fix of ml_dtypes<0.5 (#222)
This is a followup of #198 of a missing fix.
1 parent c897e4c commit b9a2b92

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

python/tvm_ffi/_dtype.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,8 @@ def lanes(self) -> int:
314314
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
315315
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "float8_e4m3fn"
316316
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2"
317-
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn"
317+
if hasattr(ml_dtypes, "float4_e2m1fn"): # ml_dtypes >= 0.5.0
318+
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn"
318319
except ImportError:
319320
pass
320321

0 commit comments

Comments
 (0)