Skip to content

Commit 268cd2c

Browse files
committed
CPU check.
1 parent cd392f6 commit 268cd2c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

python-package/xgboost/testing/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66

77
from ..compat import import_cupy
8+
from ..data import _is_cupy_alike
89

910
Device: TypeAlias = Literal["cpu", "cuda"]
1011

@@ -13,7 +14,7 @@ def assert_allclose(
1314
device: Device, a: Any, b: Any, *, rtol: float = 1e-7, atol: float = 0
1415
) -> None:
1516
"""Dispatch the assert_allclose for devices."""
16-
if device == "cpu" and not hasattr(a, "get") and not hasattr(b, "get"):
17+
if device == "cpu" and not _is_cupy_alike(a) and not _is_cupy_alike(b):
1718
np.testing.assert_allclose(a, b, atol=atol, rtol=rtol)
1819
else:
1920
cp = import_cupy()

0 commit comments

Comments
 (0)