Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][test] Check for ml_dtypes before running tests #123061

Merged
merged 2 commits into from
Jan 15, 2025

Conversation

kwk
Copy link
Contributor

@kwk kwk commented Jan 15, 2025

We noticed that mlir/python/requirements.txt lists ml_dtypes as a requirement but when looking at the code in mlir/python, the only import is guarded:

try:
    import ml_dtypes
except ModuleNotFoundError:
    # The third-party ml_dtypes provides some optional low precision data-types for NumPy.
    ml_dtypes = None

This makes ml_dtypes an optional dependency.

Some python tests however partially depend on ml_dtypes and should not run if that module is unavailable. That is what this change does.

This is a replacement for #123051 which was excluding tests too broadly.

@llvmbot
Copy link
Member

llvmbot commented Jan 15, 2025

@llvm/pr-subscribers-mlir

Author: Konrad Kleine (kwk)

Changes

We noticed that mlir/python/requirements.txt lists ml_dtypes as a requirement but when looking at the code in mlir/python, the only import is guarded:

try:
    import ml_dtypes
except ModuleNotFoundError:
    # The third-party ml_dtypes provides some optional low precision data-types for NumPy.
    ml_dtypes = None

This makes ml_dtypes an optional dependency.

Some python tests however partially depend on ml_dtypes and should not run if that module is unavailable. That is what this change does.

This is a replacement for #123051 which was excluding tests too broadly.


Full diff: https://github.com/llvm/llvm-project/pull/123061.diff

1 Files Affected:

  • (modified) mlir/test/python/execution_engine.py (+10-3)
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 0d12c35d96bee7..9bfec3e7d0c538 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -5,7 +5,12 @@
 from mlir.passmanager import *
 from mlir.execution_engine import *
 from mlir.runtime import *
-from ml_dtypes import bfloat16, float8_e5m2
+try:
+    from ml_dtypes import bfloat16, float8_e5m2
+    HAS_ML_DTYPES=True
+except ModuleNotFoundError:
+    HAS_ML_DTYPES=False
+
 
 MLIR_RUNNER_UTILS = os.getenv(
     "MLIR_RUNNER_UTILS", "../../../../lib/libmlir_runner_utils.so"
@@ -564,7 +569,8 @@ def testBF16Memref():
         log(npout)
 
 
-run(testBF16Memref)
+if HAS_ML_DTYPES:
+    run(testBF16Memref)
 
 
 # Test f8E5M2 memrefs
@@ -603,7 +609,8 @@ def testF8E5M2Memref():
         log(npout)
 
 
-run(testF8E5M2Memref)
+if HAS_ML_DTYPES:
+    run(testF8E5M2Memref)
 
 
 #  Test addition of two 2d_memref

Copy link

github-actions bot commented Jan 15, 2025

✅ With the latest revision this PR passed the Python code formatter.

@kwk kwk force-pushed the check_for_ml_dtypes_before_running_tests branch from 195fa3f to 8db4f46 Compare January 15, 2025 14:26
Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@makslevental
Copy link
Contributor

@kwk let me know if you need me to merge

@kwk kwk merged commit 34d5072 into llvm:main Jan 15, 2025
4 of 5 checks passed
@kwk kwk deleted the check_for_ml_dtypes_before_running_tests branch January 15, 2025 16:33
paulhuggett pushed a commit to paulhuggett/llvm-project that referenced this pull request Jan 16, 2025
We noticed that `mlir/python/requirements.txt` lists `ml_dtypes` as a requirement but when looking at the code in `mlir/python`, the only `import` is guarded:

```python
try:
    import ml_dtypes
except ModuleNotFoundError:
    # The third-party ml_dtypes provides some optional low precision data-types for NumPy.
    ml_dtypes = None
```

This makes `ml_dtypes` an optional dependency.

Some python tests however partially depend on `ml_dtypes` and should not run if that module is unavailable. That is what this change does.

This is a replacement for llvm#123051 which was excluding tests too broadly.
@kwk
Copy link
Contributor Author

kwk commented Jan 16, 2025

@makslevental, unfortunately the "solution" doesn't work as I expected. The problem are the CHECK-lines which are of course not conditionalized. Take this for example:

# Test bf16 memrefs
# CHECK-LABEL: TEST: testBF16Memref
def testBF16Memref():
    with Context():
        module = Module.parse(
            """
    module  {
      func.func @main(%arg0: memref<1xbf16>,
                      %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
        %0 = arith.constant 0 : index
        %1 = memref.load %arg0[%0] : memref<1xbf16>
        memref.store %1, %arg1[%0] : memref<1xbf16>
        return
      }
    } """
        )

        arg1 = np.array([0.5]).astype(bfloat16)
        arg2 = np.array([0.0]).astype(bfloat16)

        arg1_memref_ptr = ctypes.pointer(
            ctypes.pointer(get_ranked_memref_descriptor(arg1))
        )
        arg2_memref_ptr = ctypes.pointer(
            ctypes.pointer(get_ranked_memref_descriptor(arg2))
        )

        execution_engine = ExecutionEngine(lowerToLLVM(module))
        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)

        # test to-numpy utility
        # CHECK: [0.5]
        npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
        log(npout)


