Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: replacing lambda with named function to make model pickable #1594

Merged
merged 6 commits into from
Feb 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions darts/models/forecasting/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
----------------------
"""

from typing import Optional
from typing import Callable, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -238,7 +238,7 @@ def __init__(
pd.Timestamp attributes that are relevant for the seasonality automatically.
trend
If set, indicates what kind of detrending will be applied before performing DFT.
Possible values: 'poly' or 'exp', for polynomial trend, or exponential trend, respectively.
Possible values: 'poly', 'exp' or None, for polynomial trend, exponential trend or no trend, respectively.
trend_poly_degree
The degree of the polynomial that will be used for detrending, if `trend='poly'`.

Expand Down Expand Up @@ -269,6 +269,20 @@ def __str__(self):
+ ")"
)

def _exp_trend(self, x) -> Callable:
"""Helper function, used to make FFT model pickable."""
return np.exp(self.trend_coefficients[1]) * np.exp(
self.trend_coefficients[0] * x
)

def _poly_trend(self, trend_coefficients) -> Callable:
"""Helper function, for consistency with the other trends"""
return np.poly1d(trend_coefficients)

def _null_trend(self, x) -> Callable:
"""Helper function, used to make FFT model pickable."""
return 0

def fit(self, series: TimeSeries):
series = fill_missing_values(series)
super().fit(series)
Expand All @@ -277,19 +291,18 @@ def fit(self, series: TimeSeries):

# determine trend
if self.trend == "poly":
trend_coefficients = np.polyfit(
self.trend_coefficients = np.polyfit(
range(len(series)), series.univariate_values(), self.trend_poly_degree
)
self.trend_function = np.poly1d(trend_coefficients)
self.trend_function = self._poly_trend(self.trend_coefficients)
elif self.trend == "exp":
trend_coefficients = np.polyfit(
self.trend_coefficients = np.polyfit(
range(len(series)), np.log(series.univariate_values()), 1
)
self.trend_function = lambda x: np.exp(trend_coefficients[1]) * np.exp(
trend_coefficients[0] * x
)
self.trend_function = self._exp_trend
else:
self.trend_function = lambda x: 0
self.trend_coefficients = None
self.trend_function = self._null_trend

# subtract trend
detrended_values = series.univariate_values() - self.trend_function(
Expand Down