@@ -195,6 +195,8 @@ class COO(SparseArray, NDArrayOperatorsMixin): # lgtm [py/missing-equals]
195195
196196 __array_priority__ = 12
197197
198+ __array_members__ = ("data" , "coords" )
199+
198200 def __init__ (
199201 self ,
200202 coords ,
@@ -207,6 +209,8 @@ def __init__(
207209 fill_value = None ,
208210 idx_dtype = None ,
209211 ):
212+ from .._common import _coerce_to_supported_dense
213+
210214 if isinstance (coords , COO ):
211215 self ._make_shallow_copy_of (coords )
212216 if data is not None or shape is not None :
@@ -226,8 +230,8 @@ def __init__(
226230 self .enable_caching ()
227231 return
228232
229- self .data = np . asarray (data )
230- self .coords = np . asarray (coords )
233+ self .data = _coerce_to_supported_dense (data )
234+ self .coords = _coerce_to_supported_dense (coords )
231235
232236 if self .coords .ndim == 1 :
233237 if self .coords .size == 0 and shape is not None :
@@ -236,7 +240,7 @@ def __init__(
236240 self .coords = self .coords [None , :]
237241
238242 if self .data .ndim == 0 :
239- self .data = np .broadcast_to (self .data , self .coords .shape [1 ])
243+ self .data = self . _component_namespace .broadcast_to (self .data , self .coords .shape [1 ])
240244
241245 if self .data .ndim != 1 :
242246 raise ValueError ("`data` must be a scalar or 1-dimensional." )
@@ -251,7 +255,9 @@ def __init__(
251255 shape = tuple (shape )
252256
253257 if shape and not self .coords .size :
254- self .coords = np .zeros ((len (shape ) if isinstance (shape , Iterable ) else 1 , 0 ), dtype = np .intp )
258+ self .coords = self ._component_namespace .zeros (
259+ (len (shape ) if isinstance (shape , Iterable ) else 1 , 0 ), dtype = np .intp
260+ )
255261 super ().__init__ (shape , fill_value = fill_value )
256262 if idx_dtype :
257263 if not can_store (idx_dtype , max (shape )):
@@ -1307,10 +1313,10 @@ def _sort_indices(self):
13071313 """
13081314 linear = self .linear_loc ()
13091315
1310- if (np .diff (linear ) >= 0 ).all (): # already sorted
1316+ if (self . _component_namespace .diff (linear ) >= 0 ).all (): # already sorted
13111317 return
13121318
1313- order = np .argsort (linear , kind = "mergesort" )
1319+ order = self . _component_namespace .argsort (linear , kind = "mergesort" )
13141320 self .coords = self .coords [:, order ]
13151321 self .data = self .data [order ]
13161322
@@ -1336,16 +1342,16 @@ def _sum_duplicates(self):
13361342 # Inspired by scipy/sparse/coo.py::sum_duplicates
13371343 # See https://github.com/scipy/scipy/blob/main/LICENSE.txt
13381344 linear = self .linear_loc ()
1339- unique_mask = np .diff (linear ) != 0
1345+ unique_mask = self . _component_namespace .diff (linear ) != 0
13401346
13411347 if unique_mask .sum () == len (unique_mask ): # already unique
13421348 return
13431349
1344- unique_mask = np .append (True , unique_mask )
1350+ unique_mask = self . _component_namespace .append (True , unique_mask )
13451351
13461352 coords = self .coords [:, unique_mask ]
1347- (unique_inds ,) = np .nonzero (unique_mask )
1348- data = np .add .reduceat (self .data , unique_inds , dtype = self .data .dtype )
1353+ (unique_inds ,) = self . _component_namespace .nonzero (unique_mask )
1354+ data = self . _component_namespace .add .reduceat (self .data , unique_inds , dtype = self .data .dtype )
13491355
13501356 self .data = data
13511357 self .coords = coords
0 commit comments