diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index 0d12c35d96bee71..e3f41815800d58d 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -5,7 +5,14 @@ from mlir.passmanager import * from mlir.execution_engine import * from mlir.runtime import * -from ml_dtypes import bfloat16, float8_e5m2 + +try: + from ml_dtypes import bfloat16, float8_e5m2 + + HAS_ML_DTYPES = True +except ModuleNotFoundError: + HAS_ML_DTYPES = False + MLIR_RUNNER_UTILS = os.getenv( "MLIR_RUNNER_UTILS", "../../../../lib/libmlir_runner_utils.so" @@ -564,7 +571,8 @@ def testBF16Memref(): log(npout) -run(testBF16Memref) +if HAS_ML_DTYPES: + run(testBF16Memref) # Test f8E5M2 memrefs @@ -603,7 +611,8 @@ def testF8E5M2Memref(): log(npout) -run(testF8E5M2Memref) +if HAS_ML_DTYPES: + run(testF8E5M2Memref) # Test addition of two 2d_memref