Skip to content

Commit

Permalink
Add WeeklyFourier (pymc-labs#1443)
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloRoque authored Jan 29, 2025
1 parent bdab387 commit 759bfd2
Show file tree
Hide file tree
Showing 4 changed files with 407 additions and 65 deletions.
1 change: 1 addition & 0 deletions pymc_marketing/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@

DAYS_IN_YEAR: float = 365.25
DAYS_IN_MONTH: float = DAYS_IN_YEAR / 12
DAYS_IN_WEEK: int = 7
3 changes: 2 additions & 1 deletion pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
TanhSaturationBaselined,
saturation_from_dict,
)
from pymc_marketing.mmm.fourier import MonthlyFourier, YearlyFourier
from pymc_marketing.mmm.fourier import MonthlyFourier, WeeklyFourier, YearlyFourier
from pymc_marketing.mmm.hsgp import (
HSGP,
CovFunc,
Expand Down Expand Up @@ -85,6 +85,7 @@
"SaturationTransformation",
"TanhSaturation",
"TanhSaturationBaselined",
"WeeklyFourier",
"WeibullCDFAdstock",
"WeibullPDFAdstock",
"YearlyFourier",
Expand Down
133 changes: 120 additions & 13 deletions pymc_marketing/mmm/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- Yearly Fourier: A yearly seasonality with a period of 365.25 days
- Monthly Fourier: A monthly seasonality with a period of 365.25 / 12 days
- Weekly Fourier: A weekly seasonality with a period of 7 days
.. plot::
:context: close-figs
Expand Down Expand Up @@ -221,7 +222,7 @@
from pydantic import BaseModel, Field, InstanceOf, field_serializer, model_validator
from typing_extensions import Self

from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_YEAR
from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_WEEK, DAYS_IN_YEAR
from pymc_marketing.deserialize import deserialize, register_deserialization
from pymc_marketing.plot import SelToString, plot_curve, plot_hdi, plot_samples
from pymc_marketing.prior import Prior, VariableFactory, create_dim_handler
Expand Down Expand Up @@ -383,9 +384,20 @@ def _get_default_start_date(self) -> datetime.datetime:
"""
pass # pragma: no cover

@abstractmethod
def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index:
"""Return the relevant day within the characteristic periodicity.
Returns
-------
int or float
The relevant period within the characteristic periodicity
"""
pass

def apply(
self,
dayofyear: pt.TensorLike,
dayofperiod: pt.TensorLike,
result_callback: Callable[[pt.TensorVariable], None] | None = None,
) -> pt.TensorVariable:
"""Apply fourier seasonality to day of year.
Expand All @@ -394,8 +406,8 @@ def apply(
Parameters
----------
dayofyear : pt.TensorLike
Day of year.
dayofperiod : pt.TensorLike
Day of year or weekday
result_callback : Callable[[pt.TensorVariable], None], optional
Callback function to apply to the result, by default None
Expand Down Expand Up @@ -431,7 +443,7 @@ def callback(result):
fourier.apply(dayofyear, result_callback=callback)
"""
periods = dayofyear / self.days_in_period
periods = dayofperiod / self.days_in_period

model = pm.modelcontext(None)
model.add_coord(self.prefix, self.nodes)
Expand Down Expand Up @@ -506,15 +518,15 @@ def sample_curve(
start_date = self.get_default_start_date(start_date=start_date)
date_range = pd.date_range(
start=start_date,
periods=int(self.days_in_period) + 1,
periods=np.ceil(self.days_in_period) + 1,
freq="D",
)
coords["date"] = date_range.to_numpy()
dayofyear = date_range.dayofyear.to_numpy()
dayofperiod = self._get_days_in_period(date_range).to_numpy()

else:
coords["day"] = full_period
dayofyear = full_period
dayofperiod = full_period

for key, values in parameters[self.variable_name].coords.items():
if key in {"chain", "draw", self.prefix}:
Expand All @@ -525,7 +537,7 @@ def sample_curve(
name = f"{self.prefix}_trend"
pm.Deterministic(
name,
self.apply(dayofyear=dayofyear),
self.apply(dayofperiod=dayofperiod),
dims=tuple(coords.keys()),
)

Expand Down Expand Up @@ -777,6 +789,16 @@ def _get_default_start_date(self) -> datetime.datetime:
current_year = datetime.datetime.now().year
return datetime.datetime(year=current_year, month=1, day=1)

def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index:
"""Return the dayofyear within the yearly periodicity.
Returns
-------
int or float
The relevant period within the characteristic periodicity
"""
return dates.dayofyear


class MonthlyFourier(FourierBase):
"""Monthly fourier seasonality.
Expand All @@ -799,11 +821,11 @@ class MonthlyFourier(FourierBase):
mu = np.array([0, 0, 0.5, 0])
b = 0.075
dist = Prior("Laplace", mu=mu, b=b, dims="fourier")
yearly = MonthlyFourier(n_order=2, prior=dist)
prior = yearly.sample_prior(samples=100)
curve = yearly.sample_curve(prior)
monthly = MonthlyFourier(n_order=2, prior=dist)
prior = monthly.sample_prior(samples=100)
curve = monthly.sample_curve(prior)
_, axes = yearly.plot_curve(curve)
_, axes = monthly.plot_curve(curve)
axes[0].set(title="Monthly Fourier Seasonality")
plt.show()
Expand Down Expand Up @@ -832,6 +854,83 @@ def _get_default_start_date(self) -> datetime.datetime:
now = datetime.datetime.now()
return datetime.datetime(year=now.year, month=now.month, day=1)

def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index:
"""Return the dayofyear within the yearly periodicity.
Returns
-------
int or float
The relevant period within the characteristic periodicity
"""
return dates.dayofyear


class WeeklyFourier(FourierBase):
"""Weekly fourier seasonality.
.. plot::
:context: close-figs
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import WeeklyFourier
from pymc_marketing.prior import Prior
az.style.use("arviz-white")
seed = sum(map(ord, "Weekly"))
rng = np.random.default_rng(seed)
mu = np.array([0, 0, 0.5, 0])
b = 0.075
dist = Prior("Laplace", mu=mu, b=b, dims="fourier")
weekly = WeeklyFourier(n_order=2, prior=dist)
prior = weekly.sample_prior(samples=100)
curve = weekly.sample_curve(prior)
_, axes = weekly.plot_curve(curve)
axes[0].set(title="Weekly Fourier Seasonality")
plt.show()
n_order : int
Number of fourier modes to use.
prefix : str, optional
Alternative prefix for the fourier seasonality, by default None or
"fourier"
prior : Prior | VariableFactory, optional
Prior distribution or VariableFactory for the fourier seasonality beta parameters, by
default `Prior("Laplace", mu=0, b=1)`
name : str, optional
Name of the variable that multiplies the fourier modes, by default None
variable_name : str, optional
Name of the variable that multiplies the fourier modes, by default None
"""

days_in_period: float = DAYS_IN_WEEK

def _get_default_start_date(self) -> datetime.datetime:
"""Get the default start date for weekly seasonality.
Returns the first day of the current month.
"""
now = datetime.datetime.now()
return datetime.datetime.fromisocalendar(
year=now.year, week=now.isocalendar().week, day=1
)

def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index:
"""Return the weekday within the weekly periodicity.
Returns
-------
int or float
The relevant period within the characteristic periodicity
"""
return dates.weekday


def _is_yearly_fourier(data: Any) -> bool:
return data.get("class") == "YearlyFourier"
Expand All @@ -841,6 +940,10 @@ def _is_monthly_fourier(data: Any) -> bool:
return data.get("class") == "MonthlyFourier"


def _is_weekly_fourier(data: Any) -> bool:
return data.get("class") == "WeeklyFourier"


register_deserialization(
is_type=_is_yearly_fourier,
deserialize=lambda data: YearlyFourier.from_dict(data),
Expand All @@ -850,3 +953,7 @@ def _is_monthly_fourier(data: Any) -> bool:
is_type=_is_monthly_fourier,
deserialize=lambda data: MonthlyFourier.from_dict(data),
)

register_deserialization(
is_type=_is_weekly_fourier, deserialize=lambda data: WeeklyFourier.from_dict(data)
)
Loading

0 comments on commit 759bfd2

Please sign in to comment.