Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polars >=1.0 fixes. #256

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ faer-ext = {version = "0.1.0", features = ["ndarray"]}
ndarray = "0.15.6"
serde = {version = "*", features=["derive"]}
hashbrown = {version = "0.14.2", features=["nightly"]}
numpy = "*"
numpy = "*"
42 changes: 42 additions & 0 deletions functime/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, Iterable

import polars as pl

try:
POLARS_MAJOR_VERSION = int(pl.__version__.split(".", 1)[0])
except ValueError:
POLARS_MAJOR_VERSION = 0

if POLARS_MAJOR_VERSION >= 1:
from polars.plugins import register_plugin_function

rle_fields = {"value": "value", "len": "len"}

else:

def register_plugin_function(
*,
plugin_path: Path | str,
function_name: str,
args: IntoExpr | Iterable[IntoExpr],
kwargs: dict[str, Any] | None = None,
is_elementwise: bool = False,
returns_scalar: bool = False,
cast_to_supertype: bool = False,
):
expr = args[0]
args1 = args[1:]
expr.register_plugin(
lib=plugin_path,
args=args1,
symbol=function_name,
is_elementwise=is_elementwise,
returns_scalar=returns_scalar,
kwargs=kwargs,
cast_to_supertypes=cast_to_supertype,
)

rle_fields = {"value": "values", "len": "lengths"}
77 changes: 51 additions & 26 deletions functime/feature_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@

import logging
import math
from pathlib import Path
from typing import List, Mapping, Optional, Sequence, Union

import numpy as np
import polars as pl
from polars.type_aliases import ClosedInterval
from polars.utils.udfs import _get_shared_lib_location

# from numpy.linalg import lstsq
from scipy.linalg import lstsq
from scipy.signal import find_peaks_cwt, ricker, welch
from scipy.spatial import KDTree

from functime._compat import register_plugin_function, rle_fields
from functime._functime_rust import rs_faer_lstsq1
from functime._utils import warn_is_unstable
from functime.type_aliases import DetrendMethod
Expand All @@ -34,7 +35,12 @@

# from polars.type_aliases import IntoExpr

lib = _get_shared_lib_location(__file__)
try:
from polars.utils.udfs import _get_shared_lib_location

lib = _get_shared_lib_location(__file__)
except ImportError:
lib = Path(__file__).parent


def absolute_energy(x: TIME_SERIES_T) -> FLOAT_INT_EXPR:
Expand Down Expand Up @@ -995,12 +1001,16 @@ def longest_streak_above_mean(x: TIME_SERIES_T) -> INT_EXPR:
"""
y = (x > x.mean()).rle()
if isinstance(x, pl.Series):
result = y.filter(y.struct.field("values")).struct.field("lengths").max()
result = (
y.filter(y.struct.field(rle_fields["value"]))
.struct.field(rle_fields["len"])
.max()
)
return 0 if result is None else result
else:
return (
y.filter(y.struct.field("values"))
.struct.field("lengths")
y.filter(y.struct.field(rle_fields["value"]))
.struct.field(rle_fields["len"])
.max()
.fill_null(0)
)
Expand All @@ -1024,12 +1034,16 @@ def longest_streak_below_mean(x: TIME_SERIES_T) -> INT_EXPR:
"""
y = (x < x.mean()).rle()
if isinstance(x, pl.Series):
result = y.filter(y.struct.field("values")).struct.field("lengths").max()
result = (
y.filter(y.struct.field(rle_fields["value"]))
.struct.field(rle_fields["len"])
.max()
)
return 0 if result is None else result
else:
return (
y.filter(y.struct.field("values"))
.struct.field("lengths")
y.filter(y.struct.field(rle_fields["value"]))
.struct.field(rle_fields["len"])
.max()
.fill_null(0)
)
Expand Down Expand Up @@ -1752,7 +1766,7 @@ def streak_length_stats(x: TIME_SERIES_T, above: bool, threshold: float) -> MAP_
else:
y = (x.diff() <= threshold).rle()

