diff --git a/.gitignore b/.gitignore index 625945d..c0ba214 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ examples/personal notebooks/* .vscode/* poetry.lock +x.py # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/docs/deprecation.md b/docs/deprecation.md new file mode 100644 index 0000000..fccd4f4 --- /dev/null +++ b/docs/deprecation.md @@ -0,0 +1 @@ +# Deprecation policy \ No newline at end of file diff --git a/extension_templates/effect.py b/extension_templates/effect.py index 0e21a8a..f08f8e9 100644 --- a/extension_templates/effect.py +++ b/extension_templates/effect.py @@ -5,7 +5,7 @@ import jax.numpy as jnp import pandas as pd -from prophetverse.effects.base import BaseEffect, Stage +from prophetverse.effects.base import BaseEffect from prophetverse.utils.frame_to_array import series_to_tensor_or_array @@ -19,13 +19,16 @@ class MyEffectName(BaseEffect): # If no columns are found, should # _predict be skipped? "skip_predict_if_no_match": True, + # Should only the indexes related to the forecasting horizon be passed to + # _transform? + "filter_indexes_with_forecating_horizon_at_transform": True, } def __init__(self, param1: Any, param2: Any): self.param1 = param1 self.param2 = param2 - def _fit(self, X: pd.DataFrame, scale: float = 1.0): + def _fit(self, y: pd.DataFrame, X: pd.DataFrame, scale: float = 1.0): """Customize the initialization of the effect. This method is called by the `fit()` method and can be overridden by @@ -33,19 +36,25 @@ def _fit(self, X: pd.DataFrame, scale: float = 1.0): Parameters ---------- + y : pd.DataFrame + The timeseries dataframe + X : pd.DataFrame The DataFrame to initialize the effect. + + scale : float, optional + The scale of the timeseries. For multivariate timeseries, this is + a dataframe. For univariate, it is a simple float. """ # Do something with X, scale, and other parameters pass - def _transform( - self, X: pd.DataFrame, stage: Stage = Stage.TRAIN - ) -> Dict[str, jnp.ndarray]: - """Prepare the input data in a dict of jax arrays. + def _transform(self, X: pd.DataFrame, fh: pd.Index) -> Any: + """Prepare input data to be passed to numpyro model. - This method is called by the `fit()` method and can be overridden - by subclasses to provide additional data preparation logic. + This method receives the Exogenous variables DataFrame and should return a + the data needed for the effect. Those data will be passed to the `predict` + method as `data` argument. Parameters ---------- @@ -54,43 +63,46 @@ def _transform( time indexes, if passed during fit, or for the forecasting time indexes, if passed during predict. - stage : Stage, optional - The stage of the effect, by default Stage.TRAIN. This can be used to - differentiate between training and prediction stages and apply different - transformations accordingly. + fh : pd.Index + The forecasting horizon as a pandas Index. Returns ------- - Dict[str, jnp.ndarray] - A dictionary containing the data needed for the effect. The keys of the - dictionary should be the names of the arguments of the `apply` method, and - the values should be the corresponding data as jnp.ndarray. + Any + Any object containing the data needed for the effect. The object will be + passed to `predict` method as `data` argument. """ # Do something with X - if stage == "train": - array = series_to_tensor_or_array(X) - else: - # something else - pass - return {"data": array} - - def _predict(self, trend: jnp.ndarray, **kwargs) -> jnp.ndarray: - """Apply the effect. + array = series_to_tensor_or_array(X) + return array - This method is called by the `apply()` method and must be overridden by - subclasses to provide the actual effect computation logic. + def predict( + self, + data: Dict, + predicted_effects: Dict[str, jnp.ndarray], + ) -> jnp.ndarray: + """Apply and return the effect values. Parameters ---------- - trend : jnp.ndarray - An array containing the trend values. + data : Any + Data obtained from the transformed method. - kwargs: dict - Additional keyword arguments that may be needed to compute the effect. + predicted_effects : Dict[str, jnp.ndarray], optional + A dictionary containing the predicted effects, by default None. Returns ------- jnp.ndarray - The effect values. + An array with shape (T,1) for univariate timeseries, or (N, T, 1) for + multivariate timeseries, where T is the number of timepoints and N is the + number of series. """ + # Get the trend + # (T, 1) shaped array for univariate timeseries + # (N, T, 1) shaped array for multivariate timeseries, where N is the number of + # series + # trend: jnp.ndarray = predicted_effects["trend"] + # Or user predicted_effects.get("trend") to return None if the trend is + # not found raise NotImplementedError("Subclasses must implement _predict()") diff --git a/src/prophetverse/effects/base.py b/src/prophetverse/effects/base.py index a366802..a064b50 100644 --- a/src/prophetverse/effects/base.py +++ b/src/prophetverse/effects/base.py @@ -1,7 +1,6 @@ """Module that stores abstract class of effects.""" -from enum import Enum -from typing import Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional import jax.numpy as jnp import pandas as pd @@ -15,18 +14,6 @@ EFFECT_APPLICATION_TYPE = Literal["additive", "multiplicative"] -class Stage(str, Enum): - """ - Enum class for stages of the forecasting model. - - Used to indicate the stage of the model, either "train" or "predict", for the - effect preparation steps. - """ - - TRAIN: str = "train" - PREDICT: str = "predict" - - class BaseEffect(BaseObject): """Base class for effects. @@ -88,11 +75,14 @@ class BaseEffect(BaseObject): # If no columns are found, should # _predict be skipped? "skip_predict_if_no_match": True, + # Should only the indexes related to the forecasting horizon be passed to + # _transform? + "filter_indexes_with_forecating_horizon_at_transform": True, } def __init__(self): self._input_feature_column_names: List[str] = [] - self._is_fitted = False + self._is_fitted: bool = False @property def input_feature_column_names(self) -> List[str]: @@ -114,7 +104,7 @@ def should_skip_predict(self) -> bool: return True return False - def fit(self, X: pd.DataFrame, scale: float = 1.0): + def fit(self, y: pd.DataFrame, X: pd.DataFrame, scale: float = 1.0): """Initialize the effect. This method is called during `fit()` of the forecasting model. @@ -126,6 +116,9 @@ def fit(self, X: pd.DataFrame, scale: float = 1.0): Parameters ---------- + y : pd.DataFrame + The timeseries dataframe + X : pd.DataFrame The DataFrame to initialize the effect. @@ -144,9 +137,9 @@ def fit(self, X: pd.DataFrame, scale: float = 1.0): than one level of index. """ if not self.get_tag("supports_multivariate", False): - if X.index.nlevels > 1: + if X is not None and X.index.nlevels > 1: raise ValueError( - f"The effect of if {self.id} does not " + f"The effect {self.__class__.__name__} does not " + "support multivariate data" ) @@ -155,10 +148,10 @@ def fit(self, X: pd.DataFrame, scale: float = 1.0): else: self._input_feature_column_names = X.columns.tolist() - self._fit(X, scale=scale) + self._fit(y=y, X=X, scale=scale) self._is_fitted = True - def _fit(self, X: pd.DataFrame, scale: float = 1.0): + def _fit(self, y: pd.DataFrame, X: pd.DataFrame, scale: float = 1.0): """Customize the initialization of the effect. This method is called by the `fit()` method and can be overridden by @@ -166,20 +159,28 @@ def _fit(self, X: pd.DataFrame, scale: float = 1.0): Parameters ---------- + y : pd.DataFrame + The timeseries dataframe + X : pd.DataFrame The DataFrame to initialize the effect. + + scale : float, optional + The scale of the timeseries. For multivariate timeseries, this is + a dataframe. For univariate, it is a simple float. """ pass def transform( - self, X: pd.DataFrame, stage: Stage = Stage.TRAIN - ) -> Dict[str, jnp.ndarray]: + self, + X: pd.DataFrame, + fh: pd.Index, + ) -> Any: """Prepare input data to be passed to numpyro model. - This method is called during `fit()` and `predict()` of the forecasting model. - It receives the Exogenous variables DataFrame and should return a dictionary - containing the data needed for the effect. Those data will be passed to the - `predict` method as named arguments. + This method receives the Exogenous variables DataFrame and should return a + the data needed for the effect. Those data will be passed to the `predict` + method as `data` argument. Parameters ---------- @@ -188,14 +189,14 @@ def transform( time indexes, if passed during fit, or for the forecasting time indexes, if passed during predict. - + fh : pd.Index + The forecasting horizon as a pandas Index. Returns ------- - Dict[str, jnp.ndarray] - A dictionary containing the data needed for the effect. The keys of the - dictionary should be the names of the arguments of the `predict` method, and - the values should be the corresponding data as jnp.ndarray. + Any + Any object containing the data needed for the effect. The object will be + passed to `predict` method as `data` argument. Raises ------ @@ -209,16 +210,24 @@ def transform( if self.should_skip_predict: return {} + if self.get_tag("filter_indexes_with_forecating_horizon_at_transform", True): + # Filter when index level -1 is in fh + if X is not None: + X = X.loc[X.index.get_level_values(-1).isin(fh)] + X = X[self.input_feature_column_names] - return self._transform(X, stage=stage) + return self._transform(X, fh) def _transform( - self, X: pd.DataFrame, stage: Stage = Stage.TRAIN - ) -> Dict[str, jnp.ndarray]: - """Prepare the input data in a dict of jax arrays. + self, + X: pd.DataFrame, + fh: pd.Index, + ) -> Any: + """Prepare input data to be passed to numpyro model. - This method is called by the `fit()` method and can be overridden - by subclasses to provide additional data preparation logic. + This method receives the Exogenous variables DataFrame and should return a + the data needed for the effect. Those data will be passed to the `predict` + method as `data` argument. Parameters ---------- @@ -227,57 +236,76 @@ def _transform( time indexes, if passed during fit, or for the forecasting time indexes, if passed during predict. + fh : pd.Index + The forecasting horizon as a pandas Index. + Returns ------- - Dict[str, jnp.ndarray] - A dictionary containing the data needed for the effect. The keys of the - dictionary should be the names of the arguments of the `predict` method, and - the values should be the corresponding data as jnp.ndarray. + Any + Any object containing the data needed for the effect. The object will be + passed to `predict` method as `data` argument. """ array = series_to_tensor_or_array(X) - return {"data": array} + return array - def predict(self, trend: jnp.ndarray, **kwargs) -> jnp.ndarray: + def predict( + self, + data: Dict, + predicted_effects: Optional[Dict[str, jnp.ndarray]] = None, + ) -> jnp.ndarray: """Apply and return the effect values. Parameters ---------- - trend : jnp.ndarray - An array containing the trend values. + data : Any + Data obtained from the transformed method. + + predicted_effects : Dict[str, jnp.ndarray], optional + A dictionary containing the predicted effects, by default None. Returns ------- jnp.ndarray - The effect values. + An array with shape (T,1) for univariate timeseries, or (N, T, 1) for + multivariate timeseries, where T is the number of timepoints and N is the + number of series. """ - x = self._predict(trend, **kwargs) + if predicted_effects is None: + predicted_effects = {} - return x + x = self._predict(data, predicted_effects) - def _predict(self, trend: jnp.ndarray, **kwargs) -> jnp.ndarray: - """Apply the effect. + return x - This method is called by the `predict()` method and must be overridden by - subclasses to provide the actual effect computation logic. + def _predict( + self, + data: Dict, + predicted_effects: Dict[str, jnp.ndarray], + ) -> jnp.ndarray: + """Apply and return the effect values. Parameters ---------- - trend : jnp.ndarray - An array containing the trend values. + data : Any + Data obtained from the transformed method. - kwargs: dict - Additional keyword arguments that may be needed to compute the effect. + predicted_effects : Dict[str, jnp.ndarray], optional + A dictionary containing the predicted effects, by default None. Returns ------- jnp.ndarray - The effect values. + An array with shape (T,1) for univariate timeseries, or (N, T, 1) for + multivariate timeseries, where T is the number of timepoints and N is the + number of series. """ raise NotImplementedError("Subclasses must implement _predict()") - def __call__(self, trend: jnp.ndarray, **kwargs) -> jnp.ndarray: + def __call__( + self, data: Dict, predicted_effects: Dict[str, jnp.ndarray] + ) -> jnp.ndarray: """Run the processes to calculate effect as a function.""" - return self.predict(trend, **kwargs) + return self.predict(data=data, predicted_effects=predicted_effects) class BaseAdditiveOrMultiplicativeEffect(BaseEffect): @@ -314,20 +342,42 @@ def __init__(self, effect_mode="additive"): super().__init__() - def predict(self, trend: jnp.ndarray, **kwargs) -> jnp.ndarray: - """Apply the effect. + def predict( + self, + data: Any, + predicted_effects: Optional[Dict[str, jnp.ndarray]] = None, + ) -> jnp.ndarray: + """Apply and return the effect values. Parameters ---------- - trend : jnp.ndarray - The trend of the model. + data : Any + Data obtained from the transformed method. + + predicted_effects : Dict[str, jnp.ndarray], optional + A dictionary containing the predicted effects, by default None. Returns ------- jnp.ndarray - The computed effect. + An array with shape (T,1) for univariate timeseries, or (N, T, 1) for + multivariate timeseries, where T is the number of timepoints and N is the + number of series. """ - x = super().predict(trend, **kwargs) + if predicted_effects is None: + raise ValueError( + "BaseAdditiveOrMultiplicativeEffect requires trend in" + + " predicted_effects" + ) + + trend = predicted_effects["trend"] + if trend.ndim == 1: + trend = trend.reshape((-1, 1)) + + x = super().predict(data=data, predicted_effects=predicted_effects) + x = x.reshape(trend.shape) + if self.effect_mode == "additive": return x + return trend * x diff --git a/src/prophetverse/effects/fourier.py b/src/prophetverse/effects/fourier.py index 2782075..71fd37b 100644 --- a/src/prophetverse/effects/fourier.py +++ b/src/prophetverse/effects/fourier.py @@ -7,7 +7,7 @@ import pandas as pd from sktime.transformations.series.fourier import FourierFeatures -from prophetverse.effects.base import EFFECT_APPLICATION_TYPE, BaseEffect, Stage +from prophetverse.effects.base import EFFECT_APPLICATION_TYPE, BaseEffect from prophetverse.effects.linear import LinearEffect from prophetverse.sktime._expand_column_per_level import ExpandColumnPerLevel @@ -57,15 +57,19 @@ def __init__( self.effect_mode = effect_mode self.expand_column_per_level_ = None # type: Union[None,ExpandColumnPerLevel] - def _fit(self, X: pd.DataFrame, scale: float = 1.0): + def _fit(self, y: pd.DataFrame, X: pd.DataFrame, scale: float = 1.0): """Customize the initialization of the effect. Fit the fourier feature transformer and the linear effect. Parameters ---------- + y : pd.DataFrame + The timeseries dataframe + X : pd.DataFrame The DataFrame to initialize the effect. + scale: float, optional The scale of the timeseries, by default 1.0. """ @@ -73,7 +77,7 @@ def _fit(self, X: pd.DataFrame, scale: float = 1.0): sp_list=self.sp_list, fourier_terms_list=self.fourier_terms_list, freq=self.freq, - keep_original_columns=True, + keep_original_columns=False, ) self.fourier_features_.fit(X=X) @@ -87,14 +91,13 @@ def _fit(self, X: pd.DataFrame, scale: float = 1.0): prior=dist.Normal(0, self.prior_scale), effect_mode=self.effect_mode ) - self.linear_effect_.fit(X=X, scale=scale) + self.linear_effect_.fit(X=X, y=y, scale=scale) - def _transform( - self, X: pd.DataFrame, stage: Stage = Stage.TRAIN - ) -> Dict[str, jnp.ndarray]: - """Prepare the input data in a dict of jax arrays. + def _transform(self, X: pd.DataFrame, fh: pd.Index) -> jnp.ndarray: + """Prepare input data to be passed to numpyro model. - Creates the fourier terms and the linear effect. + This method return a jnp.ndarray of sines and cosines of the given + frequencies. Parameters ---------- @@ -103,41 +106,46 @@ def _transform( time indexes, if passed during fit, or for the forecasting time indexes, if passed during predict. - stage : Stage, optional - The stage of the effect, by default Stage.TRAIN. This can be used to - differentiate between training and prediction stages and apply different - transformations accordingly. + fh : pd.Index + The forecasting horizon as a pandas Index. Returns ------- - Dict[str, jnp.ndarray] - A dictionary containing the data needed for the effect. + jnp.ndarray + Any object containing the data needed for the effect. The object will be + passed to `predict` method as `data` argument. """ X = self.fourier_features_.transform(X) if self.expand_column_per_level_ is not None: X = self.expand_column_per_level_.transform(X) - array = self.linear_effect_.transform(X, stage) + array = self.linear_effect_.transform(X, fh) return array - def _predict(self, trend: jnp.ndarray, **kwargs) -> jnp.ndarray: - """Apply the effect. - - Apply linear seasonality. + def _predict( + self, + data: Dict, + predicted_effects: Dict[str, jnp.ndarray], + ) -> jnp.ndarray: + """Apply and return the effect values. Parameters ---------- - trend : jnp.ndarray - An array containing the trend values. + data : Any + Data obtained from the transformed method. - kwargs: dict - Additional keyword arguments that may be needed to compute the effect. + predicted_effects : Dict[str, jnp.ndarray], optional + A dictionary containing the predicted effects, by default None. Returns ------- jnp.ndarray - The effect values. + An array with shape (T,1) for univariate timeseries, or (N, T, 1) for + multivariate timeseries, where T is the number of timepoints and N is the + number of series. """ - return self.linear_effect_.predict(trend, **kwargs) + return self.linear_effect_.predict( + data=data, predicted_effects=predicted_effects + ) diff --git a/src/prophetverse/effects/hill.py b/src/prophetverse/effects/hill.py index a2bb565..227cc46 100644 --- a/src/prophetverse/effects/hill.py +++ b/src/prophetverse/effects/hill.py @@ -1,6 +1,6 @@ """Definition of Hill Effect class.""" -from typing import Optional +from typing import Dict, Optional import jax.numpy as jnp import numpyro @@ -44,23 +44,26 @@ def __init__( super().__init__(effect_mode=effect_mode) - def _predict(self, trend: jnp.ndarray, **kwargs) -> jnp.ndarray: - """Compute the effect using the log transformation. + def _predict( + self, + data: Dict[str, jnp.ndarray], + predicted_effects: Dict[str, jnp.ndarray], + ) -> jnp.ndarray: + """Apply and return the effect values. Parameters ---------- - trend : jnp.ndarray - The trend component of the hierarchical prophet model. - data : jnp.ndarray - The data used to compute the effect. + data : Any + Data obtained from the transformed method. + + predicted_effects : Dict[str, jnp.ndarray] + A dictionary containing the predicted effects Returns ------- jnp.ndarray - The computed effect based on the given trend and data. + An array with shape (T,1) for univariate timeseries. """ - data: jnp.ndarray = kwargs.pop("data") - half_max = numpyro.sample("half_max", self.half_max_prior) slope = numpyro.sample("slope", self.slope_prior) max_effect = numpyro.sample("max_effect", self.max_effect_prior) diff --git a/src/prophetverse/effects/lift_experiment.py b/src/prophetverse/effects/lift_experiment.py index 8cd32a6..b2e8df5 100644 --- a/src/prophetverse/effects/lift_experiment.py +++ b/src/prophetverse/effects/lift_experiment.py @@ -9,7 +9,7 @@ from prophetverse.utils.frame_to_array import series_to_tensor_or_array -from .base import BaseEffect, Stage +from .base import BaseEffect __all__ = ["LiftExperimentLikelihood"] @@ -48,44 +48,60 @@ def __init__( super().__init__() - def fit(self, X: pd.DataFrame, scale: float = 1): - """Initialize this effect and its wrapped effect. + def fit(self, y: pd.DataFrame, X: pd.DataFrame, scale: float = 1): + """Initialize the effect. + + This method is called during `fit()` of the forecasting model. + It receives the Exogenous variables DataFrame and should be used to initialize + any necessary parameters or data structures, such as detecting the columns that + match the regex pattern. + + This method MUST set _input_feature_columns_names to a list of column names Parameters ---------- - X : DataFrame - Dataframe of exogenous data. - scale : float - The scale of the timeseries. This is used to normalize the lift effect. + y : pd.DataFrame + The timeseries dataframe + + X : pd.DataFrame + The DataFrame to initialize the effect. + + scale : float, optional + The scale of the timeseries. For multivariate timeseries, this is + a dataframe. For univariate, it is a simple float. + + Returns + ------- + None """ - self.effect.fit(X) + self.effect.fit(X=X, y=y, scale=scale) self.timeseries_scale = scale - super().fit(X) + super().fit(X=X, y=y, scale=scale) + + def _transform(self, X: pd.DataFrame, fh: pd.Index) -> Dict[str, Any]: + """Prepare input data to be passed to numpyro model. - def _transform(self, X: pd.DataFrame, stage: Stage = Stage.TRAIN) -> Dict[str, Any]: - """Prepare the input data for the effect, and the custom likelihood. + Returns a dictionary with the data for the lift and for the inner effect. Parameters ---------- X : pd.DataFrame - The input data with exogenous variables. - stage : Stage, optional - which stage is being executed, by default Stage.TRAIN. - Used to determine if the likelihood should be applied. + The input DataFrame containing the exogenous variables for the training + time indexes, if passed during fit, or for the forecasting time indexes, if + passed during predict. + + fh : pd.Index + The forecasting horizon as a pandas Index. Returns ------- Dict[str, Any] - The dictionary of data passed to _predict and the likelihood. + Dictionary with data for the lift and for the inner effect """ - data_dict = self.effect._transform(X, stage) - - if stage == Stage.PREDICT: - data_dict["observed_lift"] = None - data_dict["obs_mask"] = None - return data_dict + data_dict = {} + data_dict["inner_effect_data"] = self.effect._transform(X, fh=fh) - X_lift = self.lift_test_results.loc[X.index] + X_lift = self.lift_test_results.reindex(fh, fill_value=jnp.nan) lift_array = series_to_tensor_or_array(X_lift) data_dict["observed_lift"] = lift_array / self.timeseries_scale data_dict["obs_mask"] = ~jnp.isnan(data_dict["observed_lift"]) @@ -93,28 +109,29 @@ def _transform(self, X: pd.DataFrame, stage: Stage = Stage.TRAIN) -> Dict[str, A return data_dict def _predict( - self, - trend: jnp.ndarray, - **kwargs, + self, data: Dict, predicted_effects: Dict[str, jnp.ndarray] ) -> jnp.ndarray: - """Apply the effect and the custom likelihood. + """Apply and return the effect values. Parameters ---------- - trend : jnp.ndarray - The trend component. - observed_lift : jnp.ndarray - The observed lift to apply the likelihood to. + data : Any + Data obtained from the transformed method. + + predicted_effects : Dict[str, jnp.ndarray], optional + A dictionary containing the predicted effects, by default None. Returns ------- jnp.ndarray - The effect applied to the input data. + An array with shape (T,1) for univariate timeseries. """ - observed_lift = kwargs.pop("observed_lift") - obs_mask = kwargs.pop("obs_mask") + observed_lift = data["observed_lift"] + obs_mask = data["obs_mask"] - x = self.effect.predict(trend, **kwargs) + x = self.effect.predict( + data=data["inner_effect_data"], predicted_effects=predicted_effects + ) numpyro.sample( "lift_experiment", diff --git a/src/prophetverse/effects/linear.py b/src/prophetverse/effects/linear.py index d2f77e3..f7077f1 100644 --- a/src/prophetverse/effects/linear.py +++ b/src/prophetverse/effects/linear.py @@ -1,6 +1,6 @@ """Definition of Linear Effect class.""" -from typing import Optional +from typing import Any, Dict, Optional import jax.numpy as jnp import numpyro @@ -40,23 +40,28 @@ def __init__( super().__init__(effect_mode=effect_mode) - def _predict(self, trend: jnp.ndarray, **kwargs) -> jnp.ndarray: - """Compute the Linear effect. + def _predict( + self, + data: Any, + predicted_effects: Optional[Dict[str, jnp.ndarray]] = None, + ) -> jnp.ndarray: + """Apply and return the effect values. Parameters ---------- - trend : jnp.ndarray - The trend component of the hierarchical prophet model. - data : jnp.ndarray - The data used to compute the effect. + data : Any + Data obtained from the transformed method. + + predicted_effects : Dict[str, jnp.ndarray], optional + A dictionary containing the predicted effects, by default None. Returns ------- jnp.ndarray - The computed effect based on the given trend and data. + An array with shape (T,1) for univariate timeseries, or (N, T, 1) for + multivariate timeseries, where T is the number of timepoints and N is the + number of series. """ - data = kwargs.pop("data") - n_features = data.shape[-1] with numpyro.plate("features_plate", n_features, dim=-1): diff --git a/src/prophetverse/effects/log.py b/src/prophetverse/effects/log.py index f957f75..dee6fba 100644 --- a/src/prophetverse/effects/log.py +++ b/src/prophetverse/effects/log.py @@ -1,6 +1,6 @@ """Definition of Log Effect class.""" -from typing import Optional +from typing import Dict, Optional import jax.numpy as jnp import numpyro @@ -39,24 +39,27 @@ def __init__( super().__init__(effect_mode=effect_mode) def _predict( # type: ignore[override] - self, trend: jnp.ndarray, **kwargs + self, + data: jnp.ndarray, + predicted_effects: Optional[Dict[str, jnp.ndarray]] = None, ) -> jnp.ndarray: - """Compute the effect using the log transformation. + """Apply and return the effect values. Parameters ---------- - trend : jnp.ndarray - The trend component. - data : jnp.ndarray - The input data. + data : Any + Data obtained from the transformed method. + + predicted_effects : Dict[str, jnp.ndarray], optional + A dictionary containing the predicted effects, by default None. Returns ------- jnp.ndarray - The computed effect based on the given trend and data. + An array with shape (T,1) for univariate timeseries, or (N, T, 1) for + multivariate timeseries, where T is the number of timepoints and N is the + number of series. """ - data: jnp.ndarray = kwargs.pop("data") - scale = numpyro.sample("log_scale", self.scale_prior) rate = numpyro.sample("log_rate", self.rate_prior) effect = scale * jnp.log(jnp.clip(rate * data + 1, 1e-8, None)) diff --git a/src/prophetverse/effects/trend/__init__.py b/src/prophetverse/effects/trend/__init__.py new file mode 100644 index 0000000..4f1d591 --- /dev/null +++ b/src/prophetverse/effects/trend/__init__.py @@ -0,0 +1,12 @@ +"""Module for trend models in prophetverse.""" + +from .base import TrendEffectMixin +from .flat import FlatTrend +from .piecewise import PiecewiseLinearTrend, PiecewiseLogisticTrend + +__all__ = [ + "TrendEffectMixin", + "FlatTrend", + "PiecewiseLinearTrend", + "PiecewiseLogisticTrend", +] diff --git a/src/prophetverse/effects/trend/base.py b/src/prophetverse/effects/trend/base.py new file mode 100644 index 0000000..db843ea --- /dev/null +++ b/src/prophetverse/effects/trend/base.py @@ -0,0 +1,68 @@ +"""Module containing the base class for trend models.""" + +import pandas as pd + +from prophetverse.utils.frame_to_array import convert_index_to_days_since_epoch + + +class TrendEffectMixin: + """ + Mixin class for trend models. + + Trend models are effects applied to the trend component of a time series. + + Attributes + ---------- + t_scale: float + The time scale of the trend model. + t_start: float + The starting time of the trend model. + n_series: int + The number of series in the time series data. + """ + + _tags = {"skip_predict_if_no_match": False, "supports_multivariate": True} + + def _fit(self, y: pd.DataFrame, X: pd.DataFrame, scale: float = 1) -> None: + """Initialize the effect. + + Set the time scale, starting time, and number of series attributes. + + Parameters + ---------- + y : pd.DataFrame + The timeseries dataframe + + X : pd.DataFrame + The DataFrame to initialize the effect. + """ + # Set time scale + t_days = convert_index_to_days_since_epoch( + y.index.get_level_values(-1).unique() + ) + self.t_scale = (t_days[1:] - t_days[:-1]).mean() + self.t_start = t_days.min() / self.t_scale + if y.index.nlevels > 1: + self.n_series = y.index.droplevel(-1).nunique() + else: + self.n_series = 1 + + def _index_to_scaled_timearray(self, idx): + """ + Convert the index to a scaled time array. + + Parameters + ---------- + idx: int + The index to be converted. + + Returns + ------- + float + The scaled time array value. + """ + if idx.nlevels > 1: + idx = idx.get_level_values(-1).unique() + + t_days = convert_index_to_days_since_epoch(idx) + return (t_days) / self.t_scale - self.t_start diff --git a/src/prophetverse/trend/flat.py b/src/prophetverse/effects/trend/flat.py similarity index 71% rename from src/prophetverse/trend/flat.py rename to src/prophetverse/effects/trend/flat.py index e3bea09..ce07c16 100644 --- a/src/prophetverse/trend/flat.py +++ b/src/prophetverse/effects/trend/flat.py @@ -5,10 +5,12 @@ import numpyro.distributions as dist import pandas as pd -from .base import TrendModel +from prophetverse.effects.base import BaseEffect +from .base import TrendEffectMixin -class FlatTrend(TrendModel): + +class FlatTrend(TrendEffectMixin, BaseEffect): """Flat trend model. The mean of the target variable is used as the prior location for the trend. @@ -23,17 +25,22 @@ def __init__(self, changepoint_prior_scale: float = 0.1) -> None: self.changepoint_prior_scale = changepoint_prior_scale super().__init__() - def initialize(self, y: pd.DataFrame): - """Set the prior location for the trend. + def _fit(self, y: pd.DataFrame, X: pd.DataFrame, scale: float = 1): + """Initialize the effect. + + Set the prior location for the trend. Parameters ---------- y : pd.DataFrame - The target variable. + The timeseries dataframe + + X : pd.DataFrame + The DataFrame to initialize the effect. """ self.changepoint_prior_loc = y.mean().values - def fit(self, idx: pd.PeriodIndex) -> dict: + def _transform(self, X: pd.DataFrame, fh: pd.Index) -> dict: """Prepare input data (a constant factor in this case). Parameters @@ -46,12 +53,11 @@ def fit(self, idx: pd.PeriodIndex) -> dict: dict dictionary containing the input data for the trend model """ - return { - "constant_vector": jnp.ones((len(idx), 1)), - } + idx = X.index + return jnp.ones((len(idx), 1)) - def compute_trend( # type: ignore[override] - self, constant_vector: jnp.ndarray, **kwargs + def _predict( # type: ignore[override] + self, data: jnp.ndarray, predicted_effects=None ) -> jnp.ndarray: """Apply the trend. @@ -65,6 +71,9 @@ def compute_trend( # type: ignore[override] jnp.ndarray The forecasted trend """ + # Alias for clarity + constant_vector = data + mean = self.changepoint_prior_loc var = self.changepoint_prior_scale**2 diff --git a/src/prophetverse/trend/piecewise.py b/src/prophetverse/effects/trend/piecewise.py similarity index 88% rename from src/prophetverse/trend/piecewise.py rename to src/prophetverse/effects/trend/piecewise.py index 31f7469..adc2f6f 100644 --- a/src/prophetverse/trend/piecewise.py +++ b/src/prophetverse/effects/trend/piecewise.py @@ -4,7 +4,8 @@ This module contains the implementation of piecewise trend models (logistic and linear). """ -from typing import Tuple, Union +import itertools +from typing import Dict, Tuple, Union import jax.numpy as jnp import numpy as np @@ -13,14 +14,15 @@ import pandas as pd from sktime.transformations.series.detrend import Detrender +from prophetverse.effects.base import BaseEffect from prophetverse.utils.frame_to_array import series_to_tensor -from .base import TrendModel +from .base import TrendEffectMixin __all__ = ["PiecewiseLinearTrend", "PiecewiseLogisticTrend"] -class PiecewiseLinearTrend(TrendModel): +class PiecewiseLinearTrend(TrendEffectMixin, BaseEffect): """Piecewise Linear Trend model. This model assumes that the trend is piecewise linear, with changepoints @@ -71,29 +73,72 @@ def __init__( self.remove_seasonality_before_suggesting_initial_vals = ( remove_seasonality_before_suggesting_initial_vals ) - super().__init__(**kwargs) + super().__init__() - def initialize(self, y: pd.DataFrame): - """ - Initialize the piecewise trend model. + def _fit(self, y: pd.DataFrame, X: pd.DataFrame, scale: float = 1): + """Initialize the effect. + + Set the prior location for the trend. Parameters ---------- - y: pd.DataFrame - The input data. + y : pd.DataFrame + The timeseries dataframe - Returns - ------- - None + X : pd.DataFrame + The DataFrame to initialize the effect. + + scale : float, optional + The scale of the timeseries. For multivariate timeseries, this is + a dataframe. For univariate, it is a simple float. """ - super().initialize(y) + super()._fit(X=X, y=y, scale=scale) + t_scaled = self._index_to_scaled_timearray( y.index.get_level_values(-1).unique() ) self._setup_changepoints(t_scaled) self._setup_changepoint_prior_vectors(y) + self._index_names = y.index.names + self._series_idx = None + if y.index.nlevels > 1: + self._series_idx = y.index.droplevel(-1).unique() + + def _fh_to_index(self, fh: pd.Index) -> Union[pd.Index, pd.MultiIndex]: + """Convert an index representing the fcst horizon to multiindex if needed. + + If there's a single timeseries, just returns the fh. + + Parameters + ---------- + fh : pd.Index + The timeindex representing the forecasting horizon. + + Returns + ------- + Union[pd.Index, pd.MultiIndex] + The fh for all time series passed during fit + """ + if self._series_idx is None: + return fh + + idx_list = self._series_idx.to_list() + idx_list = [x if isinstance(x, tuple) else (x,) for x in idx_list] + # Create a new multi-index combining the existing levels with the new time index + new_idx_tuples = list( + map( + lambda x: ( + *x[0], + x[1], + ), + # Create a cross product of current indexes + # and dates in fh + itertools.product(idx_list, fh.to_list()), + ) + ) + return pd.MultiIndex.from_tuples(new_idx_tuples, names=self._index_names) - def fit(self, idx: pd.PeriodIndex) -> dict: + def _transform(self, X, fh) -> dict: """ Prepare the input data for the piecewise trend model. @@ -107,7 +152,53 @@ def fit(self, idx: pd.PeriodIndex) -> dict: dict A dictionary containing the prepared input data. """ - return {"changepoint_matrix": self.get_changepoint_matrix(idx)} + idx = self._fh_to_index(fh) + return self.get_changepoint_matrix(idx) + + def _predict( + self, data: jnp.ndarray, predicted_effects: Dict[str, jnp.ndarray] + ) -> jnp.ndarray: + """ + Compute the trend based on the given changepoint matrix. + + Parameters + ---------- + data: jnp.ndarray + The changepoint matrix. + predicted_effects: Dict[str, jnp.ndarray] + Dictionary of previously computed effects. For the trend, it is an empty + dict. + + Returns + ------- + jnp.ndarray + The computed trend. + """ + # alias for clarity + changepoint_matrix = data + offset = numpyro.sample( + "offset", + dist.Normal(self._offset_prior_loc, self._offset_prior_scale), + ) + + changepoint_coefficients = numpyro.sample( + "changepoint_coefficients", + dist.Laplace(self._changepoint_prior_loc, self._changepoint_prior_scale), + ) + + # If multivariate + if changepoint_matrix.ndim == 3: + changepoint_coefficients = changepoint_coefficients.reshape((1, -1, 1)) + offset = offset.reshape((-1, 1, 1)) + + trend = (changepoint_matrix) @ changepoint_coefficients + offset + + if trend.ndim == 1 or ( + trend.ndim == 3 and self.n_series == 1 and self.squeeze_if_single_series + ): + trend = trend.reshape((-1, 1)) + + return trend def get_changepoint_matrix(self, idx: pd.PeriodIndex) -> jnp.ndarray: """ @@ -338,46 +429,6 @@ def _suggest_global_trend_and_offset( return global_rate, offset_loc - def compute_trend( # type: ignore[override] - self, changepoint_matrix: jnp.ndarray - ) -> jnp.ndarray: - """ - Compute the trend based on the given changepoint matrix. - - Parameters - ---------- - changepoint_matrix: jnp.ndarray - The changepoint matrix. - - Returns - ------- - jnp.ndarray - The computed trend. - """ - offset = numpyro.sample( - "offset", - dist.Normal(self._offset_prior_loc, self._offset_prior_scale), - ) - - changepoint_coefficients = numpyro.sample( - "changepoint_coefficients", - dist.Laplace(self._changepoint_prior_loc, self._changepoint_prior_scale), - ) - - # If multivariate - if changepoint_matrix.ndim == 3: - changepoint_coefficients = changepoint_coefficients.reshape((1, -1, 1)) - offset = offset.reshape((-1, 1, 1)) - - trend = (changepoint_matrix) @ changepoint_coefficients + offset - - if trend.ndim == 1 or ( - trend.ndim == 3 and self.n_series == 1 and self.squeeze_if_single_series - ): - trend = trend.reshape((-1, 1)) - - return trend - class PiecewiseLogisticTrend(PiecewiseLinearTrend): """ @@ -480,8 +531,8 @@ def _suggest_global_trend_and_offset( return global_rates, offset - def compute_trend( # type: ignore[override] - self, changepoint_matrix: jnp.ndarray, **kwargs + def _predict( # type: ignore[override] + self, data: jnp.ndarray, predicted_effects=None ) -> jnp.ndarray: """ Compute the trend for the given changepoint matrix. @@ -499,7 +550,7 @@ def compute_trend( # type: ignore[override] with numpyro.plate("series", self.n_series, dim=-3): capacity = numpyro.sample("capacity", self.capacity_prior) - trend = super().compute_trend(changepoint_matrix) + trend = super()._predict(data=data, predicted_effects=predicted_effects) if self.n_series == 1: capacity = capacity.squeeze() diff --git a/src/prophetverse/models.py b/src/prophetverse/models.py index 00e09b1..411807e 100644 --- a/src/prophetverse/models.py +++ b/src/prophetverse/models.py @@ -8,12 +8,11 @@ from prophetverse.distributions import GammaReparametrized from prophetverse.effects.base import BaseEffect -from prophetverse.trend.base import TrendModel def multivariate_model( y, - trend_model: TrendModel, + trend_model: BaseEffect, trend_data: Dict[str, jnp.ndarray], data: Optional[Dict[str, jnp.ndarray]] = None, exogenous_effects: Optional[Dict[str, BaseEffect]] = None, @@ -31,7 +30,7 @@ def multivariate_model( Parameters ---------- y (jnp.ndarray): Array of time series data. - trend_model (TrendModel): Trend model. + trend_model (BaseEffect): Trend model. trend_data (dict): Dictionary containing the data needed for the trend model. data (dict): Dictionary containing the exogenous data. exogenous_effects (dict): Dictionary containing the exogenous effects. @@ -39,24 +38,11 @@ def multivariate_model( correlation_matrix_concentration (float): Concentration parameter for the LKJ distribution. """ - trend = trend_model(**trend_data) - - numpyro.deterministic("trend", trend) - - mean = trend - # Exogenous effects - if exogenous_effects is not None: - - for exog_effect_name, exog_effect in exogenous_effects.items(): - - exog_data = data[exog_effect_name] # type: ignore[index] - with numpyro.handlers.scope(prefix=exog_effect_name): - effect = exog_effect(trend=trend, **exog_data) - effect = numpyro.deterministic(exog_effect_name, effect) - mean += effect - - std_observation = numpyro.sample( - "std_observation", dist.HalfNormal(jnp.array([noise_scale] * mean.shape[0])) + mean = _compute_mean_univariate( + trend_model=trend_model, + trend_data=trend_data, + data=data, + exogenous_effects=exogenous_effects, ) if y is not None: @@ -64,12 +50,20 @@ def multivariate_model( if is_single_series: + mean = mean.reshape((-1, 1)) + std_observation = numpyro.sample( + "std_observation", dist.HalfNormal(jnp.array(noise_scale)) + ) + with numpyro.plate("time", mean.shape[-1], dim=-2): - numpyro.sample( - "obs", dist.Normal(mean.squeeze(-1).T, std_observation), obs=y - ) + numpyro.sample("obs", dist.Normal(mean, std_observation), obs=y) else: + + std_observation = numpyro.sample( + "std_observation", dist.HalfNormal(jnp.array([noise_scale] * mean.shape[0])) + ) + correlation_matrix = numpyro.sample( "corr_matrix", dist.LKJCholesky( @@ -94,7 +88,7 @@ def multivariate_model( def univariate_model( y, - trend_model: TrendModel, + trend_model: BaseEffect, trend_data: Dict[str, jnp.ndarray], data: Optional[Dict[str, jnp.ndarray]] = None, exogenous_effects: Optional[Dict[str, BaseEffect]] = None, @@ -107,7 +101,7 @@ def univariate_model( Parameters ---------- y (jnp.ndarray): Array of time series data. - trend_model (TrendModel): Trend model. + trend_model (BaseEffect): Trend model. trend_data (dict): Dictionary containing the data needed for the trend model. data (dict): Dictionary containing the exogenous data. exogenous_effects (dict): Dictionary containing the exogenous effects. @@ -132,7 +126,7 @@ def univariate_model( def univariate_gamma_model( y, - trend_model: TrendModel, + trend_model: BaseEffect, trend_data: Dict[str, jnp.ndarray], data: Optional[Dict[str, jnp.ndarray]] = None, exogenous_effects: Optional[Dict[str, BaseEffect]] = None, @@ -145,7 +139,7 @@ def univariate_gamma_model( Parameters ---------- y (jnp.ndarray): Array of time series data. - trend_model (TrendModel): Trend model. + trend_model (BaseEffect): Trend model. trend_data (dict): Dictionary containing the data needed for the trend model. data (dict): Dictionary containing the exogenous data. exogenous_effects (dict): Dictionary containing the exogenous effects. @@ -172,7 +166,7 @@ def univariate_gamma_model( def univariate_negbinomial_model( y, - trend_model: TrendModel, + trend_model: BaseEffect, trend_data: Dict[str, jnp.ndarray], data: Optional[Dict[str, jnp.ndarray]] = None, exogenous_effects: Optional[Dict[str, BaseEffect]] = None, @@ -186,7 +180,7 @@ def univariate_negbinomial_model( Parameters ---------- y (jnp.ndarray): Array of time series data. - trend_model (TrendModel): Trend model. + trend_model (BaseEffect): Trend model. trend_data (dict): Dictionary containing the data needed for the trend model. data (dict): Dictionary containing the exogenous data. exogenous_effects (dict): Dictionary containing the exogenous effects. @@ -245,12 +239,16 @@ def _to_positive( def _compute_mean_univariate( - trend_model: TrendModel, + trend_model: BaseEffect, trend_data: Dict[str, jnp.ndarray], data: Optional[Dict[str, jnp.ndarray]] = None, exogenous_effects: Optional[Dict[str, BaseEffect]] = None, ): - trend = trend_model(**trend_data) + + predicted_effects: Dict[str, jnp.ndarray] = {} + + trend = trend_model(data=trend_data, predicted_effects=predicted_effects) + predicted_effects["trend"] = trend numpyro.deterministic("trend", trend) @@ -259,9 +257,10 @@ def _compute_mean_univariate( if exogenous_effects is not None: for exog_effect_name, exog_effect in exogenous_effects.items(): - exog_data = data[exog_effect_name] # type: ignore[index] + transformed_data = data[exog_effect_name] # type: ignore[index] with numpyro.handlers.scope(prefix=exog_effect_name): - effect = exog_effect(trend=trend, **exog_data) + effect = exog_effect(transformed_data, predicted_effects) effect = numpyro.deterministic(exog_effect_name, effect) mean += effect + predicted_effects[exog_effect_name] = effect return mean diff --git a/src/prophetverse/sktime/base.py b/src/prophetverse/sktime/base.py index 0e7877a..e998aff 100644 --- a/src/prophetverse/sktime/base.py +++ b/src/prophetverse/sktime/base.py @@ -2,6 +2,7 @@ import itertools import warnings +from collections import OrderedDict from typing import Any, Dict, List, Optional, Tuple, Union import jax @@ -13,8 +14,13 @@ from sktime.base import _HeterogenousMetaEstimator from sktime.forecasting.base import BaseForecaster, ForecastingHorizon -from prophetverse.effects.base import BaseEffect, Stage +from prophetverse.effects.base import BaseEffect from prophetverse.effects.linear import LinearEffect +from prophetverse.effects.trend import ( + FlatTrend, + PiecewiseLinearTrend, + PiecewiseLogisticTrend, +) from prophetverse.engine import MAPInferenceEngine, MCMCInferenceEngine from prophetverse.utils import get_multiindex_loc @@ -88,7 +94,7 @@ def __init__( super().__init__() @property - def should_skip_scaling(self): + def _likelihood_is_discrete(self): """Property that indicates whether the forecaster uses a discrete likelihood. As a consequence, the target variable must be integer-valued and will not be @@ -495,7 +501,7 @@ def _scale_y(self, y: pd.DataFrame) -> pd.DataFrame: This method assumes that the scaling factor has already been computed and stored in the `_scale` attribute of the class. """ - if self.should_skip_scaling: + if self._likelihood_is_discrete: return y if isinstance(self._scale, float): @@ -531,7 +537,7 @@ def _inv_scale_y(self, y: pd.DataFrame) -> pd.DataFrame: This method assumes that the scaling factor has already been computed and stored in the `_scale` attribute of the class. """ - if self.should_skip_scaling: + if self._likelihood_is_discrete: return y if isinstance(self._scale, float): @@ -751,15 +757,48 @@ def _vectorize_predict_method( return pd.concat(outs, axis=0) -class BaseEffectsBayesianForecaster(_HeterogenousMetaEstimator, BaseBayesianForecaster): +class BaseProphetForecaster(_HeterogenousMetaEstimator, BaseBayesianForecaster): """Base class for Bayesian estimators with Effects objects. Parameters ---------- + trend : Union[str, BaseEffect], optional, one of "linear" (default) or "logistic" + Type of trend to use. Can also be a custom effect object. + + changepoint_interval : int, optional, default=25 + Number of potential changepoints to sample in the history. + + changepoint_range : float or int, optional, default=0.8 + Proportion of the history in which trend changepoints will be estimated. + + * if float, must be between 0 and 1. + The range will be that proportion of the training history. + + * if int, can be positive or negative. + Absolute value must be less than number of training points. + The range will be that number of points. + A negative int indicates number of points + counting from the end of the history, a positive int from the beginning. + + changepoint_prior_scale : float, optional, default=0.001 + Regularization parameter controlling the flexibility + of the automatic changepoint selection. + + offset_prior_scale : float, optional, default=0.1 + Scale parameter for the prior distribution of the offset. + The offset is the constant term in the piecewise trend equation. + + capacity_prior_scale : float, optional, default=0.2 + Scale parameter for the prior distribution of the capacity. + + capacity_prior_loc : float, optional, default=1.1 + Location parameter for the prior distribution of the capacity. + exogenous_effects : List[Tuple[str, BaseEffect, str]] List of exogenous effects to apply to the data. Each item of the list is a tuple with the name of the effect, the effect object, and the regex pattern to match the columns of the dataframe. + default_effect : Optional[BaseEffect] Default effect to apply to the columns that do not match any regex pattern. If None, a LinearEffect is used. @@ -770,7 +809,14 @@ class BaseEffectsBayesianForecaster(_HeterogenousMetaEstimator, BaseBayesianFore def __init__( self, - exogenous_effects: List[BaseEffect], + trend: Union[BaseEffect, str] = "linear", + changepoint_interval: int = 25, + changepoint_range: Union[float, int] = 0.8, + changepoint_prior_scale: float = 0.001, + offset_prior_scale: float = 0.1, + capacity_prior_scale=0.2, + capacity_prior_loc=1.1, + exogenous_effects: Optional[List[BaseEffect]] = None, default_effect: Optional[BaseEffect] = None, rng_key: jax.typing.ArrayLike = None, inference_method: str = "map", @@ -783,6 +829,16 @@ def __init__( scale=None, ): + # Trend related hyperparams + self.trend = trend + self.changepoint_interval = changepoint_interval + self.changepoint_range = changepoint_range + self.changepoint_prior_scale = changepoint_prior_scale + self.offset_prior_scale = offset_prior_scale + self.capacity_prior_scale = capacity_prior_scale + self.capacity_prior_loc = capacity_prior_loc + + # Exogenous variables related hyperparams self.exogenous_effects = exogenous_effects self.default_effect = default_effect super().__init__( @@ -824,7 +880,9 @@ def _exogenous_effects(self, value): for ((name, effect), (_, _, regex)) in zip(value, self.exogenous_effects) ] - def _fit_effects(self, X: Union[None, pd.DataFrame]): + def _fit_effects( + self, X: Union[None, pd.DataFrame], y: Optional[pd.DataFrame] = None + ): """ Set custom effects for the features. @@ -842,14 +900,16 @@ def _fit_effects(self, X: Union[None, pd.DataFrame]): for effect_name, effect, regex in exogenous_effects: if X is not None: - columns = self.match_columns(X.columns, regex) + columns = self._match_columns(X.columns, regex) X_columns = X[columns] else: X_columns = None effect = effect.clone() - effect.fit(X_columns, scale=self._scale) # type: ignore[attr-defined] + effect.fit( # type: ignore[attr-defined] + X=X_columns, y=y, scale=self._scale + ) if columns_with_effects.intersection(columns): msg = "Columns {} are already set".format( @@ -884,7 +944,8 @@ def _fit_effects(self, X: Union[None, pd.DataFrame]): default_effect = default_effect.clone() default_effect.fit( - X[features_without_effects], + X=X[features_without_effects], + y=y, scale=self._scale, # type: ignore[attr-defined] ) fitted_effects_list_.append( @@ -897,7 +958,7 @@ def _fit_effects(self, X: Union[None, pd.DataFrame]): self.exogenous_effects_ = fitted_effects_list_ - def _transform_effects(self, X: pd.DataFrame, stage: Stage = Stage.TRAIN): + def _transform_effects(self, X: pd.DataFrame, fh: pd.Index) -> OrderedDict: """ Get exogenous data array. @@ -905,19 +966,21 @@ def _transform_effects(self, X: pd.DataFrame, stage: Stage = Stage.TRAIN): ---------- X : pd.DataFrame Input data. + fh : pd.Index + Forecasting horizon as an index. Returns ------- dict Dictionary of exogenous data arrays. """ - out = {} + out = OrderedDict() for effect_name, effect, columns in self.exogenous_effects_: # If no columns are found, skip if effect.should_skip_predict: continue - data: Dict[str, jnp.ndarray] = effect.transform(X[columns], stage=stage) + data: Dict[str, jnp.ndarray] = effect.transform(X[columns], fh=fh) out[effect_name] = data return out @@ -938,7 +1001,71 @@ def non_skipped_exogenous_effect(self) -> dict[str, BaseEffect]: if not effect.should_skip_predict } - def match_columns( + def _get_trend_model(self): + """ + Return the trend model based on the specified trend parameter. + + Returns + ------- + BaseEffect + The trend model based on the specified trend parameter. + + Raises + ------ + ValueError + If the trend parameter is not one of 'linear', 'logistic', 'flat' + or a BaseEffect instance. + """ + # Changepoints and trend + if self.trend == "linear": + return PiecewiseLinearTrend( + changepoint_interval=self.changepoint_interval, + changepoint_range=self.changepoint_range, + changepoint_prior_scale=self.changepoint_prior_scale, + offset_prior_scale=self.offset_prior_scale, + ) + + elif self.trend == "logistic": + return PiecewiseLogisticTrend( + changepoint_interval=self.changepoint_interval, + changepoint_range=self.changepoint_range, + changepoint_prior_scale=self.changepoint_prior_scale, + offset_prior_scale=self.offset_prior_scale, + capacity_prior=dist.TransformedDistribution( + dist.HalfNormal(self.capacity_prior_scale), + dist.transforms.AffineTransform( + loc=self.capacity_prior_loc, scale=1 + ), + ), + ) + elif self.trend == "flat": + return FlatTrend(changepoint_prior_scale=self.changepoint_prior_scale) + + elif isinstance(self.trend, BaseEffect): + return self.trend + + raise ValueError( + "trend must be either 'linear', 'logistic' or a BaseEffect instance." + ) + + def _validate_hyperparams(self): + """Validate the hyperparameters.""" + if self.changepoint_interval <= 0: + raise ValueError("changepoint_interval must be greater than 0.") + if self.changepoint_prior_scale <= 0: + raise ValueError("changepoint_prior_scale must be greater than 0.") + if self.capacity_prior_scale <= 0: + raise ValueError("capacity_prior_scale must be greater than 0.") + if self.capacity_prior_loc <= 0: + raise ValueError("capacity_prior_loc must be greater than 0.") + if self.offset_prior_scale <= 0: + raise ValueError("offset_prior_scale must be greater than 0.") + if self.trend not in ["linear", "logistic", "flat"] and not isinstance( + self.trend, BaseEffect + ): + raise ValueError('trend must be either "linear" or "logistic".') + + def _match_columns( self, columns: Union[pd.Index, List[str]], regex: Union[str, None] ) -> pd.Index: """Match the columns of the DataFrame with the regex pattern. diff --git a/src/prophetverse/sktime/multivariate.py b/src/prophetverse/sktime/multivariate.py index bd49a82..92826ad 100644 --- a/src/prophetverse/sktime/multivariate.py +++ b/src/prophetverse/sktime/multivariate.py @@ -1,28 +1,21 @@ """Contains the implementation of the HierarchicalProphet forecaster.""" -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union -import jax.numpy as jnp import numpy as np import pandas as pd -from numpyro import distributions as dist from sktime.forecasting.base import ForecastingHorizon from sktime.transformations.base import BaseTransformer from sktime.transformations.hierarchical.aggregate import Aggregator from prophetverse.models import multivariate_model -from prophetverse.sktime.base import BaseEffectsBayesianForecaster, Stage -from prophetverse.trend.piecewise import ( - PiecewiseLinearTrend, - PiecewiseLogisticTrend, - TrendModel, -) +from prophetverse.sktime.base import BaseEffect, BaseProphetForecaster from prophetverse.utils import loc_bottom_series, reindex_time_series, series_to_tensor from ._expand_column_per_level import ExpandColumnPerLevel -class HierarchicalProphet(BaseEffectsBayesianForecaster): +class HierarchicalProphet(BaseProphetForecaster): """A Bayesian hierarchical time series forecasting model based on the Prophet. This class forecasts all series in a hierarchy at once, using a MultivariateNormal @@ -34,49 +27,77 @@ class HierarchicalProphet(BaseEffectsBayesianForecaster): Parameters ---------- - changepoint_interval : int - The number of points between each potential changepoint. - changepoint_range : float + trend : Union[str, BaseEffect], optional, one of "linear" (default) or "logistic" + Type of trend to use. Can also be a custom effect object. + + changepoint_interval : int, optional, default=25 + Number of potential changepoints to sample in the history. + + changepoint_range : float or int, optional, default=0.8 Proportion of the history in which trend changepoints will be estimated. - If a float between 0 and 1, the range will be that proportion of the history. - If an int, the range will be that number of points. A negative int indicates the - number of points counting from the end of the history. - changepoint_prior_scale : float - Parameter controlling the flexibility of the automatic changepoint selection. + + * if float, must be between 0 and 1. + The range will be that proportion of the training history. + + * if int, can be positive or negative. + Absolute value must be less than number of training points. + The range will be that number of points. + A negative int indicates number of points + counting from the end of the history, a positive int from the beginning. + + changepoint_prior_scale : float, optional, default=0.001 + Regularization parameter controlling the flexibility + of the automatic changepoint selection. + offset_prior_scale : float, optional, default=0.1 Scale parameter for the prior distribution of the offset. + The offset is the constant term in the piecewise trend equation. + capacity_prior_scale : float, optional, default=0.2 - Scale parameter for the capacity prior. + Scale parameter for the prior distribution of the capacity. + capacity_prior_loc : float, optional, default=1.1 - Location parameter for the capacity prior. - trend : str, optional, default='linear' - Type of trend. Either "linear" or "logistic". + Location parameter for the prior distribution of the capacity. + feature_transformer : BaseTransformer or None, optional, default=None A transformer to preprocess the exogenous features. + exogenous_effects : list of AbstractEffect, optional, default=None A list defining the exogenous effects to be used in the model. + default_effect : AbstractEffect, optional, default=None The default effect to be used when no effect is specified for a variable. + shared_features : list, optional, default=[] List of shared features across series. + mcmc_samples : int, optional, default=2000 Number of MCMC samples to draw. + mcmc_warmup : int, optional, default=200 Number of warmup steps for MCMC. + mcmc_chains : int, optional, default=4 Number of MCMC chains. + inference_method : str, optional, default='map' Inference method to use. Either "map" or "mcmc". + optimizer_name : str, optional, default='Adam' Name of the optimizer to use. + optimizer_kwargs : dict, optional, default={'step_size': 1e-4} Additional keyword arguments for the optimizer. + optimizer_steps : int, optional, default=100_000 Number of optimization steps. + noise_scale : float, optional, default=0.05 Scale parameter for the noise. + correlation_matrix_concentration : float, optional, default=1.0 Concentration parameter for the correlation matrix. + rng_key : jax.random.PRNGKey, optional, default=None Random number generator key. """ @@ -105,13 +126,13 @@ class HierarchicalProphet(BaseEffectsBayesianForecaster): def __init__( self, - changepoint_interval=25, - changepoint_range=0.8, - changepoint_prior_scale=0.1, - offset_prior_scale=0.1, + trend: Union[BaseEffect, str] = "linear", + changepoint_interval: int = 25, + changepoint_range: Union[float, int] = 0.8, + changepoint_prior_scale: float = 0.001, + offset_prior_scale: float = 0.1, capacity_prior_scale=0.2, capacity_prior_loc=1.1, - trend="linear", feature_transformer: BaseTransformer = None, exogenous_effects=None, default_effect=None, @@ -128,19 +149,24 @@ def __init__( rng_key=None, ): - self.changepoint_interval = changepoint_interval - self.changepoint_range = changepoint_range - self.changepoint_prior_scale = changepoint_prior_scale - self.offset_prior_scale = offset_prior_scale self.noise_scale = noise_scale - self.capacity_prior_scale = capacity_prior_scale - self.capacity_prior_loc = capacity_prior_loc - self.trend = trend self.shared_features = shared_features self.feature_transformer = feature_transformer self.correlation_matrix_concentration = correlation_matrix_concentration super().__init__( + # Trend + trend=trend, + changepoint_interval=changepoint_interval, + changepoint_range=changepoint_range, + changepoint_prior_scale=changepoint_prior_scale, + offset_prior_scale=offset_prior_scale, + capacity_prior_scale=capacity_prior_scale, + capacity_prior_loc=capacity_prior_loc, + # Exog effects + default_effect=default_effect, + exogenous_effects=exogenous_effects, + # Base Bayesian forecaster rng_key=rng_key, inference_method=inference_method, optimizer_name=optimizer_name, @@ -149,8 +175,6 @@ def __init__( mcmc_samples=mcmc_samples, mcmc_warmup=mcmc_warmup, mcmc_chains=mcmc_chains, - default_effect=default_effect, - exogenous_effects=exogenous_effects, ) self.model = multivariate_model # type: ignore[method-assign] @@ -158,19 +182,10 @@ def __init__( def _validate_hyperparams(self): """Validate the hyperparameters of the HierarchicalProphet forecaster.""" - if self.changepoint_interval <= 0: - raise ValueError("changepoint_interval must be greater than 0.") + super()._validate_hyperparams() - if self.changepoint_prior_scale <= 0: - raise ValueError("changepoint_prior_scale must be greater than 0.") if self.noise_scale <= 0: raise ValueError("noise_scale must be greater than 0.") - if self.capacity_prior_scale <= 0: - raise ValueError("capacity_prior_scale must be greater than 0.") - if self.capacity_prior_loc <= 0: - raise ValueError("capacity_prior_loc must be greater than 0.") - if self.offset_prior_scale <= 0: - raise ValueError("offset_prior_scale must be greater than 0.") if self.correlation_matrix_concentration <= 0: raise ValueError("correlation_matrix_concentration must be greater than 0.") @@ -198,6 +213,7 @@ def _get_fit_data(self, y, X, fh): # Handling series without __total indexes self.aggregator_ = Aggregator() self.original_y_indexes_ = y.index + fh = y.index.get_level_values(-1).unique() y = self.aggregator_.fit_transform(y) # Updating internal _y of sktime because BaseBayesianForecaster @@ -209,51 +225,19 @@ def _get_fit_data(self, y, X, fh): y_bottom = loc_bottom_series(y) y_bottom_arrays = series_to_tensor(y_bottom) - # Changepoints and trend - if self.trend == "linear": - self.trend_model_ = PiecewiseLinearTrend( - changepoint_interval=self.changepoint_interval, - changepoint_range=self.changepoint_range, - changepoint_prior_scale=self.changepoint_prior_scale, - offset_prior_scale=self.offset_prior_scale, - squeeze_if_single_series=False, - ) - - elif self.trend == "logistic": - self.trend_model_ = PiecewiseLogisticTrend( - changepoint_interval=self.changepoint_interval, - changepoint_range=self.changepoint_range, - changepoint_prior_scale=self.changepoint_prior_scale, - offset_prior_scale=self.offset_prior_scale, - capacity_prior=dist.TransformedDistribution( - dist.HalfNormal(self.capacity_prior_scale), - dist.transforms.AffineTransform( - loc=self.capacity_prior_loc, scale=1 - ), - ), - squeeze_if_single_series=False, - ) - - elif isinstance(self.trend, TrendModel): - self.trend_model_ = self.trend - else: - raise ValueError( - "trend must be either 'linear', 'logistic' or a TrendModel instance." - ) - - self.trend_model_.initialize(y_bottom) - fh = y.index.get_level_values(-1).unique() - trend_data = self.trend_model_.fit(fh) - - # Exog variables - # If no exogenous variables, create empty DataFrame # Else, aggregate exogenous variables and transform them if X is None or X.columns.empty: X = pd.DataFrame(index=y.index) + + X_bottom = loc_bottom_series(X) + if self.feature_transformer is not None: - X = self.feature_transformer.fit_transform(X) - self._has_exogenous_variables = X is not None and not X.columns.empty + X_bottom = self.feature_transformer.fit_transform(X_bottom) + + self._has_exogenous_variables = ( + X_bottom is not None and not X_bottom.columns.empty + ) if self._has_exogenous_variables: shared_features = self.shared_features @@ -261,19 +245,21 @@ def _get_fit_data(self, y, X, fh): shared_features = [] self.expand_columns_transformer_ = ExpandColumnPerLevel( - X.columns.difference(shared_features).to_list() - ).fit(X) - X = X.loc[y_bottom.index] - X = self.expand_columns_transformer_.transform(X) + X_bottom.columns.difference(shared_features).to_list() + ).fit(X_bottom) + X_bottom = self.expand_columns_transformer_.transform(X_bottom) else: self._exogenous_effects_and_columns = {} exogenous_data = {} - self._fit_effects(loc_bottom_series(X)) - exogenous_data = self._transform_effects( - loc_bottom_series(X), stage=Stage.TRAIN - ) + # Trend model + self.trend_model_ = self._get_trend_model() + self.trend_model_.fit(X=X_bottom, y=y_bottom, scale=self._scale) + trend_data = self.trend_model_.transform(X=X_bottom, fh=fh) + + self._fit_effects(X_bottom, y_bottom) + exogenous_data = self._transform_effects(X_bottom, fh=fh) self.fit_and_predict_data_ = { "trend_model": self.trend_model_, @@ -290,44 +276,6 @@ def _get_fit_data(self, y, X, fh): **self.fit_and_predict_data_, ) - def _get_exogenous_matrix_from_X(self, X: pd.DataFrame) -> jnp.ndarray: - """ - Convert the exogenous variables to a NumPyro matrix. - - Parameters - ---------- - X: pd.DataFrame - The exogenous variables. - - Return - ------ - jnp.ndarray - The NumPyro matrix of the exogenous variables. - """ - X_bottom = loc_bottom_series(X) - X_arrays = series_to_tensor(X_bottom) - - return X_arrays - - def predict_samples( - self, fh: ForecastingHorizon, X: Optional[pd.DataFrame] = None - ) -> np.ndarray: - """Generate samples for the given exogenous variables and forecasting horizon. - - Parameters - ---------- - X (pd.DataFrame): Exogenous variables. - fh (ForecastingHorizon): Forecasting horizon. - - Returns - ------- - np.ndarray - Predicted samples. - """ - samples = super().predict_samples(X=X, fh=fh) - - return self.aggregator_.transform(samples) - def _get_predict_data(self, X: pd.DataFrame, fh: ForecastingHorizon) -> np.ndarray: """Generate samples for the given exogenous variables and forecasting horizon. @@ -351,25 +299,24 @@ def _get_predict_data(self, X: pd.DataFrame, fh: ForecastingHorizon) -> np.ndarr if not isinstance(fh, ForecastingHorizon): fh = self._check_fh(fh) - trend_data = self.trend_model_.fit(fh_as_index) - if X is None or X.shape[1] == 0: idx = reindex_time_series(self._y, fh_as_index).index X = pd.DataFrame(index=idx) X = self.aggregator_.transform(X) - X = X.loc[X.index.get_level_values(-1).isin(fh_as_index)] + X_bottom = loc_bottom_series(X) + if self._has_exogenous_variables: - assert ( - X.index.get_level_values(-1).nunique() == fh_as_index.nunique() - ), "Missing exogenous variables for some series or dates." + + assert fh_as_index.isin( + X_bottom.index.get_level_values(-1) + ).all(), "Missing exogenous variables for some series or dates." if self.feature_transformer is not None: - X = self.feature_transformer.transform(X) - X = self.expand_columns_transformer_.transform(X) + X_bottom = self.feature_transformer.transform(X_bottom) + X_bottom = self.expand_columns_transformer_.transform(X_bottom) - exogenous_data = self._transform_effects( - loc_bottom_series(X), stage=Stage.PREDICT - ) + trend_data = self.trend_model_.transform(X=X_bottom, fh=fh_as_index) + exogenous_data = self._transform_effects(X=X_bottom, fh=fh_as_index) return dict( y=None, @@ -378,6 +325,25 @@ def _get_predict_data(self, X: pd.DataFrame, fh: ForecastingHorizon) -> np.ndarr **self.fit_and_predict_data_, ) + def predict_samples( + self, fh: ForecastingHorizon, X: Optional[pd.DataFrame] = None + ) -> np.ndarray: + """Generate samples for the given exogenous variables and forecasting horizon. + + Parameters + ---------- + X (pd.DataFrame): Exogenous variables. + fh (ForecastingHorizon): Forecasting horizon. + + Returns + ------- + np.ndarray + Predicted samples. + """ + samples = super().predict_samples(X=X, fh=fh) + + return self.aggregator_.transform(samples) + def _filter_series_tuples(self, levels: List[Tuple]) -> List[Tuple]: """Filter series tuples, returning only series of interest. diff --git a/src/prophetverse/sktime/univariate.py b/src/prophetverse/sktime/univariate.py index 3f19c17..5ef801d 100644 --- a/src/prophetverse/sktime/univariate.py +++ b/src/prophetverse/sktime/univariate.py @@ -5,23 +5,19 @@ """ +from typing import List, Optional, Union + import jax.numpy as jnp import pandas as pd -from numpyro import distributions as dist from sktime.forecasting.base import ForecastingHorizon +from prophetverse.effects import BaseEffect from prophetverse.models import ( univariate_gamma_model, univariate_model, univariate_negbinomial_model, ) -from prophetverse.sktime.base import BaseEffectsBayesianForecaster, Stage -from prophetverse.trend.flat import FlatTrend -from prophetverse.trend.piecewise import ( - PiecewiseLinearTrend, - PiecewiseLogisticTrend, - TrendModel, -) +from prophetverse.sktime.base import BaseProphetForecaster __all__ = ["Prophetverse", "Prophet", "ProphetGamma", "ProphetNegBinomial"] @@ -35,7 +31,7 @@ _DISCRETE_LIKELIHOODS = ["negbinomial"] -class Prophetverse(BaseEffectsBayesianForecaster): +class Prophetverse(BaseProphetForecaster): """Univariate ``Prophetverse`` forecaster, with support for multiple likelihoods. Differences to facebook's prophet: @@ -58,6 +54,9 @@ class Prophetverse(BaseEffectsBayesianForecaster): Parameters ---------- + trend : Union[str, BaseEffect], optional, one of "linear" (default) or "logistic" + Type of trend to use. Can also be a custom effect object. + changepoint_interval : int, optional, default=25 Number of potential changepoints to sample in the history. @@ -67,7 +66,7 @@ class Prophetverse(BaseEffectsBayesianForecaster): * if float, must be between 0 and 1. The range will be that proportion of the training history. - * if int, ca nbe positive or negative. + * if int, can be positive or negative. Absolute value must be less than number of training points. The range will be that number of points. A negative int indicates number of points @@ -81,21 +80,19 @@ class Prophetverse(BaseEffectsBayesianForecaster): Scale parameter for the prior distribution of the offset. The offset is the constant term in the piecewise trend equation. - feature_transformer : sktime transformer, BaseTransformer, optional, default=None - Transformer object to generate Fourier terms, holiday or other features. - If None, no additional features are used. - For multiple features, pass a ``FeatureUnion`` object with the transformers. - capacity_prior_scale : float, optional, default=0.2 Scale parameter for the prior distribution of the capacity. capacity_prior_loc : float, optional, default=1.1 Location parameter for the prior distribution of the capacity. + feature_transformer : sktime transformer, BaseTransformer, optional, default=None + Transformer object to generate Fourier terms, holiday or other features. + If None, no additional features are used. + For multiple features, pass a ``FeatureUnion`` object with the transformers. + noise_scale : float, optional, default=0.05 Scale parameter for the observation noise. - trend : str, optional, one of "linear" (default) or "logistic" - Type of trend to use. Can be "linear" or "logistic". mcmc_samples : int, optional, default=2000 Number of MCMC samples to draw. @@ -152,15 +149,17 @@ class Prophetverse(BaseEffectsBayesianForecaster): def __init__( self, - changepoint_interval=25, - changepoint_range=0.8, - changepoint_prior_scale=0.001, - offset_prior_scale=0.1, - feature_transformer=None, + trend: Union[BaseEffect, str] = "linear", + changepoint_interval: int = 25, + changepoint_range: Union[float, int] = 0.8, + changepoint_prior_scale: float = 0.001, + offset_prior_scale: float = 0.1, capacity_prior_scale=0.2, capacity_prior_loc=1.1, + exogenous_effects: Optional[List[BaseEffect]] = None, + default_effect: Optional[BaseEffect] = None, + feature_transformer=None, noise_scale=0.05, - trend="linear", mcmc_samples=2000, mcmc_warmup=200, mcmc_chains=4, @@ -168,27 +167,27 @@ def __init__( optimizer_name="Adam", optimizer_kwargs=None, optimizer_steps=100_000, - exogenous_effects=None, likelihood="normal", - default_effect=None, scale=None, rng_key=None, ): """Initialize the Prophet model.""" - self.changepoint_interval = changepoint_interval - self.changepoint_range = changepoint_range - self.changepoint_prior_scale = changepoint_prior_scale - self.offset_prior_scale = offset_prior_scale self.noise_scale = noise_scale self.feature_transformer = feature_transformer - self.capacity_prior_scale = capacity_prior_scale - self.capacity_prior_loc = capacity_prior_loc - self.trend = trend + self.likelihood = likelihood super().__init__( rng_key=rng_key, - # ExogenousEffectMixin + # Trend + trend=trend, + changepoint_interval=changepoint_interval, + changepoint_range=changepoint_range, + changepoint_prior_scale=changepoint_prior_scale, + offset_prior_scale=offset_prior_scale, + capacity_prior_scale=capacity_prior_scale, + capacity_prior_loc=capacity_prior_loc, + # Exog default_effect=default_effect, exogenous_effects=exogenous_effects, # BaseBayesianForecaster @@ -216,7 +215,7 @@ def model(self): return _LIKELIHOOD_MODEL_MAP[self.likelihood] @property - def should_skip_scaling(self) -> bool: + def _likelihood_is_discrete(self) -> bool: """Skip scaling if the likelihood is discrete. In the case of discrete likelihoods, the data is not scaled since this can @@ -226,23 +225,10 @@ def should_skip_scaling(self) -> bool: def _validate_hyperparams(self): """Validate the hyperparameters.""" - if self.changepoint_interval <= 0: - raise ValueError("changepoint_interval must be greater than 0.") + super()._validate_hyperparams() - if self.changepoint_prior_scale <= 0: - raise ValueError("changepoint_prior_scale must be greater than 0.") if self.noise_scale <= 0: raise ValueError("noise_scale must be greater than 0.") - if self.capacity_prior_scale <= 0: - raise ValueError("capacity_prior_scale must be greater than 0.") - if self.capacity_prior_loc <= 0: - raise ValueError("capacity_prior_loc must be greater than 0.") - if self.offset_prior_scale <= 0: - raise ValueError("offset_prior_scale must be greater than 0.") - if self.trend not in ["linear", "logistic", "flat"] and not isinstance( - self.trend, TrendModel - ): - raise ValueError('trend must be either "linear" or "logistic".') if self.likelihood not in _LIKELIHOOD_MODEL_MAP: raise ValueError( @@ -272,12 +258,12 @@ def _get_fit_data(self, y, X, fh): self.trend_model_ = self._get_trend_model() - if self.should_skip_scaling: - self.trend_model_.initialize(y / self._scale) + if self._likelihood_is_discrete: + # Scale the data, since _get_fit_data receives + # a non-scaled y for discrete likelihoods + self.trend_model_.fit(X=X, y=y / self._scale) else: - self.trend_model_.initialize(y) - - trend_data = self.trend_model_.fit(fh) + self.trend_model_.fit(X=X, y=y) # Exogenous features @@ -291,8 +277,10 @@ def _get_fit_data(self, y, X, fh): self._has_exogenous = ~X.columns.empty X = X.loc[y.index] - self._fit_effects(X) - exogenous_data = self._transform_effects(X, stage=Stage.TRAIN) + trend_data = self.trend_model_.transform(X=X, fh=fh) + + self._fit_effects(X, y) + exogenous_data = self._transform_effects(X, fh=fh) y_array = jnp.array(y.values.flatten()).reshape((-1, 1)) @@ -315,7 +303,9 @@ def _get_fit_data(self, y, X, fh): return inputs - def _get_predict_data(self, X: pd.DataFrame, fh: ForecastingHorizon) -> dict: + def _get_predict_data( + self, X: Union[pd.DataFrame, None], fh: ForecastingHorizon + ) -> dict: """ Prepare the data for making predictions. @@ -334,18 +324,16 @@ def _get_predict_data(self, X: pd.DataFrame, fh: ForecastingHorizon) -> dict: fh_dates = self.fh_to_index(fh) fh_as_index = pd.Index(list(fh_dates.to_numpy())) - trend_data = self.trend_model_.fit(fh_as_index) - if X is None: X = pd.DataFrame(index=fh_as_index) if self.feature_transformer is not None: X = self.feature_transformer.transform(X) + trend_data = self.trend_model_.transform(X=X, fh=fh_as_index) + exogenous_data = ( - self._transform_effects(X.loc[fh_as_index], stage=Stage.PREDICT) - if self._has_exogenous - else None + self._transform_effects(X, fh_as_index) if self._has_exogenous else None ) return dict( @@ -355,53 +343,6 @@ def _get_predict_data(self, X: pd.DataFrame, fh: ForecastingHorizon) -> dict: **self.fit_and_predict_data_, ) - def _get_trend_model(self): - """ - Return the trend model based on the specified trend parameter. - - Returns - ------- - TrendModel - The trend model based on the specified trend parameter. - - Raises - ------ - ValueError - If the trend parameter is not one of 'linear', 'logistic', 'flat' - or a TrendModel instance. - """ - # Changepoints and trend - if self.trend == "linear": - return PiecewiseLinearTrend( - changepoint_interval=self.changepoint_interval, - changepoint_range=self.changepoint_range, - changepoint_prior_scale=self.changepoint_prior_scale, - offset_prior_scale=self.offset_prior_scale, - ) - - elif self.trend == "logistic": - return PiecewiseLogisticTrend( - changepoint_interval=self.changepoint_interval, - changepoint_range=self.changepoint_range, - changepoint_prior_scale=self.changepoint_prior_scale, - offset_prior_scale=self.offset_prior_scale, - capacity_prior=dist.TransformedDistribution( - dist.HalfNormal(self.capacity_prior_scale), - dist.transforms.AffineTransform( - loc=self.capacity_prior_loc, scale=1 - ), - ), - ) - elif self.trend == "flat": - return FlatTrend(changepoint_prior_scale=self.changepoint_prior_scale) - - elif isinstance(self.trend, TrendModel): - return self.trend - - raise ValueError( - "trend must be either 'linear', 'logistic' or a TrendModel instance." - ) - @classmethod def get_test_params(cls, parameter_set="default"): # pragma: no cover """Params to be used in sktime unit tests. diff --git a/src/prophetverse/trend/__init__.py b/src/prophetverse/trend/__init__.py deleted file mode 100644 index b7c47b1..0000000 --- a/src/prophetverse/trend/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Trend models to be used in ProphetVerse. - -This module contains the following trend models: - -- FlatTrend: A flat trend model that does not change over time. -- PiecewiseLinearTrend: A piecewise linear trend model that changes linearly over time. -- PiecewiseLogisticTrend: A piecewise logistic trend model that changes logistically - over time. -""" - -from .base import TrendModel -from .flat import FlatTrend -from .piecewise import PiecewiseLinearTrend, PiecewiseLogisticTrend - -__all__ = [ - "TrendModel", - "FlatTrend", - "PiecewiseLinearTrend", - "PiecewiseLogisticTrend", -] diff --git a/src/prophetverse/trend/base.py b/src/prophetverse/trend/base.py deleted file mode 100644 index 7b02421..0000000 --- a/src/prophetverse/trend/base.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Module containing the base class for trend models.""" - -from abc import ABC, abstractmethod - -import pandas as pd - -from prophetverse.utils.frame_to_array import convert_index_to_days_since_epoch - - -class TrendModel(ABC): - """ - Abstract base class for trend models. - - Attributes - ---------- - t_scale: float - The time scale of the trend model. - t_start: float - The starting time of the trend model. - n_series: int - The number of series in the time series data. - """ - - def initialize(self, y: pd.DataFrame) -> None: - """Initialize trend model with the timeseries data. - - This method is close to what "fit" is in sktime/sklearn estimators. - Child classes should implement this method to initialize the model and - may call super().initialize() to perform common initialization steps. - - Parameters - ---------- - y: pd.DataFrame - time series dataframe, may be multiindex - """ - # Set time scale - t_days = convert_index_to_days_since_epoch( - y.index.get_level_values(-1).unique() - ) - self.t_scale = (t_days[1:] - t_days[:-1]).mean() - self.t_start = t_days.min() / self.t_scale - if y.index.nlevels > 1: - self.n_series = y.index.droplevel(-1).nunique() - else: - self.n_series = 1 - - @abstractmethod - def fit(self, idx: pd.PeriodIndex) -> dict: - """Return a dictionary containing the data needed for the trend model. - - All arguments in the signature of compute_trend should be keys in the - dictionary. - - For example, given t, a possible implementation would return - ```python - { - "changepoint_matrix": jnp.array([[1, 0], [0, 1]]), - } - ``` - - And compute_trend would be defined as - - ```python - def compute_trend(self, changepoint_matrix): - ... - ``` - - Parameters - ---------- - idx: pd.PeriodIndex - The index of the time series data. - - Returns - ------- - dict: A dictionary containing the data needed for the trend model. - - """ - ... - - @abstractmethod - def compute_trend(self, **kwargs): - """Compute the trend. - - Receive the output of fit as keyword arguments. - - Returns - ------- - jnp.ndarray: array with trend data for each time step and series. - """ - ... - - def _index_to_scaled_timearray(self, idx): - """ - Convert the index to a scaled time array. - - Parameters - ---------- - idx: int - The index to be converted. - - Returns - ------- - float - The scaled time array value. - """ - t_days = convert_index_to_days_since_epoch(idx) - return (t_days) / self.t_scale - self.t_start - - def __call__(self, **kwargs): - """Compute the trend. - - Parameters - ---------- - **kwargs : dict - The keyword arguments to be passed to the compute_trend method. - """ - return self.compute_trend(**kwargs) diff --git a/src/prophetverse/utils/multiindex.py b/src/prophetverse/utils/multiindex.py index 70e7e56..5d68ddc 100644 --- a/src/prophetverse/utils/multiindex.py +++ b/src/prophetverse/utils/multiindex.py @@ -26,7 +26,12 @@ def get_bottom_series_idx(y): pd.Index The index of the bottom series. """ - return _get_s_matrix(y).columns + if y.index.nlevels == 1: + raise ValueError("y must be a multi-index DataFrame") + if y.index.nlevels == 2: + return pd.Index([x for x in y.index.droplevel(-1).unique() if x != "__total"]) + series = pd.Index([x for x in y.index.droplevel(-1).unique() if x[-1] != "__total"]) + return series def loc_bottom_series(y): diff --git a/tests/conftest.py b/tests/conftest.py index dd886c1..cec4f64 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,9 +35,9 @@ class ConcreteEffect(BaseAdditiveOrMultiplicativeEffect): _tags = {"skip_predict_if_no_match": False} - def _predict(self, trend: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray: + def _predict(self, data: jnp.ndarray, predicted_effects=None) -> jnp.ndarray: """Calculate simple effect.""" - return jnp.mean(data, axis=0) + return jnp.mean(data, axis=1) @pytest.fixture(name="effect_with_regex") diff --git a/tests/effects/test_base.py b/tests/effects/test_base.py index a8338e1..6a93917 100644 --- a/tests/effects/test_base.py +++ b/tests/effects/test_base.py @@ -7,19 +7,19 @@ @pytest.mark.smoke def test__predict(effect_with_regex): - trend = jnp.array([1.0, 2.0, 3.0]) - data = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - result = effect_with_regex._predict(trend, data) - expected_result = jnp.mean(data, axis=0) + trend = jnp.array([1.0, 2.0, 3.0]).reshape((-1, 1)) + data = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]).reshape((-1, 2)) + result = effect_with_regex.predict(data, predicted_effects={"trend": trend}) + expected_result = jnp.mean(data, axis=1).reshape((-1, 1)) assert jnp.allclose(result, expected_result) @pytest.mark.smoke def test_call(effect_with_regex): - trend = jnp.array([1.0, 2.0, 3.0]) - data = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - result = effect_with_regex(trend, data=data) - expected_result = jnp.mean(data, axis=0) + trend = jnp.array([1.0, 2.0, 3.0]).reshape((-1, 1)) + data = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]).reshape((-1, 2)) + result = effect_with_regex(data=data, predicted_effects={"trend": trend}) + expected_result = jnp.mean(data, axis=1).reshape((-1, 1)) assert jnp.allclose(result, expected_result) @@ -30,4 +30,4 @@ def test_bad_effect_mode(): def test_not_fitted(): with pytest.raises(ValueError): - BaseEffect().transform(pd.DataFrame()) + BaseEffect().transform(pd.DataFrame(), fh=pd.Index([])) diff --git a/tests/effects/test_fourier.py b/tests/effects/test_fourier.py index 4678bdc..5edf943 100644 --- a/tests/effects/test_fourier.py +++ b/tests/effects/test_fourier.py @@ -37,7 +37,7 @@ def test_linear_fourier_seasonality_initialization(fourier_effect_instance): def test_linear_fourier_seasonality_fit(fourier_effect_instance, exog_data): - fourier_effect_instance.fit(exog_data) + fourier_effect_instance.fit(X=exog_data, y=None) assert hasattr(fourier_effect_instance, "fourier_features_") assert hasattr(fourier_effect_instance, "linear_effect_") assert isinstance(fourier_effect_instance.fourier_features_, FourierFeatures) @@ -45,20 +45,23 @@ def test_linear_fourier_seasonality_fit(fourier_effect_instance, exog_data): def test_linear_fourier_seasonality_transform(fourier_effect_instance, exog_data): - fourier_effect_instance.fit(exog_data) - transformed = fourier_effect_instance.transform(exog_data, stage="train") + fh = exog_data.index.get_level_values(-1).unique() + fourier_effect_instance.fit(X=exog_data, y=None) + transformed = fourier_effect_instance.transform(X=exog_data, fh=fh) fourier_transformed = fourier_effect_instance.fourier_features_.transform(exog_data) - assert isinstance(transformed, dict) - assert "data" in transformed - assert transformed["data"].shape == fourier_transformed.shape + assert isinstance(transformed, jnp.ndarray) + assert transformed.shape == fourier_transformed.shape def test_linear_fourier_seasonality_predict(fourier_effect_instance, exog_data): - fourier_effect_instance.fit(exog_data) + fh = exog_data.index.get_level_values(-1).unique() + fourier_effect_instance.fit(X=exog_data, y=None) trend = jnp.array([1.0] * len(exog_data)) - data = fourier_effect_instance.transform(exog_data, stage="predict") + data = fourier_effect_instance.transform(exog_data, fh=fh) with numpyro.handlers.seed(numpyro.handlers.seed, 0): - prediction = fourier_effect_instance.predict(trend, **data) + prediction = fourier_effect_instance.predict( + data, predicted_effects={"trend": trend} + ) assert prediction is not None assert isinstance(prediction, jnp.ndarray) diff --git a/tests/effects/test_hill.py b/tests/effects/test_hill.py index abc09c3..a55d8d7 100644 --- a/tests/effects/test_hill.py +++ b/tests/effects/test_hill.py @@ -37,11 +37,13 @@ def test_initialization_defaults(): def test__predict_multiplicative(hill_effect_multiplicative): - trend = jnp.array([1.0, 2.0, 3.0]) - data = jnp.array([0.5, 1.0, 1.5]) + trend = jnp.array([1.0, 2.0, 3.0]).reshape((-1, 1)) + data = jnp.array([0.5, 1.0, 1.5]).reshape((-1, 1)) with seed(numpyro.handlers.seed, 0): - result = hill_effect_multiplicative.predict(trend, data=data) + result = hill_effect_multiplicative.predict( + data=data, predicted_effects={"trend": trend} + ) half_max, slope, max_effect = 0.5, 1.0, 1.5 x = _exponent_safe(data / half_max, -slope) @@ -52,11 +54,13 @@ def test__predict_multiplicative(hill_effect_multiplicative): def test__predict_additive(hill_effect_additive): - trend = jnp.array([1.0, 2.0, 3.0]) - data = jnp.array([0.5, 1.0, 1.5]) + trend = jnp.array([1.0, 2.0, 3.0]).reshape((-1, 1)) + data = jnp.array([0.5, 1.0, 1.5]).reshape((-1, 1)) with seed(numpyro.handlers.seed, 0): - result = hill_effect_additive.predict(trend, data=data) + result = hill_effect_additive.predict( + data=data, predicted_effects={"trend": trend} + ) half_max, slope, max_effect = 0.5, 1.0, 1.5 x = _exponent_safe(data / half_max, -slope) diff --git a/tests/effects/test_lift_experiment.py b/tests/effects/test_lift_experiment.py index f782132..9f755cd 100644 --- a/tests/effects/test_lift_experiment.py +++ b/tests/effects/test_lift_experiment.py @@ -34,6 +34,11 @@ def X(): ) +@pytest.fixture +def y(X): + return pd.DataFrame(index=X.index, data=[1] * len(X)) + + def test_liftexperimentlikelihood_initialization( lift_experiment_effect_instance, inner_effect, lift_test_results ): @@ -44,36 +49,35 @@ def test_liftexperimentlikelihood_initialization( def test_liftexperimentlikelihood_fit(X, lift_experiment_effect_instance): - lift_experiment_effect_instance.fit(X, scale=1) + lift_experiment_effect_instance.fit(y=y, X=X, scale=1) assert lift_experiment_effect_instance.timeseries_scale == 1 assert lift_experiment_effect_instance.effect._is_fitted -def test_liftexperimentlikelihood_transform_train(X, lift_experiment_effect_instance): - - lift_experiment_effect_instance.fit(X) - transformed = lift_experiment_effect_instance.transform(X, stage="train") +def test_liftexperimentlikelihood_transform_train( + X, y, lift_experiment_effect_instance +): + fh = y.index.get_level_values(-1).unique() + lift_experiment_effect_instance.fit(X=X, y=y) + transformed = lift_experiment_effect_instance.transform( + X, + fh=fh, + ) assert "observed_lift" in transformed assert transformed["observed_lift"] is not None -def test_liftexperimentlikelihood_transform_predict(X, lift_experiment_effect_instance): - lift_experiment_effect_instance.fit(X) - transformed = lift_experiment_effect_instance.transform(X, stage="predict") - assert "observed_lift" in transformed - assert transformed["observed_lift"] is None - +def test_liftexperimentlikelihood_predict(X, y, lift_experiment_effect_instance): + fh = X.index.get_level_values(-1).unique() -def test_liftexperimentlikelihood_predict(X, lift_experiment_effect_instance): trend = jnp.array([1, 2, 3, 4, 5, 6]) - - lift_experiment_effect_instance.fit(X) - data = lift_experiment_effect_instance.transform(X, stage="train") - predicted = lift_experiment_effect_instance.predict(trend, **data) - inner_effect_data = lift_experiment_effect_instance.effect.transform( - X, stage="train" + lift_experiment_effect_instance.fit(X=X, y=y) + data = lift_experiment_effect_instance.transform(X=X, fh=fh) + predicted = lift_experiment_effect_instance.predict( + data=data, predicted_effects={"trend": trend} ) + inner_effect_data = lift_experiment_effect_instance.effect.transform(X, fh=fh) inner_effect_predict = lift_experiment_effect_instance.effect.predict( - trend, **inner_effect_data + data=inner_effect_data, predicted_effects={"trend": trend} ) assert jnp.all(predicted == inner_effect_predict) diff --git a/tests/effects/test_linear.py b/tests/effects/test_linear.py index 179931f..0a10d41 100644 --- a/tests/effects/test_linear.py +++ b/tests/effects/test_linear.py @@ -26,11 +26,13 @@ def test_initialization_defaults(): def test__predict_multiplicative(linear_effect_multiplicative): - trend = jnp.array([1.0, 2.0, 3.0]) + trend = jnp.array([1.0, 2.0, 3.0]).reshape((-1, 1)) data = jnp.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]]) with seed(numpyro.handlers.seed, 0): - result = linear_effect_multiplicative.predict(trend, data=data) + result = linear_effect_multiplicative.predict( + data=data, predicted_effects={"trend": trend} + ) expected_result = trend * (data @ jnp.array([1.0, 1.0]).reshape((-1, 1))) @@ -38,11 +40,13 @@ def test__predict_multiplicative(linear_effect_multiplicative): def test__predict_additive(linear_effect_additive): - trend = jnp.array([1.0, 2.0, 3.0]) + trend = jnp.array([1.0, 2.0, 3.0]).reshape((-1, 1)) data = jnp.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]]) with seed(numpyro.handlers.seed, 0): - result = linear_effect_additive.predict(trend, data=data) + result = linear_effect_additive.predict( + data=data, predicted_effects={"trend": trend} + ) expected_result = data @ jnp.array([1.0, 1.0]).reshape((-1, 1)) diff --git a/tests/effects/test_log.py b/tests/effects/test_log.py index 11ce1f8..526f5f1 100644 --- a/tests/effects/test_log.py +++ b/tests/effects/test_log.py @@ -33,11 +33,13 @@ def test_initialization_defaults(): def test__predict_multiplicative(log_effect_multiplicative): - trend = jnp.array([1.0, 2.0, 3.0]) - data = jnp.array([1.0, 2.0, 3.0]) + trend = jnp.array([1.0, 2.0, 3.0]).reshape((-1, 1)) + data = jnp.array([1.0, 2.0, 3.0]).reshape((-1, 1)) with seed(numpyro.handlers.seed, 0): - result = log_effect_multiplicative.predict(trend, data=data) + result = log_effect_multiplicative.predict( + data=data, predicted_effects={"trend": trend} + ) scale, rate = 0.5, 2.0 expected_effect = scale * jnp.log(rate * data + 1) @@ -47,11 +49,13 @@ def test__predict_multiplicative(log_effect_multiplicative): def test__predict_additive(log_effect_additive): - trend = jnp.array([1.0, 2.0, 3.0]) - data = jnp.array([1.0, 2.0, 3.0]) + trend = jnp.array([1.0, 2.0, 3.0]).reshape((-1, 1)) + data = jnp.array([1.0, 2.0, 3.0]).reshape((-1, 1)) with seed(numpyro.handlers.seed, 0): - result = log_effect_additive.predict(trend, data=data) + result = log_effect_additive.predict( + data=data, predicted_effects={"trend": trend} + ) scale, rate = 0.5, 2.0 expected_result = scale * jnp.log(rate * data + 1) @@ -64,7 +68,9 @@ def test__predict_with_zero_data(log_effect_multiplicative): data = jnp.array([0.0, 0.0, 0.0]) with seed(numpyro.handlers.seed, 0): - result = log_effect_multiplicative.predict(trend, data=data) + result = log_effect_multiplicative.predict( + data=data, predicted_effects={"trend": trend} + ) scale, rate = 0.5, 2.0 expected_effect = scale * jnp.log(rate * data + 1) @@ -78,7 +84,9 @@ def test__predict_with_empty_data(log_effect_multiplicative): data = jnp.array([]) with seed(numpyro.handlers.seed, 0): - result = log_effect_multiplicative.predict(trend, data=data) + result = log_effect_multiplicative.predict( + data=data, predicted_effects={"trend": trend} + ) scale, rate = 0.5, 2.0 expected_effect = scale * jnp.log(rate * data + 1) diff --git a/tests/models/multivariate_model/test_frame_to_array.py b/tests/models/multivariate_model/test_frame_to_array.py index c8c7070..b6872cd 100644 --- a/tests/models/multivariate_model/test_frame_to_array.py +++ b/tests/models/multivariate_model/test_frame_to_array.py @@ -2,6 +2,8 @@ import pandas as pd import pytest from jax import numpy as jnp +from sktime.transformations.hierarchical.aggregate import Aggregator +from sktime.utils._testing.hierarchical import _make_hierarchical from prophetverse.utils import ( convert_dataframe_to_tensors, @@ -44,11 +46,20 @@ def test_get_multiindex_loc(sample_hierarchical_data): # Test for fetching bottom series -def test_loc_bottom_series(sample_hierarchical_data): - result = loc_bottom_series(sample_hierarchical_data) - assert isinstance( - result, pd.DataFrame - ), "Should return a DataFrame of the bottom series" +@pytest.mark.parametrize("hierarchical_levels", [(2, 4, 4), (2,)]) +def test_loc_bottom_series(hierarchical_levels): + hierarchical_data = _make_hierarchical(hierarchical_levels) + aggregated = Aggregator(flatten_single_levels=False).fit_transform( + hierarchical_data + ) + result = loc_bottom_series(aggregated) + pd.testing.assert_frame_equal(result.sort_index(), hierarchical_data.sort_index()) + + +def test_get_bottom_series_idx_raises_error(): + y = pd.DataFrame(data=[1, 2, 3], index=[1, 2, 3], columns=["A"]) + with pytest.raises(ValueError): + get_bottom_series_idx(y) # Test for iterating all series diff --git a/tests/sktime/test_base.py b/tests/sktime/test_base.py index 39a85fa..beaa6be 100644 --- a/tests/sktime/test_base.py +++ b/tests/sktime/test_base.py @@ -2,12 +2,12 @@ import pytest from prophetverse.effects import LinearEffect -from prophetverse.sktime.base import BaseEffectsBayesianForecaster +from prophetverse.sktime.base import BaseProphetForecaster @pytest.fixture def base_effects_bayesian_forecaster(): - return BaseEffectsBayesianForecaster( + return BaseProphetForecaster( exogenous_effects=[ ("effect1", LinearEffect(prior=dist.Normal(10, 2)), r"(x1).*"), ] diff --git a/tests/sktime/test_univariate.py b/tests/sktime/test_univariate.py index 81b1da6..d671ecc 100644 --- a/tests/sktime/test_univariate.py +++ b/tests/sktime/test_univariate.py @@ -2,6 +2,7 @@ from numpyro import distributions as dist from prophetverse.effects.linear import LinearEffect +from prophetverse.effects.trend.flat import FlatTrend from prophetverse.sktime.seasonality import seasonal_transformer from prophetverse.sktime.univariate import ( _DISCRETE_LIKELIHOODS, @@ -11,7 +12,6 @@ ProphetNegBinomial, Prophetverse, ) -from prophetverse.trend.flat import FlatTrend from ._utils import ( execute_extra_predict_methods_tests, @@ -162,4 +162,4 @@ def test_prophetverse_likelihood_behaviour(likelihood): assert model.model == _LIKELIHOOD_MODEL_MAP[likelihood] if likelihood in _DISCRETE_LIKELIHOODS: - assert model.should_skip_scaling + assert model._likelihood_is_discrete diff --git a/tests/trend/test_flat.py b/tests/trend/test_flat.py index a4b8688..88d333c 100644 --- a/tests/trend/test_flat.py +++ b/tests/trend/test_flat.py @@ -4,7 +4,9 @@ import pytest from numpy.testing import assert_almost_equal -from prophetverse.trend.flat import FlatTrend # Assuming this is the import path +from prophetverse.effects.trend.flat import ( + FlatTrend, # Assuming this is the import path +) @pytest.fixture @@ -26,27 +28,28 @@ def test_initialization(trend_model): def test_initialize(trend_model, timeseries_data): - trend_model.initialize(timeseries_data) + trend_model.fit(X=None, y=timeseries_data) expected_loc = timeseries_data["data"].mean() assert_almost_equal(trend_model.changepoint_prior_loc, expected_loc) def test_fit(trend_model, timeseries_data): - idx = timeseries_data.index.to_period("D") - result = trend_model.fit(idx) - assert "constant_vector" in result - assert result["constant_vector"].shape == (len(idx), 1) - assert jnp.all(result["constant_vector"] == 1) + idx = timeseries_data.index + trend_model.fit(X=None, y=timeseries_data) + result = trend_model.transform(X=pd.DataFrame(index=timeseries_data.index), fh=idx) + assert result.shape == (len(idx), 1) + assert jnp.all(result == 1) def test_compute_trend(trend_model, timeseries_data): - idx = timeseries_data.index.to_period("D") - trend_model.initialize(timeseries_data) - data = trend_model.fit(idx) - constant_vector = data["constant_vector"] + idx = timeseries_data.index + trend_model.fit(X=None, y=timeseries_data) + constant_vector = trend_model.transform( + X=pd.DataFrame(index=timeseries_data.index), fh=idx + ) with numpyro.handlers.seed(rng_seed=0): - trend_result = trend_model.compute_trend(constant_vector) + trend_result = trend_model.predict(constant_vector, None) assert jnp.unique(trend_result).shape == (1,) assert trend_result.shape == (len(idx), 1) diff --git a/tests/trend/test_piecewise.py b/tests/trend/test_piecewise.py index f7569f5..75bbe8c 100644 --- a/tests/trend/test_piecewise.py +++ b/tests/trend/test_piecewise.py @@ -2,10 +2,8 @@ import numpyro import pandas as pd import pytest -from numpyro.distributions import Normal -from prophetverse.trend.base import TrendModel -from prophetverse.trend.piecewise import ( +from prophetverse.effects.trend.piecewise import ( PiecewiseLinearTrend, PiecewiseLogisticTrend, _enforce_array_if_zero_dim, @@ -13,7 +11,6 @@ _get_changepoint_timeindexes, _suggest_logistic_rate_and_offset, _to_list_if_scalar, - series_to_tensor, ) @@ -63,7 +60,7 @@ def piecewise_logistic_trend(): # Tests for PiecewiseLinearTrend def test_piecewise_linear_initialize(piecewise_linear_trend, mock_dataframe): - piecewise_linear_trend.initialize(mock_dataframe) + piecewise_linear_trend.fit(mock_dataframe, mock_dataframe) assert hasattr( piecewise_linear_trend, "_changepoint_ts" ), "Changepoint ts not set during initialization." @@ -71,7 +68,10 @@ def test_piecewise_linear_initialize(piecewise_linear_trend, mock_dataframe): # Tests for PiecewiseLogisticTrend def test_piecewise_logistic_initialize(piecewise_logistic_trend, mock_dataframe): - piecewise_logistic_trend.initialize(mock_dataframe) + piecewise_logistic_trend.fit(mock_dataframe, mock_dataframe) + piecewise_logistic_trend.transform( + mock_dataframe, fh=mock_dataframe.index.get_level_values(-1) + ) assert hasattr( piecewise_logistic_trend, "_changepoint_ts" ), "Changepoint ts not set during initialization." @@ -87,11 +87,11 @@ def test_piecewise_compute_trend( df = make_df() for trend_model in [piecewise_linear_trend, piecewise_logistic_trend]: - trend_model.initialize(df) + trend_model.fit(df, df) period_index = pd.period_range(start="2020-01-01", periods=100, freq="D") changepoint_matrix = trend_model.get_changepoint_matrix(period_index) with numpyro.handlers.seed(rng_seed=0): - trend = trend_model.compute_trend(changepoint_matrix) + trend = trend_model.predict(changepoint_matrix, predicted_effects={}) assert ( trend.ndim == expected_ndim ), f"Dimensions are incorrect for trend_model {trend_model.__class__.__name__}" @@ -124,7 +124,7 @@ def test_get_changepoint_timeindexes(): def test_piecewise_linear_get_changepoint_matrix( piecewise_linear_trend, mock_dataframe ): - piecewise_linear_trend.initialize(mock_dataframe) + piecewise_linear_trend.fit(mock_dataframe, mock_dataframe) period_index = pd.period_range(start="2020-01-01", periods=100, freq="D") result = piecewise_linear_trend.get_changepoint_matrix(period_index)