Skip to content

Commit

Permalink
Fix upcast on in-place. (#598)
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi authored Jul 1, 2023
1 parent fb4da47 commit 096f9b0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
11 changes: 11 additions & 0 deletions sparse/_sparse_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,17 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
)

if out is not None:
test_args = [
np.empty(1, dtype=a.dtype) if hasattr(a, "dtype") else [a]
for a in inputs
]
test_kwargs = kwargs.copy()
if method == "reduce":
test_kwargs["axis"] = None
test_out = tuple(np.empty(1, dtype=a.dtype) for a in out)
if len(test_out) == 1:
test_out = test_out[0]
getattr(ufunc, method)(*test_args, out=test_out, **test_kwargs)
kwargs["dtype"] = out[0].dtype

if method == "outer":
Expand Down
7 changes: 7 additions & 0 deletions sparse/tests/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,3 +720,10 @@ def test_no_deprecation_warning():
a = np.array([1, 2])
s = sparse.COO(a, a, shape=(3,))
s == s


# Regression test for gh-587
def test_no_out_upcast():
a = sparse.COO([[0, 1], [0, 1]], [1, 1], shape=(2, 2))
with pytest.raises(TypeError):
a *= 0.5

0 comments on commit 096f9b0

Please sign in to comment.