We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
ml_dtypes<0.5
1 parent c897e4c commit b9a2b92Copy full SHA for b9a2b92
python/tvm_ffi/_dtype.py
@@ -314,7 +314,8 @@ def lanes(self) -> int:
314
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
315
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "float8_e4m3fn"
316
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2"
317
- dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn"
+ if hasattr(ml_dtypes, "float4_e2m1fn"): # ml_dtypes >= 0.5.0
318
+ dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn"
319
except ImportError:
320
pass
321
0 commit comments