Skip to content

Commit

Permalink
[MLIR][test] Check for ml_dtypes before running tests (#123061)
Browse files Browse the repository at this point in the history
We noticed that `mlir/python/requirements.txt` lists `ml_dtypes` as a requirement but when looking at the code in `mlir/python`, the only `import` is guarded:

```python
try:
    import ml_dtypes
except ModuleNotFoundError:
    # The third-party ml_dtypes provides some optional low precision data-types for NumPy.
    ml_dtypes = None
```

This makes `ml_dtypes` an optional dependency.

Some python tests however partially depend on `ml_dtypes` and should not run if that module is unavailable. That is what this change does.

This is a replacement for #123051 which was excluding tests too broadly.
  • Loading branch information
kwk authored Jan 15, 2025
1 parent d0a3642 commit 34d5072
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions mlir/test/python/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir.runtime import *
from ml_dtypes import bfloat16, float8_e5m2

try:
from ml_dtypes import bfloat16, float8_e5m2

HAS_ML_DTYPES = True
except ModuleNotFoundError:
HAS_ML_DTYPES = False


MLIR_RUNNER_UTILS = os.getenv(
"MLIR_RUNNER_UTILS", "../../../../lib/libmlir_runner_utils.so"
Expand Down Expand Up @@ -564,7 +571,8 @@ def testBF16Memref():
log(npout)


run(testBF16Memref)
if HAS_ML_DTYPES:
run(testBF16Memref)


# Test f8E5M2 memrefs
Expand Down Expand Up @@ -603,7 +611,8 @@ def testF8E5M2Memref():
log(npout)


run(testF8E5M2Memref)
if HAS_ML_DTYPES:
run(testF8E5M2Memref)


# Test addition of two 2d_memref
Expand Down

0 comments on commit 34d5072

Please sign in to comment.