@@ -195,6 +195,8 @@ class COO(SparseArray, NDArrayOperatorsMixin): # lgtm [py/missing-equals]
195195
196196 __array_priority__ = 12
197197
198+ __array_members__ = ("data" , "coords" , "fill_value" )
199+
198200 def __init__ (
199201 self ,
200202 coords ,
@@ -207,6 +209,8 @@ def __init__(
207209 fill_value = None ,
208210 idx_dtype = None ,
209211 ):
212+ from .._common import _coerce_to_supported_dense
213+
210214 if isinstance (coords , COO ):
211215 self ._make_shallow_copy_of (coords )
212216 if data is not None or shape is not None :
@@ -226,8 +230,8 @@ def __init__(
226230 self .enable_caching ()
227231 return
228232
229- self .data = np . asarray (data )
230- self .coords = np . asarray (coords )
233+ self .data = _coerce_to_supported_dense (data )
234+ self .coords = _coerce_to_supported_dense (coords )
231235
232236 if self .coords .ndim == 1 :
233237 if self .coords .size == 0 and shape is not None :
@@ -236,7 +240,7 @@ def __init__(
236240 self .coords = self .coords [None , :]
237241
238242 if self .data .ndim == 0 :
239- self .data = np .broadcast_to (self .data , self .coords .shape [1 ])
243+ self .data = self . _component_namespace .broadcast_to (self .data , self .coords .shape [1 ])
240244
241245 if self .data .ndim != 1 :
242246 raise ValueError ("`data` must be a scalar or 1-dimensional." )
@@ -251,7 +255,9 @@ def __init__(
251255 shape = tuple (shape )
252256
253257 if shape and not self .coords .size :
254- self .coords = np .zeros ((len (shape ) if isinstance (shape , Iterable ) else 1 , 0 ), dtype = np .intp )
258+ self .coords = self ._component_namespace .zeros (
259+ (len (shape ) if isinstance (shape , Iterable ) else 1 , 0 ), dtype = np .intp
260+ )
255261 super ().__init__ (shape , fill_value = fill_value )
256262 if idx_dtype :
257263 if not can_store (idx_dtype , max (shape )):
@@ -369,7 +375,7 @@ def from_numpy(cls, x, fill_value=None, idx_dtype=None):
369375 x = np .asanyarray (x ).view (type = np .ndarray )
370376
371377 if fill_value is None :
372- fill_value = _zero_of_dtype (x .dtype ) if x .shape else x
378+ fill_value = _zero_of_dtype (x .dtype , x . device ) if x .shape else x
373379
374380 coords = np .atleast_2d (np .flatnonzero (~ equivalent (x , fill_value )))
375381 data = x .ravel ()[tuple (coords )]
@@ -407,7 +413,9 @@ def todense(self):
407413 >>> np.array_equal(x, x2)
408414 True
409415 """
410- x = np .full (self .shape , self .fill_value , self .dtype )
416+ x = self ._component_namespace .full (
417+ self .shape , fill_value = self .fill_value , dtype = self .dtype , device = self .data .device
418+ )
411419
412420 coords = tuple ([self .coords [i , :] for i in range (self .ndim )])
413421 data = self .data
@@ -446,14 +454,16 @@ def from_scipy_sparse(cls, x, /, *, fill_value=None):
446454 >>> np.array_equal(x.todense(), s.todense())
447455 True
448456 """
457+ import array_api_compat
458+
449459 x = x .asformat ("coo" )
450460 if not x .has_canonical_format :
451461 x .eliminate_zeros ()
452462 x .sum_duplicates ()
453463
454- coords = np . empty (( 2 , x . nnz ), dtype = x . row . dtype )
455- coords [ 0 , :] = x . row
456- coords [ 1 , :] = x . col
464+ xp = array_api_compat . array_namespace ( x . data )
465+
466+ coords = xp . stack (( x . row , x . col ))
457467 return COO (
458468 coords ,
459469 x .data ,
@@ -1184,14 +1194,19 @@ def to_scipy_sparse(self, /, *, accept_fv=None):
11841194 - [`sparse.COO.tocsr`][] : Convert to a [`scipy.sparse.csr_matrix`][].
11851195 - [`sparse.COO.tocsc`][] : Convert to a [`scipy.sparse.csc_matrix`][].
11861196 """
1187- import scipy .sparse
1197+ from .._settings import NUMPY_DEVICE
1198+
1199+ if self .device == NUMPY_DEVICE :
1200+ import scipy .sparse as sps
1201+ else :
1202+ import cupyx .scipy .sparse as sps
11881203
11891204 check_fill_value (self , accept_fv = accept_fv )
11901205
11911206 if self .ndim != 2 :
11921207 raise ValueError ("Can only convert a 2-dimensional array to a Scipy sparse matrix." )
11931208
1194- result = scipy . sparse .coo_matrix ((self .data , (self .coords [0 ], self .coords [1 ])), shape = self .shape )
1209+ result = sps .coo_matrix ((self .data , (self .coords [0 ], self .coords [1 ])), shape = self .shape )
11951210 result .has_canonical_format = True
11961211 return result
11971212
@@ -1307,10 +1322,10 @@ def _sort_indices(self):
13071322 """
13081323 linear = self .linear_loc ()
13091324
1310- if (np .diff (linear ) >= 0 ).all (): # already sorted
1325+ if (self . _component_namespace .diff (linear ) >= 0 ).all (): # already sorted
13111326 return
13121327
1313- order = np .argsort (linear , kind = "mergesort" )
1328+ order = self . _component_namespace .argsort (linear , kind = "mergesort" )
13141329 self .coords = self .coords [:, order ]
13151330 self .data = self .data [order ]
13161331
@@ -1336,16 +1351,16 @@ def _sum_duplicates(self):
13361351 # Inspired by scipy/sparse/coo.py::sum_duplicates
13371352 # See https://github.com/scipy/scipy/blob/main/LICENSE.txt
13381353 linear = self .linear_loc ()
1339- unique_mask = np .diff (linear ) != 0
1354+ unique_mask = self . _component_namespace .diff (linear ) != 0
13401355
13411356 if unique_mask .sum () == len (unique_mask ): # already unique
13421357 return
13431358
1344- unique_mask = np .append (True , unique_mask )
1359+ unique_mask = self . _component_namespace .append (True , unique_mask )
13451360
13461361 coords = self .coords [:, unique_mask ]
1347- (unique_inds ,) = np .nonzero (unique_mask )
1348- data = np .add .reduceat (self .data , unique_inds , dtype = self .data .dtype )
1362+ (unique_inds ,) = self . _component_namespace .nonzero (unique_mask )
1363+ data = self . _component_namespace .add .reduceat (self .data , unique_inds , dtype = self .data .dtype )
13491364
13501365 self .data = data
13511366 self .coords = coords
0 commit comments