if HAS_ML_DTYPES:
    run(testBF16Memref)

The expected # CHECK-LABEL: TEST: testBF16Memref and # CHECK: [0.5] will not be found.

What options do we have here?

@makslevental
Copy link
Contributor

makslevental commented Jan 16, 2025

The expected # CHECK-LABEL: TEST: testBF16Memref and # CHECK: [0.5] will not be found.

What options do we have here?

whoops should've thought of that - in this case the solution is simple: just change those checks for printed values into asserts for those same values:

x = ranked_memref_to_numpy(arg2_memref_ptr[0])
assert len(x) == 1 and assert(x[0]) == 0.5

EDIT:

and to satisfy the CHECK-LABEL do

if HAS_ML_DTYPES:
    run(testBF16Memref)
else:
    print("TEST: testBF16Memref")

kwk added a commit to kwk/llvm-project that referenced this pull request Jan 16, 2025
In order to optionally run some checks that depend on the `ml_dtypes`
python module we have to remove the `CHECK` lines for those tests or
they will be required and missed in the test output.

I've changed to use asserts as recommended in [1].

[1]: llvm#123061 (comment)
@kwk
Copy link
Contributor Author

kwk commented Jan 16, 2025

@makslevental sorry for the delay, here's the follow-up PR: #123240

kwk added a commit to kwk/llvm-project that referenced this pull request Jan 17, 2025
In order to optionally run some checks that depend on the `ml_dtypes`
python module we have to remove the `CHECK` lines for those tests or
they will be required and missed in the test output.

I've changed to use asserts as recommended in [1].

[1]: llvm#123061 (comment)
kwk added a commit to kwk/llvm-project that referenced this pull request Jan 17, 2025
In order to optionally run some checks that depend on the `ml_dtypes`
python module we have to remove the `CHECK` lines for those tests or
they will be required and missed in the test output.

I've changed to use asserts as recommended in [1].

[1]: llvm#123061 (comment)
kwk added a commit that referenced this pull request Jan 17, 2025
In order to optionally run some checks that depend on the `ml_dtypes`
python module we have to remove the `CHECK` lines for those tests or
they will be required and missed in the test output.

I've changed to use asserts as recommended in [1].

[1]:
#123061 (comment)
github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Jan 17, 2025
In order to optionally run some checks that depend on the `ml_dtypes`
python module we have to remove the `CHECK` lines for those tests or
they will be required and missed in the test output.

I've changed to use asserts as recommended in [1].

[1]:
llvm/llvm-project#123061 (comment)
DKLoehr pushed a commit to DKLoehr/llvm-project that referenced this pull request Jan 17, 2025
We noticed that `mlir/python/requirements.txt` lists `ml_dtypes` as a requirement but when looking at the code in `mlir/python`, the only `import` is guarded:

```python
try:
    import ml_dtypes
except ModuleNotFoundError:
    # The third-party ml_dtypes provides some optional low precision data-types for NumPy.
    ml_dtypes = None
```

This makes `ml_dtypes` an optional dependency.

Some python tests however partially depend on `ml_dtypes` and should not run if that module is unavailable. That is what this change does.

This is a replacement for llvm#123051 which was excluding tests too broadly.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants