From 790f6d5aac9156b6a8278a1f6e2b8594196be0da Mon Sep 17 00:00:00 2001 From: James Webber Date: Fri, 5 Jan 2024 10:23:40 -0500 Subject: [PATCH] Preserve gcxs compression (#601) --- sparse/_umath.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sparse/_umath.py b/sparse/_umath.py index 15c6d3e9..6675d309 100644 --- a/sparse/_umath.py +++ b/sparse/_umath.py @@ -412,6 +412,7 @@ def __init__(self, func, *args, **kwargs): processed_args = [] out_type = GCXS + out_kwargs = {} sparse_args = [arg for arg in args if isinstance(arg, SparseArray)] @@ -421,6 +422,8 @@ def __init__(self, func, *args, **kwargs): out_type = DOK elif all(isinstance(arg, GCXS) for arg in sparse_args): out_type = GCXS + if len({arg.compressed_axes for arg in sparse_args}) == 1: + out_kwargs["compressed_axes"] = sparse_args[0].compressed_axes else: out_type = COO @@ -441,6 +444,7 @@ def __init__(self, func, *args, **kwargs): return self.out_type = out_type + self.out_kwargs = out_kwargs self.args = tuple(processed_args) self.func = func self.dtype = kwargs.pop("dtype", None) @@ -497,7 +501,7 @@ def get_result(self): shape=self.shape, has_duplicates=False, fill_value=self.fill_value, - ).asformat(self.out_type) + ).asformat(self.out_type, **self.out_kwargs) def _get_fill_value(self): """