diff --git a/Cargo.toml b/Cargo.toml index 618ad249..2d269a6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,4 +17,4 @@ faer-ext = {version = "0.1.0", features = ["ndarray"]} ndarray = "0.15.6" serde = {version = "*", features=["derive"]} hashbrown = {version = "0.14.2", features=["nightly"]} -numpy = "*" \ No newline at end of file +numpy = "*" diff --git a/functime/_compat.py b/functime/_compat.py new file mode 100644 index 00000000..11324640 --- /dev/null +++ b/functime/_compat.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Iterable + +import polars as pl + +try: + POLARS_MAJOR_VERSION = int(pl.__version__.split(".", 1)[0]) +except ValueError: + POLARS_MAJOR_VERSION = 0 + +if POLARS_MAJOR_VERSION >= 1: + from polars.plugins import register_plugin_function + + rle_fields = {"value": "value", "len": "len"} + +else: + + def register_plugin_function( + *, + plugin_path: Path | str, + function_name: str, + args: IntoExpr | Iterable[IntoExpr], + kwargs: dict[str, Any] | None = None, + is_elementwise: bool = False, + returns_scalar: bool = False, + cast_to_supertype: bool = False, + ): + expr = args[0] + args1 = args[1:] + expr.register_plugin( + lib=plugin_path, + args=args1, + symbol=function_name, + is_elementwise=is_elementwise, + returns_scalar=returns_scalar, + kwargs=kwargs, + cast_to_supertypes=cast_to_supertype, + ) + + rle_fields = {"value": "values", "len": "lengths"} diff --git a/functime/feature_extractors.py b/functime/feature_extractors.py index 39af2092..56f1d5e5 100644 --- a/functime/feature_extractors.py +++ b/functime/feature_extractors.py @@ -2,18 +2,19 @@ import logging import math +from pathlib import Path from typing import List, Mapping, Optional, Sequence, Union import numpy as np import polars as pl from polars.type_aliases import ClosedInterval -from polars.utils.udfs import _get_shared_lib_location # from numpy.linalg import lstsq from scipy.linalg import lstsq from scipy.signal import find_peaks_cwt, ricker, welch from scipy.spatial import KDTree +from functime._compat import register_plugin_function, rle_fields from functime._functime_rust import rs_faer_lstsq1 from functime._utils import warn_is_unstable from functime.type_aliases import DetrendMethod @@ -34,7 +35,12 @@ # from polars.type_aliases import IntoExpr -lib = _get_shared_lib_location(__file__) +try: + from polars.utils.udfs import _get_shared_lib_location + + lib = _get_shared_lib_location(__file__) +except ImportError: + lib = Path(__file__).parent def absolute_energy(x: TIME_SERIES_T) -> FLOAT_INT_EXPR: @@ -995,12 +1001,16 @@ def longest_streak_above_mean(x: TIME_SERIES_T) -> INT_EXPR: """ y = (x > x.mean()).rle() if isinstance(x, pl.Series): - result = y.filter(y.struct.field("values")).struct.field("lengths").max() + result = ( + y.filter(y.struct.field(rle_fields["value"])) + .struct.field(rle_fields["len"]) + .max() + ) return 0 if result is None else result else: return ( - y.filter(y.struct.field("values")) - .struct.field("lengths") + y.filter(y.struct.field(rle_fields["value"])) + .struct.field(rle_fields["len"]) .max() .fill_null(0) ) @@ -1024,12 +1034,16 @@ def longest_streak_below_mean(x: TIME_SERIES_T) -> INT_EXPR: """ y = (x < x.mean()).rle() if isinstance(x, pl.Series): - result = y.filter(y.struct.field("values")).struct.field("lengths").max() + result = ( + y.filter(y.struct.field(rle_fields["value"])) + .struct.field(rle_fields["len"]) + .max() + ) return 0 if result is None else result else: return ( - y.filter(y.struct.field("values")) - .struct.field("lengths") + y.filter(y.struct.field(rle_fields["value"])) + .struct.field(rle_fields["len"]) .max() .fill_null(0) ) @@ -1752,7 +1766,7 @@ def streak_length_stats(x: TIME_SERIES_T, above: bool, threshold: float) -> MAP_ else: y = (x.diff() <= threshold).rle() - y = y.filter(y.struct.field("values")).struct.field("lengths") + y = y.filter(y.struct.field(rle_fields["value"])).struct.field(rle_fields["len"]) if isinstance(x, pl.Series): return { "min": y.min() or 0, @@ -1797,12 +1811,16 @@ def longest_streak_above(x: TIME_SERIES_T, threshold: float) -> TIME_SERIES_T: y = (x.diff() >= threshold).rle() if isinstance(x, pl.Series): - streak_max = y.filter(y.struct.field("values")).struct.field("lengths").max() + streak_max = ( + y.filter(y.struct.field(rle_fields["value"])) + .struct.field(rle_fields["len"]) + .max() + ) return 0 if streak_max is None else streak_max else: return ( - y.filter(y.struct.field("values")) - .struct.field("lengths") + y.filter(y.struct.field(rle_fields["value"])) + .struct.field(rle_fields["len"]) .max() .fill_null(0) ) @@ -1827,12 +1845,16 @@ def longest_streak_below(x: TIME_SERIES_T, threshold: float) -> TIME_SERIES_T: """ y = (x.diff() <= threshold).rle() if isinstance(x, pl.Series): - streak_max = y.filter(y.struct.field("values")).struct.field("lengths").max() + streak_max = ( + y.filter(y.struct.field(rle_fields["value"])) + .struct.field(rle_fields["len"]) + .max() + ) return 0 if streak_max is None else streak_max else: return ( - y.filter(y.struct.field("values")) - .struct.field("lengths") + y.filter(y.struct.field(rle_fields["value"])) + .struct.field(rle_fields["len"]) .max() .fill_null(0) ) @@ -2255,9 +2277,10 @@ def lempel_ziv_complexity( https://github.com/Naereen/Lempel-Ziv_Complexity/tree/master https://en.wikipedia.org/wiki/Lempel%E2%80%93Ziv_complexity """ - out = (self._expr > threshold).register_plugin( - lib=lib, - symbol="pl_lempel_ziv_complexity", + out = register_plugin_function( + args=[self._expr > threshold], + plugin_path=lib, + function_name="pl_lempel_ziv_complexity", is_elementwise=False, returns_scalar=True, ) @@ -2766,16 +2789,17 @@ def cusum( ------- An expression of the output """ - return self._expr.register_plugin( - lib=lib, - symbol="cusum", + return register_plugin_function( + args=[self._expr], + plugin_path=lib, + function_name="cusum", kwargs={ "threshold": threshold, "drift": drift, "warmup_period": warmup_period, }, is_elementwise=False, - cast_to_supertypes=True, + cast_to_supertype=True, ) def frac_diff( @@ -2815,14 +2839,15 @@ def frac_diff( if min_weight is None and window_size is None: raise ValueError("Either min_weight or window_size must be specified.") - return self._expr.register_plugin( - lib=lib, - symbol="frac_diff", + return register_plugin_function( + args=[self._expr], + plugin_path=lib, + function_name="frac_diff", kwargs={ "d": d, "min_weight": min_weight, "window_size": window_size, }, is_elementwise=False, - cast_to_supertypes=True, + cast_to_supertype=True, ) diff --git a/functime/forecasting/snaive.py b/functime/forecasting/snaive.py index 7053b271..9f8175d8 100644 --- a/functime/forecasting/snaive.py +++ b/functime/forecasting/snaive.py @@ -30,12 +30,7 @@ def _fit(self, y: pl.LazyFrame, X: Optional[pl.LazyFrame] = None): sp = self.sp # BUG: Cannot run the following in lazy streaming mode? # Causes internal error: entered unreachable code - y_pred = ( - y.sort(idx_cols) - .set_sorted(idx_cols) - .group_by(entity_col) - .agg(pl.col(target_col).tail(sp)) - ) + y_pred = y.sort(idx_cols).group_by(entity_col).agg(pl.col(target_col).tail(sp)) artifacts = {"y_pred": y_pred} return artifacts diff --git a/functime/preprocessing.py b/functime/preprocessing.py index dbf032dc..90f9d34f 100644 --- a/functime/preprocessing.py +++ b/functime/preprocessing.py @@ -602,9 +602,10 @@ def optimizer(fun): lmbds = gb.agg( PL_NUMERIC_COLS(entity_col, time_col) .map_elements( - lambda x: boxcox_normmax(x, method=method, optimizer=optimizer) + lambda x: boxcox_normmax(x, method=method, optimizer=optimizer), + returns_scalar=True, + return_dtype=pl.Float64, ) - .cast(pl.Float64()) .name.suffix("__lmbd") ) # Step 2. Transform @@ -667,6 +668,7 @@ def transform(X: pl.LazyFrame) -> pl.LazyFrame: PL_NUMERIC_COLS(entity_col, time_col) .map_elements( lambda x: yeojohnson_normmax(x.to_numpy(), brack), + returns_scalar=True, return_dtype=pl.Float64, ) .name.suffix("__lmbd") diff --git a/tests/conftest.py b/tests/conftest.py index 3669befc..2e5b7170 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -87,18 +87,12 @@ def m4_dataset(request): def load_panel_data(path: str) -> pl.LazyFrame: return ( pl.read_parquet(path) - .pipe( - lambda df: df.select( - [ - pl.col("series").cast(pl.Categorical), - pl.col("time").cast(pl.Int16), - pl.col(df.columns[2]).cast(pl.Float32), - ] - ) + .with_columns( + pl.col("series").str.replace(" ", "").cast(pl.Categorical), + pl.col("time").cast(pl.Int16), + pl.all().exclude(["series", "time"]).cast(pl.Float32), ) - .with_columns(pl.col("series").str.replace(" ", "")) .sort(["series", "time"]) - .set_sorted(["series", "time"]) ) def update_test_time_ranges(y_train, y_test): diff --git a/tests/test_fourier.py b/tests/test_fourier.py index e35b5cc0..40b55e97 100644 --- a/tests/test_fourier.py +++ b/tests/test_fourier.py @@ -10,11 +10,25 @@ from functime.seasonality import add_fourier_terms -@pytest.mark.parametrize("freq,sp", [("1h", 24), ("1d", 365), ("1w", 52)]) -def test_fourier_with_dates(freq: str, sp: int): - timestamps = pl.date_range( - date(2020, 1, 1), date(2021, 1, 1), interval=freq, eager=True - ) +@pytest.mark.parametrize( + "freq,sp, use_date", + [ + ("1h", 24, False), + ("1d", 365, False), + ("1w", 52, False), + ("1d", 365, True), + ("1w", 52, True), + ], +) +def test_fourier_with_timestamps(freq: str, sp: int, use_date: bool): + if use_date: + timestamps = pl.date_range( + date(2020, 1, 1), date(2021, 1, 1), interval=freq, eager=True + ) + else: + timestamps = pl.datetime_range( + date(2020, 1, 1), date(2021, 1, 1), interval=freq, eager=True + ) n_timestamps = len(timestamps) idx_timestamps = timestamps.arg_sort() + 1 entities = pl.concat( diff --git a/tests/test_tsfresh.py b/tests/test_tsfresh.py index 924604ef..30693229 100644 --- a/tests/test_tsfresh.py +++ b/tests/test_tsfresh.py @@ -74,19 +74,19 @@ ( [0, 0, 0], [True, 0], - [2, 2, 2.0, None, 2.0, 2.0, 2.0, 2], + [2.0, 2, 2.0, None, 2.0, 2.0, 2.0, 2], {"check_dtype": False}, ), ( [0, 0, 0], [False, 0], - [2, 2, 2.0, None, 2.0, 2.0, 2.0, 2], + [2.0, 2, 2.0, None, 2.0, 2.0, 2.0, 2], {"check_dtype": False}, ), ( [0, 0, 0], [False, 1], - [2, 2, 2.0, None, 2.0, 2.0, 2.0, 2], + [2.0, 2, 2.0, None, 2.0, 2.0, 2.0, 2], {"check_dtype": False}, ), # won't work with no matches - error @@ -94,27 +94,27 @@ ( [0, 0, 0], [True, 1], - [0, None, None, None, None, None, None, None], + [0.0, None, None, None, None, None, None, None], {"check_dtype": False}, ), ( [0, 1, 1, 0, 2, 2, 2], [True, 0], - [2, 3, 2.5, 0.707107, 2.0, 2.5, 3.0, 2], + [2.0, 3, 2.5, 0.707107, 2.0, 2.5, 3.0, 2], {"check_dtype": False}, ), # # floats ( - [0, 1.5, 1.5, 0, 2.5, 2.5, 2.5], + [0.0, 1.5, 1.5, 0, 2.5, 2.5, 2.5], [True, 0], - [2, 3, 2.5, 0.707107, 2.0, 2.5, 3.0, 2], + [2.0, 3, 2.5, 0.707107, 2.0, 2.5, 3.0, 2], {"check_dtype": False}, ), # # negative floats ( - [0, -1.5, -1.5, 0, -2.5, -2.5, -2.5], + [0.0, -1.5, -1.5, 0, -2.5, -2.5, -2.5], [False, 0], - [2, 3, 2.5, 0.707107, 2.0, 2.5, 3.0, 2], + [2.0, 3, 2.5, 0.707107, 2.0, 2.5, 3.0, 2], {"check_dtype": False}, ), ], @@ -167,12 +167,12 @@ def test_streak_length_stats(S, params, res, k): ([0, 0, 0], [0], {}), ([0, 1, 2], [1], {}), ([2, 1, 0], [1], {}), - ([0, 1.5, 2, 2.5], [5 / 6], {}), + ([0.0, 1.5, 2, 2.5], [5 / 6], {}), ([2.5, 2, 1.5, 0], [5 / 6], {}), ([-1, 2, 3, 4], [5 / 3], {}), # # # this is a tough call, can potentially ask about this. - ([-1, 1, 2, float("inf")], [float("inf")], {}), - ([-1, 1, 2, -float("inf")], [float("inf")], {}), + ([-1.0, 1, 2, float("inf")], [float("inf")], {}), + ([-1.0, 1, 2, -float("inf")], [float("inf")], {}), ([float("inf"), -1, 1, 2], [float("inf")], {}), ], ) @@ -199,12 +199,12 @@ def test_mean_abs_change(S, res, k): [ ([0, 0, 0], [0], {}), ([0, 1, 2], [1], {}), - ([0, 1.5, 2, 2.5], [5 / 6], {}), + ([0.0, 1.5, 2, 2.5], [5 / 6], {}), ([2.5, 2, 1.5, 0], [-5 / 6], {}), ([-1, 2, 3, 4], [5 / 3], {}), - ([-1, 1.3, 5.3, 4.5], [11 / 6], {}), - ([-1, 1, 2, float("inf")], [float("inf")], {}), - ([-1, 1, 2, -float("inf")], [-float("inf")], {}), + ([-1.0, 1.3, 5.3, 4.5], [11 / 6], {}), + ([-1.0, 1, 2, float("inf")], [float("inf")], {}), + ([-1.0, 1, 2, -float("inf")], [-float("inf")], {}), ([1], [0], {"check_dtype": False}), ([], [0], {"check_dtype": False}), ], @@ -231,9 +231,9 @@ def test_mean_change(S, res, k): [ ([0, 0, 0], [False]), ([0, 1, 2], [True]), - ([0, 1.5, 2, 2.5, 50], [True]), + ([0.0, 1.5, 2, 2.5, 50], [True]), ([-1, 2, 3, 4], [True]), - ([-1, 1.3, 5.3, 4.5], [True]), + ([-1.0, 1.3, 5.3, 4.5], [True]), ], ) def test_var_gt_std(S, res): @@ -255,9 +255,9 @@ def test_var_gt_std(S, res): [ ([0, 0, 0], [False]), ([0, 1, 2], [True]), - ([0, 1.5, 2, 2.5, 50], [True]), + ([0.0, 1.5, 2, 2.5, 50], [True]), ([-1, 2, 3, 4], [True]), - ([-1, 1.3, 5.3, 4.5], [True]), + ([-1.0, 1.3, 5.3, 4.5], [True]), ], ) def test_large_standard_deviation(S, res): @@ -280,7 +280,7 @@ def test_large_standard_deviation(S, res): ([0, 1, 2], [0.816497]), ([9, 7, 10000], [1.410825]), ([-1, 2, 3, 4], [0.93541434]), - ([-1, 1.3, 5.3, 4.5], [1.00049]), + ([-1.0, 1.3, 5.3, 4.5], [1.00049]), ], ) def test_variation_coefficient(S, res): @@ -303,8 +303,8 @@ def test_variation_coefficient(S, res): ([-5, 0, 1], [2]), ([0], [1]), ([-1, 2, 3, 4], [3]), - ([-1, 1.3], [1]), - ([1, float("inf")], [1]), + ([-1.0, 1.3], [1]), + ([1.0, float("inf")], [1]), ([1, None], [1]), ], ) @@ -388,7 +388,7 @@ def test_linear_trend(S, res, k): ([-5, 0, 1], [26]), ([0], [0]), ([-1, 2, -3], [14]), - ([-1, 1.3], [2.6900000000000004]), + ([-1.0, 1.3], [2.6900000000000004]), ([1], [1]), ], ) @@ -941,15 +941,15 @@ def test_benford_correlation(): def test_longest_streak_below_mean(S, res): assert_frame_equal( pl.DataFrame({"a": S}).select( - longest_streak_below_mean(pl.col("a")).alias("lengths") + longest_streak_below_mean(pl.col("a")).alias("len").cast(pl.UInt32) ), - pl.DataFrame(pl.Series("lengths", res, dtype=pl.Int32)), + pl.DataFrame(pl.Series("len", res, dtype=pl.UInt32)), ) assert_frame_equal( pl.LazyFrame({"a": S}) - .select(longest_streak_below_mean(pl.col("a")).alias("lengths")) + .select(longest_streak_below_mean(pl.col("a")).alias("len").cast(pl.UInt32)) .collect(), - pl.DataFrame(pl.Series("lengths", res, dtype=pl.Int32)), + pl.DataFrame(pl.Series("len", res, dtype=pl.UInt32)), ) @@ -967,15 +967,15 @@ def test_longest_streak_below_mean(S, res): def test_longest_streak_above_mean(S, res): assert_frame_equal( pl.DataFrame({"a": S}).select( - longest_streak_above_mean(pl.col("a")).alias("lengths") + longest_streak_above_mean(pl.col("a").alias("len").cast(pl.UInt32)) ), - pl.DataFrame(pl.Series("lengths", res, dtype=pl.Int32)), + pl.DataFrame(pl.Series("len", res, dtype=pl.UInt32)), ) assert_frame_equal( pl.LazyFrame({"a": S}) - .select(longest_streak_above_mean(pl.col("a")).alias("lengths")) + .select(longest_streak_above_mean(pl.col("a")).alias("len").cast(pl.UInt32)) .collect(), - pl.DataFrame(pl.Series("lengths", res, dtype=pl.Int32)), + pl.DataFrame(pl.Series("len", res, dtype=pl.UInt32)), ) @@ -1021,7 +1021,7 @@ def test_ratio_beyond_r_sigma(S, res, ratio): "S, res", [ ([1, 1, 2, 3, 4], [0.8]), - ([1, 1.5, 2, 3], [1.0]), + ([1.0, 1.5, 2, 3], [1.0]), ([1], [1.0]), ([1.111, -2.45, 1.111, 2.45], [0.75]), ], @@ -1123,7 +1123,7 @@ def test_mean_n_absolute_max(S, n_max, res): "S, res", [ ([1, 1, 2, 3, 4], [0.25]), - ([1, 1.5, 2, 3], [0]), + ([1.0, 1.5, 2, 3], [0]), ([1], [0]), ([1.111, -2.45, 1.111, 2.45], [1.0 / 3.0]), ], @@ -1145,7 +1145,7 @@ def test_percent_reoccuring_values(S, res): "S, res", [ ([1, 1, 2, 3, 4, 4], [10]), - ([1, 1.5, 2, 3], [0.0]), + ([1.0, 1.5, 2, 3], [0.0]), ([1], [0]), ([1.111, -2.45, 1.111, 2.45], [2.222]), ], @@ -1165,7 +1165,7 @@ def test_sum_reoccurring_points(S, res): "S, res", [ ([1, 1, 2, 3, 4, 4], [5]), - ([1, 1.5, 2, 3], [0.0]), + ([1.0, 1.5, 2, 3], [0.0]), ([1], [0]), ([1.111, -2.45, 1.111, 2.45], [1.111]), ], @@ -1185,7 +1185,7 @@ def test_sum_reoccurring_values(S, res): "S, res", [ ([1, 1, 2, 3, 4], [0.4]), - ([1, 1.5, 2, 3], [0]), + ([1.0, 1.5, 2, 3], [0]), ([1], [0]), ([1.111, -2.45, 1.111, 2.45], [0.5]), ([], [np.nan]), @@ -1574,7 +1574,7 @@ def test_range_over_mean_and_range(S, res): ([10, 20, 20, 30], [0], 0), ([10, 20, 20, 30], [1], 15), ([10, -10, 10, -10], [3], 0), - ([-10, 10.1, -10, 10.1, -10], [4], 10), + ([-10.0, 10.1, -10, 10.1, -10], [4], 10), ([10, 11, 12, 10, 11], [3], 10.5), ], )