y = y.filter(y.struct.field("values")).struct.field("lengths")
y = y.filter(y.struct.field(rle_fields["value"])).struct.field(rle_fields["len"])
if isinstance(x, pl.Series):
return {
"min": y.min() or 0,
Expand Down Expand Up @@ -1797,12 +1811,16 @@ def longest_streak_above(x: TIME_SERIES_T, threshold: float) -> TIME_SERIES_T:

y = (x.diff() >= threshold).rle()
if isinstance(x, pl.Series):
streak_max = y.filter(y.struct.field("values")).struct.field("lengths").max()
streak_max = (
y.filter(y.struct.field(rle_fields["value"]))
.struct.field(rle_fields["len"])
.max()
)
return 0 if streak_max is None else streak_max
else:
return (
y.filter(y.struct.field("values"))
.struct.field("lengths")
y.filter(y.struct.field(rle_fields["value"]))
.struct.field(rle_fields["len"])
.max()
.fill_null(0)
)
Expand All @@ -1827,12 +1845,16 @@ def longest_streak_below(x: TIME_SERIES_T, threshold: float) -> TIME_SERIES_T:
"""
y = (x.diff() <= threshold).rle()
if isinstance(x, pl.Series):
streak_max = y.filter(y.struct.field("values")).struct.field("lengths").max()
streak_max = (
y.filter(y.struct.field(rle_fields["value"]))
.struct.field(rle_fields["len"])
.max()
)
return 0 if streak_max is None else streak_max
else:
return (
y.filter(y.struct.field("values"))
.struct.field("lengths")
y.filter(y.struct.field(rle_fields["value"]))
.struct.field(rle_fields["len"])
.max()
.fill_null(0)
)
Expand Down Expand Up @@ -2255,9 +2277,10 @@ def lempel_ziv_complexity(
https://github.com/Naereen/Lempel-Ziv_Complexity/tree/master
https://en.wikipedia.org/wiki/Lempel%E2%80%93Ziv_complexity
"""
out = (self._expr > threshold).register_plugin(
lib=lib,
symbol="pl_lempel_ziv_complexity",
out = register_plugin_function(
args=[self._expr > threshold],
plugin_path=lib,
function_name="pl_lempel_ziv_complexity",
is_elementwise=False,
returns_scalar=True,
)
Expand Down Expand Up @@ -2766,16 +2789,17 @@ def cusum(
-------
An expression of the output
"""
return self._expr.register_plugin(
lib=lib,
symbol="cusum",
return register_plugin_function(
args=[self._expr],
plugin_path=lib,
function_name="cusum",
kwargs={
"threshold": threshold,
"drift": drift,
"warmup_period": warmup_period,
},
is_elementwise=False,
cast_to_supertypes=True,
cast_to_supertype=True,
)

def frac_diff(
Expand Down Expand Up @@ -2815,14 +2839,15 @@ def frac_diff(
if min_weight is None and window_size is None:
raise ValueError("Either min_weight or window_size must be specified.")

return self._expr.register_plugin(
lib=lib,
symbol="frac_diff",
return register_plugin_function(
args=[self._expr],
plugin_path=lib,
function_name="frac_diff",
kwargs={
"d": d,
"min_weight": min_weight,
"window_size": window_size,
},
is_elementwise=False,
cast_to_supertypes=True,
cast_to_supertype=True,
)
7 changes: 1 addition & 6 deletions functime/forecasting/snaive.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,7 @@ def _fit(self, y: pl.LazyFrame, X: Optional[pl.LazyFrame] = None):
sp = self.sp
# BUG: Cannot run the following in lazy streaming mode?
# Causes internal error: entered unreachable code
y_pred = (
y.sort(idx_cols)
.set_sorted(idx_cols)
.group_by(entity_col)
.agg(pl.col(target_col).tail(sp))
)
y_pred = y.sort(idx_cols).group_by(entity_col).agg(pl.col(target_col).tail(sp))
artifacts = {"y_pred": y_pred}
return artifacts

Expand Down
6 changes: 4 additions & 2 deletions functime/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,9 +602,10 @@ def optimizer(fun):
lmbds = gb.agg(
PL_NUMERIC_COLS(entity_col, time_col)
.map_elements(
lambda x: boxcox_normmax(x, method=method, optimizer=optimizer)
lambda x: boxcox_normmax(x, method=method, optimizer=optimizer),
returns_scalar=True,
return_dtype=pl.Float64,
)
.cast(pl.Float64())
.name.suffix("__lmbd")
)
# Step 2. Transform
Expand Down Expand Up @@ -667,6 +668,7 @@ def transform(X: pl.LazyFrame) -> pl.LazyFrame:
PL_NUMERIC_COLS(entity_col, time_col)
.map_elements(
lambda x: yeojohnson_normmax(x.to_numpy(), brack),
returns_scalar=True,
return_dtype=pl.Float64,
)
.name.suffix("__lmbd")
Expand Down
14 changes: 4 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,12 @@ def m4_dataset(request):
def load_panel_data(path: str) -> pl.LazyFrame:
return (
pl.read_parquet(path)
.pipe(
lambda df: df.select(
[
pl.col("series").cast(pl.Categorical),
pl.col("time").cast(pl.Int16),
pl.col(df.columns[2]).cast(pl.Float32),
]
)
.with_columns(
pl.col("series").str.replace(" ", "").cast(pl.Categorical),
pl.col("time").cast(pl.Int16),
pl.all().exclude(["series", "time"]).cast(pl.Float32),
)
.with_columns(pl.col("series").str.replace(" ", ""))
.sort(["series", "time"])
.set_sorted(["series", "time"])
)

def update_test_time_ranges(y_train, y_test):
Expand Down
24 changes: 19 additions & 5 deletions tests/test_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,25 @@
from functime.seasonality import add_fourier_terms


@pytest.mark.parametrize("freq,sp", [("1h", 24), ("1d", 365), ("1w", 52)])
def test_fourier_with_dates(freq: str, sp: int):
timestamps = pl.date_range(
date(2020, 1, 1), date(2021, 1, 1), interval=freq, eager=True
)
@pytest.mark.parametrize(
"freq,sp, use_date",
[
("1h", 24, False),
("1d", 365, False),
("1w", 52, False),
("1d", 365, True),
("1w", 52, True),
],
)
def test_fourier_with_timestamps(freq: str, sp: int, use_date: bool):
if use_date:
timestamps = pl.date_range(
date(2020, 1, 1), date(2021, 1, 1), interval=freq, eager=True
)
else:
timestamps = pl.datetime_range(
date(2020, 1, 1), date(2021, 1, 1), interval=freq, eager=True
)
n_timestamps = len(timestamps)
idx_timestamps = timestamps.arg_sort() + 1
entities = pl.concat(
Expand Down
Loading