Skip to content

Commit

Permalink
fix: pickle compatible implementation of downsamplers (#44)
Browse files Browse the repository at this point in the history
* 🔥 first pickle compatible implementation of downsamplers

* 🖊️ review code

* 🙈 fix linting issue

---------

Co-authored-by: jvdd <boebievdd@gmail.com>
  • Loading branch information
jonasvdd and jvdd authored Apr 14, 2023
1 parent ae60a28 commit 9297e50
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 18 deletions.
15 changes: 14 additions & 1 deletion tests/test_tsdownsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def generate_all_downsamplers() -> Iterable[AbstractDownsampler]:


@pytest.mark.parametrize("downsampler", generate_all_downsamplers())
def test_serialization(downsampler: AbstractDownsampler):
def test_serialization_copy(downsampler: AbstractDownsampler):
"""Test serialization."""
from copy import copy, deepcopy

Expand All @@ -53,6 +53,19 @@ def test_serialization(downsampler: AbstractDownsampler):
assert np.all(orig_downsampled == ddc_downsampled)


@pytest.mark.parametrize("downsampler", generate_all_downsamplers())
def test_serialization_pickle(downsampler: AbstractDownsampler):
"""Test serialization."""
import pickle

dc = pickle.loads(pickle.dumps(downsampler))

arr = np.arange(10_000)
orig_downsampled = downsampler.downsample(arr, n_out=100)
dc_downsampled = dc.downsample(arr, n_out=100)
assert np.all(orig_downsampled == dc_downsampled)


@pytest.mark.parametrize("downsampler", generate_rust_downsamplers())
def test_rust_downsampler(downsampler: AbstractDownsampler):
"""Test the Rust downsamplers."""
Expand Down
20 changes: 12 additions & 8 deletions tsdownsample/downsamplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@


class MinMaxDownsampler(AbstractRustDownsampler):
def __init__(self) -> None:
super().__init__(_tsdownsample_rs.minmax)
@property
def rust_mod(self):
return _tsdownsample_rs.minmax

@staticmethod
def _check_valid_n_out(n_out: int):
Expand All @@ -21,8 +22,9 @@ def _check_valid_n_out(n_out: int):


class M4Downsampler(AbstractRustDownsampler):
def __init__(self):
super().__init__(_tsdownsample_rs.m4)
@property
def rust_mod(self):
return _tsdownsample_rs.m4

@staticmethod
def _check_valid_n_out(n_out: int):
Expand All @@ -32,13 +34,15 @@ def _check_valid_n_out(n_out: int):


class LTTBDownsampler(AbstractRustDownsampler):
def __init__(self):
super().__init__(_tsdownsample_rs.lttb)
@property
def rust_mod(self):
return _tsdownsample_rs.lttb


class MinMaxLTTBDownsampler(AbstractRustDownsampler):
def __init__(self):
super().__init__(_tsdownsample_rs.minmaxlttb)
@property
def rust_mod(self):
return _tsdownsample_rs.minmaxlttb

def downsample(
self, *args, n_out: int, minmax_ratio: int = 30, parallel: bool = False, **_
Expand Down
41 changes: 32 additions & 9 deletions tsdownsample/downsampling_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,24 +143,47 @@ def downsample(self, *args, n_out: int, **kwargs): # x and y are optional
class AbstractRustDownsampler(AbstractDownsampler, ABC):
"""RustDownsampler interface-class, subclassed by concrete downsamplers."""

def __init__(self, resampling_mod: ModuleType):
def __init__(self):
super().__init__(_rust_dtypes, _y_rust_dtypes) # same for x and y
self.rust_mod = resampling_mod

# Store the single core sub module
self.mod_single_core = self.rust_mod.scalar
@property
def rust_mod(self) -> ModuleType:
"""The compiled Rust module for the current downsampler."""
raise NotImplementedError

@property
def mod_single_core(self) -> ModuleType:
"""Get the single-core Rust module.
Returns
-------
ModuleType
If SIMD compiled module is available, that one is returned. Otherwise, the
scalar compiled module is returned.
"""
if hasattr(self.rust_mod, "simd"):
# use SIMD implementation if available
self.mod_single_core = self.rust_mod.simd
return self.rust_mod.simd
return self.rust_mod.scalar

@property
def mod_multi_core(self) -> Union[ModuleType, None]:
"""Get the multi-core Rust module.
# Store the multi-core sub module (if present)
self.mod_multi_core = None # no multi-core implementation (default)
Returns
-------
ModuleType or None
If SIMD parallel compiled module is available, that one is returned.
Otherwise, the scalar parallel compiled module is returned.
If no parallel compiled module is available, None is returned.
"""
if hasattr(self.rust_mod, "simd_parallel"):
# use SIMD implementation if available
self.mod_multi_core = self.rust_mod.simd_parallel
return self.rust_mod.simd_parallel
elif hasattr(self.rust_mod, "scalar_parallel"):
# use scalar implementation if available (when no SIMD available)
self.mod_multi_core = self.rust_mod.scalar_parallel
return self.rust_mod.scalar_parallel
return None # no parallel compiled module available

@staticmethod
def _switch_mod_with_y(
Expand Down

0 comments on commit 9297e50

Please sign in to comment.