@@ -73,7 +73,11 @@ def assert_gcxs_slicing(s, x):
7373
7474
7575def assert_nnz (s , x ):
76- fill_value = s .fill_value if hasattr (s , "fill_value" ) else _zero_of_dtype (s .dtype , s .device )
76+ from ._settings import NUMPY_DEVICE
77+
78+ fill_value = (
79+ s .fill_value if hasattr (s , "fill_value" ) else _zero_of_dtype (s .dtype , getattr (s , "device" , NUMPY_DEVICE ))
80+ )
7781 assert np .sum (~ equivalent (x , fill_value )) == s .nnz
7882
7983
@@ -442,7 +446,12 @@ def equivalent(x, y, /, loose=False):
442446
443447 from ._common import _coerce_to_supported_dense
444448
445- namespace = array_api_compat .array_namespace (x , y )
449+ try :
450+ xp = array_api_compat .array_namespace (x , y )
451+ except TypeError as e :
452+ if "multiple" in str (e ):
453+ raise e
454+ xp = np
446455 x = _coerce_to_supported_dense (x )
447456 y = _coerce_to_supported_dense (y )
448457 # Can't contain NaNs
@@ -458,9 +467,9 @@ def equivalent(x, y, /, loose=False):
458467 return (x == y ) | ((x != x ) & (y != y ))
459468
460469 if x .size == 0 or y .size == 0 :
461- shape = namespace .broadcast_shapes (x .shape , y .shape )
462- return namespace .empty (shape , dtype = np .bool_ )
463- x , y = namespace .broadcast_arrays (x [..., None ], y [..., None ])
470+ shape = xp .broadcast_shapes (x .shape , y .shape )
471+ return xp .empty (shape , dtype = np .bool_ )
472+ x , y = xp .broadcast_arrays (x [..., None ], y [..., None ])
464473 return (x .astype (dt ).view (np .uint8 ) == y .astype (dt ).view (np .uint8 )).all (axis = - 1 )
465474
466475
0 commit comments