@@ -73,7 +73,11 @@ def assert_gcxs_slicing(s, x):
7373
7474
7575def assert_nnz (s , x ):
76- fill_value = s .fill_value if hasattr (s , "fill_value" ) else _zero_of_dtype (s .dtype , s .device )
76+ from ._settings import NUMPY_DEVICE
77+
78+ fill_value = (
79+ s .fill_value if hasattr (s , "fill_value" ) else _zero_of_dtype (s .dtype , getattr (s , "device" , NUMPY_DEVICE ))
80+ )
7781 assert np .sum (~ equivalent (x , fill_value )) == s .nnz
7882
7983
@@ -442,7 +446,7 @@ def equivalent(x, y, /, loose=False):
442446
443447 from ._common import _coerce_to_supported_dense
444448
445- namespace = array_api_compat .array_namespace (x , y )
449+ xp = array_api_compat .array_namespace (x , y )
446450 x = _coerce_to_supported_dense (x )
447451 y = _coerce_to_supported_dense (y )
448452 # Can't contain NaNs
@@ -458,9 +462,9 @@ def equivalent(x, y, /, loose=False):
458462 return (x == y ) | ((x != x ) & (y != y ))
459463
460464 if x .size == 0 or y .size == 0 :
461- shape = namespace .broadcast_shapes (x .shape , y .shape )
462- return namespace .empty (shape , dtype = np .bool_ )
463- x , y = namespace .broadcast_arrays (x [..., None ], y [..., None ])
465+ shape = xp .broadcast_shapes (x .shape , y .shape )
466+ return xp .empty (shape , dtype = np .bool_ )
467+ x , y = xp .broadcast_arrays (x [..., None ], y [..., None ])
464468 return (x .astype (dt ).view (np .uint8 ) == y .astype (dt ).view (np .uint8 )).all (axis = - 1 )
465469
466470
0 commit comments