1515from typing import Pattern
1616from typing import Tuple
1717from typing import Type
18+ from typing import TYPE_CHECKING
1819from typing import TypeVar
1920from typing import Union
2021
22+ if TYPE_CHECKING :
23+ from numpy import ndarray
24+
25+
2126import _pytest ._code
2227from _pytest .compat import final
2328from _pytest .compat import STRING_TYPES
@@ -232,10 +237,11 @@ def __repr__(self) -> str:
232237 def __eq__ (self , actual ) -> bool :
233238 """Return whether the given value is equal to the expected value
234239 within the pre-specified tolerance."""
235- if _is_numpy_array (actual ):
240+ asarray = _as_numpy_array (actual )
241+ if asarray is not None :
236242 # Call ``__eq__()`` manually to prevent infinite-recursion with
237243 # numpy<1.13. See #3748.
238- return all (self .__eq__ (a ) for a in actual .flat )
244+ return all (self .__eq__ (a ) for a in asarray .flat )
239245
240246 # Short-circuit exact equality.
241247 if actual == self .expected :
@@ -521,6 +527,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
521527 elif isinstance (expected , Mapping ):
522528 cls = ApproxMapping
523529 elif _is_numpy_array (expected ):
530+ expected = _as_numpy_array (expected )
524531 cls = ApproxNumpy
525532 elif (
526533 isinstance (expected , Iterable )
@@ -536,16 +543,30 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
536543
537544
538545def _is_numpy_array (obj : object ) -> bool :
539- """Return true if the given object is a numpy array.
546+ """
547+ Return true if the given object is implicitly convertible to ndarray,
548+ and numpy is already imported.
549+ """
550+ return _as_numpy_array (obj ) is not None
551+
540552
541- A special effort is made to avoid importing numpy unless it's really necessary.
553+ def _as_numpy_array (obj : object ) -> Optional ["ndarray" ]:
554+ """
555+ Return an ndarray if the given object is implicitly convertible to ndarray,
556+ and numpy is already imported, otherwise None.
542557 """
543558 import sys
544559
545560 np : Any = sys .modules .get ("numpy" )
546561 if np is not None :
547- return isinstance (obj , np .ndarray )
548- return False
562+ # avoid infinite recursion on numpy scalars, which have __array__
563+ if np .isscalar (obj ):
564+ return None
565+ elif isinstance (obj , np .ndarray ):
566+ return obj
567+ elif hasattr (obj , "__array__" ) or hasattr ("obj" , "__array_interface__" ):
568+ return np .asarray (obj )
569+ return None
549570
550571
551572# builtin pytest.raises helper
0 commit comments