-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][test] Check for ml_dtypes before running tests #123061
Conversation
@llvm/pr-subscribers-mlir Author: Konrad Kleine (kwk) ChangesWe noticed that try:
import ml_dtypes
except ModuleNotFoundError:
# The third-party ml_dtypes provides some optional low precision data-types for NumPy.
ml_dtypes = None This makes Some python tests however partially depend on This is a replacement for #123051 which was excluding tests too broadly. Full diff: https://github.com/llvm/llvm-project/pull/123061.diff 1 Files Affected:
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 0d12c35d96bee7..9bfec3e7d0c538 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -5,7 +5,12 @@
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir.runtime import *
-from ml_dtypes import bfloat16, float8_e5m2
+try:
+ from ml_dtypes import bfloat16, float8_e5m2
+ HAS_ML_DTYPES=True
+except ModuleNotFoundError:
+ HAS_ML_DTYPES=False
+
MLIR_RUNNER_UTILS = os.getenv(
"MLIR_RUNNER_UTILS", "../../../../lib/libmlir_runner_utils.so"
@@ -564,7 +569,8 @@ def testBF16Memref():
log(npout)
-run(testBF16Memref)
+if HAS_ML_DTYPES:
+ run(testBF16Memref)
# Test f8E5M2 memrefs
@@ -603,7 +609,8 @@ def testF8E5M2Memref():
log(npout)
-run(testF8E5M2Memref)
+if HAS_ML_DTYPES:
+ run(testF8E5M2Memref)
# Test addition of two 2d_memref
|
✅ With the latest revision this PR passed the Python code formatter. |
195fa3f
to
8db4f46
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@kwk let me know if you need me to merge |
We noticed that `mlir/python/requirements.txt` lists `ml_dtypes` as a requirement but when looking at the code in `mlir/python`, the only `import` is guarded: ```python try: import ml_dtypes except ModuleNotFoundError: # The third-party ml_dtypes provides some optional low precision data-types for NumPy. ml_dtypes = None ``` This makes `ml_dtypes` an optional dependency. Some python tests however partially depend on `ml_dtypes` and should not run if that module is unavailable. That is what this change does. This is a replacement for llvm#123051 which was excluding tests too broadly.
@makslevental, unfortunately the "solution" doesn't work as I expected. The problem are the # Test bf16 memrefs
# CHECK-LABEL: TEST: testBF16Memref
def testBF16Memref():
with Context():
module = Module.parse(
"""
module {
func.func @main(%arg0: memref<1xbf16>,
%arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
%0 = arith.constant 0 : index
%1 = memref.load %arg0[%0] : memref<1xbf16>
memref.store %1, %arg1[%0] : memref<1xbf16>
return
}
} """
)
arg1 = np.array([0.5]).astype(bfloat16)
arg2 = np.array([0.0]).astype(bfloat16)
arg1_memref_ptr = ctypes.pointer(
ctypes.pointer(get_ranked_memref_descriptor(arg1))
)
arg2_memref_ptr = ctypes.pointer(
ctypes.pointer(get_ranked_memref_descriptor(arg2))
)
execution_engine = ExecutionEngine(lowerToLLVM(module))
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
# test to-numpy utility
# CHECK: [0.5]
npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
log(npout)
if HAS_ML_DTYPES:
run(testBF16Memref) The expected What options do we have here? |
whoops should've thought of that - in this case the solution is simple: just change those checks for printed values into x = ranked_memref_to_numpy(arg2_memref_ptr[0])
assert len(x) == 1 and assert(x[0]) == 0.5 EDIT: and to satisfy the if HAS_ML_DTYPES:
run(testBF16Memref)
else:
print("TEST: testBF16Memref") |
In order to optionally run some checks that depend on the `ml_dtypes` python module we have to remove the `CHECK` lines for those tests or they will be required and missed in the test output. I've changed to use asserts as recommended in [1]. [1]: llvm#123061 (comment)
@makslevental sorry for the delay, here's the follow-up PR: #123240 |
In order to optionally run some checks that depend on the `ml_dtypes` python module we have to remove the `CHECK` lines for those tests or they will be required and missed in the test output. I've changed to use asserts as recommended in [1]. [1]: llvm#123061 (comment)
In order to optionally run some checks that depend on the `ml_dtypes` python module we have to remove the `CHECK` lines for those tests or they will be required and missed in the test output. I've changed to use asserts as recommended in [1]. [1]: llvm#123061 (comment)
In order to optionally run some checks that depend on the `ml_dtypes` python module we have to remove the `CHECK` lines for those tests or they will be required and missed in the test output. I've changed to use asserts as recommended in [1]. [1]: #123061 (comment)
In order to optionally run some checks that depend on the `ml_dtypes` python module we have to remove the `CHECK` lines for those tests or they will be required and missed in the test output. I've changed to use asserts as recommended in [1]. [1]: llvm/llvm-project#123061 (comment)
We noticed that `mlir/python/requirements.txt` lists `ml_dtypes` as a requirement but when looking at the code in `mlir/python`, the only `import` is guarded: ```python try: import ml_dtypes except ModuleNotFoundError: # The third-party ml_dtypes provides some optional low precision data-types for NumPy. ml_dtypes = None ``` This makes `ml_dtypes` an optional dependency. Some python tests however partially depend on `ml_dtypes` and should not run if that module is unavailable. That is what this change does. This is a replacement for llvm#123051 which was excluding tests too broadly.
We noticed that
mlir/python/requirements.txt
listsml_dtypes
as a requirement but when looking at the code inmlir/python
, the onlyimport
is guarded:This makes
ml_dtypes
an optional dependency.Some python tests however partially depend on
ml_dtypes
and should not run if that module is unavailable. That is what this change does.This is a replacement for #123051 which was excluding tests too broadly.