diff --git a/docs/source/api/index.md b/docs/source/api/index.md index ce11968eb..82541f3a2 100644 --- a/docs/source/api/index.md +++ b/docs/source/api/index.md @@ -8,7 +8,9 @@ :toctree: generated/ clv + hsgp_kwargs mmm model_config + model_builder prior ``` diff --git a/docs/source/notebooks/mmm/mmm_example.ipynb b/docs/source/notebooks/mmm/mmm_example.ipynb index c59abb16e..0674217bd 100644 --- a/docs/source/notebooks/mmm/mmm_example.ipynb +++ b/docs/source/notebooks/mmm/mmm_example.ipynb @@ -1119,7 +1119,7 @@ "source": [ "dummy_model = MMM(\n", " date_column=\"\",\n", - " channel_columns=\"\",\n", + " channel_columns=[\"\"],\n", " adstock=\"geometric\",\n", " saturation=\"logistic\",\n", " adstock_max_lag=4,\n", diff --git a/pymc_marketing/clv/models/basic.py b/pymc_marketing/clv/models/basic.py index d8dae9280..b53c224ed 100644 --- a/pymc_marketing/clv/models/basic.py +++ b/pymc_marketing/clv/models/basic.py @@ -20,6 +20,7 @@ import arviz as az import pandas as pd import pymc as pm +from pydantic import ConfigDict, InstanceOf, validate_call from pymc.backends import NDArray from pymc.backends.base import MultiTrace from pymc.model.core import Model @@ -32,11 +33,12 @@ class CLVModel(ModelBuilder): _model_type = "CLVModel" + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( self, data: pd.DataFrame, *, - model_config: ModelConfig | None = None, + model_config: InstanceOf[ModelConfig] | None = None, sampler_config: dict | None = None, non_distributions: list[str] | None = None, ): @@ -65,7 +67,7 @@ def _validate_cols( if data[required_col].nunique() != n: raise ValueError(f"Column {required_col} has duplicate entries") - def __repr__(self): + def __repr__(self) -> str: if not hasattr(self, "model"): return self._model_type else: diff --git a/pymc_marketing/hsgp_kwargs.py b/pymc_marketing/hsgp_kwargs.py new file mode 100644 index 000000000..4a7c4c7c6 --- /dev/null +++ b/pymc_marketing/hsgp_kwargs.py @@ -0,0 +1,82 @@ +# Copyright 2024 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Class to store and validate keyword argument for the Hilbert Space Gaussian Process (HSGP) components.""" + +from typing import Annotated + +import pymc as pm +from pydantic import BaseModel, Field, InstanceOf + + +class HSGPKwargs(BaseModel): + """HSGP keyword arguments for the time-varying prior. + + See [1]_ and [2]_ for the theoretical background on the Hilbert Space Gaussian Process (HSGP). + See , [6]_ for a practical guide through the method using code examples. + See the :class:`~pymc.gp.HSGP` class for more information on the Hilbert Space Gaussian Process in PyMC. + We also recommend the following resources for a more practical introduction to HSGP: [3]_, [4]_, [5]_. + + References + ---------- + .. [1] Solin, A., Sarkka, S. (2019) Hilbert Space Methods for Reduced-Rank Gaussian Process Regression. + .. [2] Ruitort-Mayol, G., and Anderson, M., and Solin, A., and Vehtari, A. (2022). Practical Hilbert Space Approximate Bayesian Gaussian Processes for Probabilistic Programming. + .. [3] PyMC Example Gallery: `"Gaussian Processes: HSGP Reference & First Steps" `_. + .. [4] PyMC Example Gallery: `"Gaussian Processes: HSGP Advanced Usage" `_. + .. [5] PyMC Example Gallery: `"Baby Births Modelling with HSGPs" `_. + .. [6] Orduz, J. `"A Conceptual and Practical Introduction to Hilbert Space GPs Approximation Methods" `_. + + Parameters + ---------- + m : int + Number of basis functions. Default is 200. + L : float, optional + Extent of basis functions. Set this to reflect the expected range of in+out-of-sample data + (considering that time-indices are zero-centered).Default is `X_mid * 2` (identical to `c=2` in HSGP). + By default it is None. + eta_lam : float + Exponential prior for the variance. Default is 1. + ls_mu : float + Mean of the inverse gamma prior for the lengthscale. Default is 5. + ls_sigma : float + Standard deviation of the inverse gamma prior for the lengthscale. Default is 5. + cov_func : ~pymc.gp.cov.Covariance, optional + Gaussian process Covariance function. By default it is None. + """ # noqa E501 + + m: int = Field(200, description="Number of basis functions") + L: ( + Annotated[ + float, + Field( + gt=0, + description=""" + Extent of basis functions. Set this to reflect the expected range of in+out-of-sample data + (considering that time-indices are zero-centered).Default is `X_mid * 2` (identical to `c=2` in HSGP) + """, + ), + ] + | None + ) = None + eta_lam: float = Field(1, gt=0, description="Exponential prior for the variance") + ls_mu: float = Field( + 5, gt=0, description="Mean of the inverse gamma prior for the lengthscale" + ) + ls_sigma: float = Field( + 5, + gt=0, + description="Standard deviation of the inverse gamma prior for the lengthscale", + ) + cov_func: InstanceOf[pm.gp.cov.Covariance] | None = Field( + None, description="Gaussian process Covariance function" + ) diff --git a/pymc_marketing/mmm/budget_optimizer.py b/pymc_marketing/mmm/budget_optimizer.py index 8fa4bd162..3b6ba1114 100644 --- a/pymc_marketing/mmm/budget_optimizer.py +++ b/pymc_marketing/mmm/budget_optimizer.py @@ -17,6 +17,7 @@ from typing import Any import numpy as np +from pydantic import BaseModel, ConfigDict, Field from scipy.optimize import minimize from pymc_marketing.mmm.components.adstock import AdstockTransformation @@ -30,7 +31,7 @@ def __init__(self, message: str): super().__init__(message) -class BudgetOptimizer: +class BudgetOptimizer(BaseModel): """ A class for optimizing budget allocation in a marketing mix model. @@ -58,19 +59,21 @@ class BudgetOptimizer: Default is True. """ - def __init__( - self, - adstock: AdstockTransformation, - saturation: SaturationTransformation, - num_days: int, - parameters: dict[str, dict[str, dict[str, float]]], - adstock_first: bool = True, - ): - self.adstock = adstock - self.saturation = saturation - self.num_days = num_days - self.parameters = parameters - self.adstock_first = adstock_first + adstock: AdstockTransformation = Field( + ..., description="The adstock transformation class." + ) + saturation: SaturationTransformation = Field( + ..., description="The saturation transformation class." + ) + num_days: int = Field(..., gt=0, description="The number of days.") + parameters: dict[str, dict[str, dict[str, float]]] = Field( + ..., description="A dictionary of parameters for each channel." + ) + adstock_first: bool = Field( + True, + description="Whether to apply adstock transformation first or saturation transformation first.", + ) + model_config = ConfigDict(arbitrary_types_allowed=True) def objective(self, budgets: list[float]) -> float: """ diff --git a/pymc_marketing/mmm/components/adstock.py b/pymc_marketing/mmm/components/adstock.py index 89a09d6f1..498327330 100644 --- a/pymc_marketing/mmm/components/adstock.py +++ b/pymc_marketing/mmm/components/adstock.py @@ -54,6 +54,7 @@ def function(self, x, alpha): import numpy as np import xarray as xr +from pydantic import Field, InstanceOf, validate_call from pymc_marketing.mmm.components.base import Transformation from pymc_marketing.mmm.transformers import ( @@ -81,13 +82,20 @@ class AdstockTransformation(Transformation): prefix: str = "adstock" lookup_name: str + @validate_call def __init__( self, - l_max: int, - normalize: bool = True, - mode: ConvMode = ConvMode.After, - priors: dict | None = None, - prefix: str | None = None, + l_max: int = Field( + ..., gt=0, description="Maximum lag for the adstock transformation." + ), + normalize: bool = Field( + True, description="Whether to normalize the adstock values." + ), + mode: ConvMode = Field(ConvMode.After, description="Convolution mode."), + priors: dict[str, str | InstanceOf[Prior]] | None = Field( + default=None, description="Priors for the parameters." + ), + prefix: str | None = Field(None, description="Prefix for the parameters."), ) -> None: self.l_max = l_max self.normalize = normalize @@ -368,16 +376,22 @@ def _get_adstock_function( if isinstance(function, AdstockTransformation): return function - if function not in ADSTOCK_TRANSFORMATIONS: + elif isinstance(function, str): + if function not in ADSTOCK_TRANSFORMATIONS: + raise ValueError( + f"Unknown adstock function: {function}. Choose from {list(ADSTOCK_TRANSFORMATIONS.keys())}" + ) + + if kwargs: + warnings.warn( + "The preferred method of initializing a lagging function is to use the class directly.", + DeprecationWarning, + stacklevel=1, + ) + + return ADSTOCK_TRANSFORMATIONS[function](**kwargs) + + else: raise ValueError( f"Unknown adstock function: {function}. Choose from {list(ADSTOCK_TRANSFORMATIONS.keys())}" ) - - if kwargs: - warnings.warn( - "The preferred method of initializing a lagging function is to use the class directly.", - DeprecationWarning, - stacklevel=1, - ) - - return ADSTOCK_TRANSFORMATIONS[function](**kwargs) diff --git a/pymc_marketing/mmm/components/saturation.py b/pymc_marketing/mmm/components/saturation.py index 225912e3f..d93f8ca60 100644 --- a/pymc_marketing/mmm/components/saturation.py +++ b/pymc_marketing/mmm/components/saturation.py @@ -71,6 +71,7 @@ def function(self, x, b): import numpy as np import xarray as xr +from pydantic import Field, InstanceOf, validate_call from pymc_marketing.mmm.components.base import Transformation from pymc_marketing.mmm.transformers import ( @@ -130,10 +131,13 @@ class InfiniteReturns(SaturationTransformation): prefix: str = "saturation" + @validate_call def sample_curve( self, - parameters: xr.Dataset, - max_value: float = 1.0, + parameters: InstanceOf[xr.Dataset] = Field( + ..., description="Parameters of the saturation transformation." + ), + max_value: float = Field(1.0, gt=0, description="Maximum range value."), ) -> xr.DataArray: """Sample the curve of the saturation transformation given parameters. diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index 1f04d217c..7e894ac24 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -16,7 +16,7 @@ import json import warnings from pathlib import Path -from typing import Any +from typing import Annotated, Any import arviz as az import matplotlib.pyplot as plt @@ -26,8 +26,10 @@ import pymc as pm import pytensor.tensor as pt import seaborn as sns +from pydantic import Field, InstanceOf, validate_call from xarray import DataArray, Dataset +from pymc_marketing.hsgp_kwargs import HSGPKwargs from pymc_marketing.mmm.base import BaseValidateMMM from pymc_marketing.mmm.budget_optimizer import BudgetOptimizer from pymc_marketing.mmm.components.adstock import ( @@ -69,33 +71,65 @@ class BaseMMM(BaseValidateMMM): _model_type: str = "BaseValidateMMM" version: str = "0.0.3" + @validate_call def __init__( self, - date_column: str, - channel_columns: list[str], - adstock_max_lag: int, - adstock: str | AdstockTransformation, - saturation: str | SaturationTransformation, - time_varying_intercept: bool = False, - time_varying_media: bool = False, - model_config: dict | None = None, - sampler_config: dict | None = None, - validate_data: bool = True, - control_columns: list[str] | None = None, - yearly_seasonality: int | None = None, - adstock_first: bool = True, + date_column: str = Field(..., description="Column name of the date variable."), + channel_columns: list[str] = Field( + min_length=1, description="Column names of the media channel variables." + ), + adstock_max_lag: int = Field( + ..., + gt=0, + description="Number of lags to consider in the adstock transformation.", + ), + adstock: str | InstanceOf[AdstockTransformation] = Field( + ..., description="Type of adstock transformation to apply." + ), + saturation: str | InstanceOf[SaturationTransformation] = Field( + ..., description="Type of saturation transformation to apply." + ), + time_varying_intercept: bool = Field( + False, description="Whether to consider time-varying intercept." + ), + time_varying_media: bool = Field( + False, description="Whether to consider time-varying media contributions." + ), + model_config: dict | None = Field(None, description="Model configuration."), + sampler_config: dict | None = Field(None, description="Sampler configuration."), + validate_data: bool = Field( + True, description="Whether to validate the data before fitting to model" + ), + control_columns: Annotated[ + list[str], + Field( + min_length=1, + description="Column names of control variables to be added as additional regressors", + ), + ] + | None = None, + yearly_seasonality: Annotated[ + int, + Field( + gt=0, description="Number of Fourier modes to model yearly seasonality." + ), + ] + | None = None, + adstock_first: bool = Field( + True, description="Whether to apply adstock first." + ), **kwargs, ) -> None: """Constructor method. - Parameters - ---------- + Parameter + --------- date_column : str Column name of the date variable. channel_columns : List[str] Column names of the media channel variables. adstock_max_lag : int, optional - Number of lags to consider in the adstock transformation, by default 4 + Number of lags to consider in the adstock transformation. adstock : str | AdstockTransformation Type of adstock transformation to apply. saturation : str | SaturationTransformation @@ -108,12 +142,12 @@ def __init__( Whether to consider time-varying media contributions, by default False. The `time-varying-media` creates a time media variable centered around 1, this variable acts as a global multiplier (scaling factor) for all channels, - meaning all media channels share the same latent fluctiation. + meaning all media channels share the same latent fluctuation. model_config : Dictionary, optional - dictionary of parameters that initialise model configuration. + Dictionary of parameters that initialise model configuration. Class-default defined by the user default_model_config method. sampler_config : Dictionary, optional - dictionary of parameters that initialise sampler configuration. + Dictionary of parameters that initialise sampler configuration. Class-default defined by the user default_sampler_config method. validate_data : bool, optional Whether to validate the data before fitting to model, by default True. @@ -121,6 +155,8 @@ def __init__( Column names of control variables to be added as additional regressors, by default None yearly_seasonality : Optional[int], optional Number of Fourier modes to model yearly seasonality, by default None. + adstock_first : bool, optional + Whether to apply adstock first, by default True. """ self.control_columns = control_columns self.adstock_max_lag = adstock_max_lag @@ -136,7 +172,7 @@ def __init__( model_config = model_config or {} model_config = parse_model_config( model_config, # type: ignore - non_distributions=["intercept_tvp_config", "media_tvp_config"], + hsgp_kwargs_fields=["intercept_tvp_config", "media_tvp_config"], ) if model_config is not None: @@ -157,7 +193,7 @@ def __init__( n_order=self.yearly_seasonality, prefix="fourier_mode", prior=self.model_config["gamma_fourier"], - name="gamma_fourier", + variable_name="gamma_fourier", ) @property @@ -377,7 +413,7 @@ def build_model( time_index=time_index, time_index_mid=self._time_index_mid, time_resolution=self._time_resolution, - model_config=self.model_config, + hsgp_kwargs=self.model_config["intercept_tvp_config"], ) intercept = pm.Deterministic( name="intercept", @@ -402,7 +438,7 @@ def build_model( time_index=time_index, time_index_mid=self._time_index_mid, time_resolution=self._time_resolution, - model_config=self.model_config, + hsgp_kwargs=self.model_config["media_tvp_config"], ) channel_contributions = pm.Deterministic( name="channel_contributions", @@ -493,23 +529,23 @@ def default_model_config(self) -> dict: } if self.time_varying_intercept: - base_config["intercept_tvp_config"] = { # type: ignore - "m": 200, - "L": None, - "eta_lam": 1, - "ls_mu": None, - "ls_sigma": 10, - "cov_func": None, - } + base_config["intercept_tvp_config"] = HSGPKwargs( + m=200, + L=None, + eta_lam=1, + ls_mu=5, + ls_sigma=10, + cov_func=None, + ) if self.time_varying_media: - base_config["media_tvp_config"] = { # type: ignore - "m": 200, - "L": None, - "eta_lam": 1, - "ls_mu": None, - "ls_sigma": 10, - "cov_func": None, - } + base_config["media_tvp_config"] = HSGPKwargs( + m=200, + L=None, + eta_lam=1, + ls_mu=5, + ls_sigma=10, + cov_func=None, + ) for media_transform in [self.adstock, self.saturation]: for dist in media_transform.function_priors.values(): @@ -883,8 +919,8 @@ class MMM( .. [2] Orduz, J. `"Media Effect Estimation with PyMC: Adstock, Saturation & Diminishing Returns" `_. """ # noqa: E501 - _model_type = "MMM" - version = "0.0.1" + _model_type: str = "MMM" + version: str = "0.0.1" def channel_contributions_forward_pass( self, channel_data: npt.NDArray[np.float64] @@ -2146,10 +2182,11 @@ def plot_allocated_contribution_by_channel( class DelayedSaturatedMMM(MMM): - _model_type = "MMM" - _model_name = "DelayedSaturatedMMM" - version = "0.0.3" + _model_type: str = "MMM" + _model_name: str = "DelayedSaturatedMMM" + version: str = "0.0.3" + @validate_call def __init__( self, date_column: str, diff --git a/pymc_marketing/mmm/fourier.py b/pymc_marketing/mmm/fourier.py index 1b60e2188..c59ae7db6 100644 --- a/pymc_marketing/mmm/fourier.py +++ b/pymc_marketing/mmm/fourier.py @@ -215,13 +215,11 @@ import pymc as pm import pytensor.tensor as pt import xarray as xr +from pydantic import BaseModel, Field, InstanceOf, field_serializer, model_validator +from typing_extensions import Self from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_YEAR -from pymc_marketing.mmm.plot import ( - plot_curve, - plot_hdi, - plot_samples, -) +from pymc_marketing.mmm.plot import plot_curve, plot_hdi, plot_samples from pymc_marketing.prior import Prior, create_dim_handler X_NAME: str = "day" @@ -261,63 +259,56 @@ def generate_fourier_modes( ) -class FourierBase: +class FourierBase(BaseModel): """Base class for Fourier seasonality transformations. Parameters ---------- n_order : int Number of fourier modes to use. + days_in_period : float + Number of days in a period. prefix : str, optional Alternative prefix for the fourier seasonality, by default None or "fourier" prior : Prior, optional Prior distribution for the fourier seasonality beta parameters, by - default None - name : str, optional - Name of the variable that multiplies the fourier modes, by default None - - Attributes - ---------- - days_in_period : float - Number of days in a period. - prefix : str - Name of model coordinates - default_prior : Prior - Default prior distribution for the fourier seasonality - beta parameters. + default `Prior("Laplace", mu=0, b=1)` + variable_name : str, optional + Name of the variable that multiplies the fourier modes. By default None, + in which case it is set to the `{prefix}_beta`. """ - days_in_period: float - prefix: str = "fourier" - - default_prior = Prior("Laplace", mu=0, b=1) + n_order: int = Field(..., gt=0) + days_in_period: float = Field(..., gt=0) + prefix: str = Field("fourier") + prior: InstanceOf[Prior] = Field(Prior("Laplace", mu=0, b=1)) + variable_name: str | None = Field(None) - def __init__( - self, - n_order: int, - prefix: str | None = None, - prior: Prior | None = None, - name: str | None = None, - ) -> None: - if not isinstance(n_order, int) or n_order < 1: - raise ValueError(f"n_order must be a positive integer. Not {n_order}") - - self.n_order = n_order - self.prefix = prefix or self.prefix - self.prior = prior or self.default_prior - self.variable_name = name or f"{self.prefix}_beta" - - if self.variable_name == self.prefix: - raise ValueError("Variable name cannot be the same as the prefix") + def model_post_init(self, __context: Any) -> None: + if self.variable_name is None: + self.variable_name = f"{self.prefix}_beta" if not self.prior.dims: self.prior = self.prior.deepcopy() self.prior.dims = self.prefix + @model_validator(mode="after") + def _check_variable_name(self) -> Self: + if self.variable_name == self.prefix: + raise ValueError("Variable name cannot be the same as the prefix") + return self + + @model_validator(mode="after") + def _check_prior_has_right_dimensions(self) -> Self: if self.prefix not in self.prior.dims: raise ValueError(f"Prior distribution must have dimension {self.prefix}") + return self + + @field_serializer("prior", when_used="json") + def serialize_prior(prior: Prior) -> dict[str, Any]: + return prior.to_json() @property def nodes(self) -> list[str]: @@ -597,8 +588,6 @@ class YearlyFourier(FourierBase): axes[0].set(title="Yearly Fourier Seasonality") plt.show() - Parameters - ---------- n_order : int Number of fourier modes to use. prefix : str, optional @@ -606,21 +595,15 @@ class YearlyFourier(FourierBase): "fourier" prior : Prior, optional Prior distribution for the fourier seasonality beta parameters, by - default None - - Attributes - ---------- - days_in_period : float - Number of days in a period. - prefix : str - Name of model coordinates - default_prior : Prior - Default prior distribution for the fourier seasonality - beta parameters. + default `Prior("Laplace", mu=0, b=1)` + name : str, optional + Name of the variable that multiplies the fourier modes, by default None + variable_name : str, optional + Name of the variable that multiplies the fourier modes, by default None """ - days_in_period = DAYS_IN_YEAR + days_in_period: float = DAYS_IN_YEAR class MonthlyFourier(FourierBase): @@ -652,8 +635,6 @@ class MonthlyFourier(FourierBase): axes[0].set(title="Monthly Fourier Seasonality") plt.show() - Parameters - ---------- n_order : int Number of fourier modes to use. prefix : str, optional @@ -661,18 +642,12 @@ class MonthlyFourier(FourierBase): "fourier" prior : Prior, optional Prior distribution for the fourier seasonality beta parameters, by - default None - - Attributes - ---------- - days_in_period : float - Number of days in a period. - prefix : str - Name of model coordinates - default_prior : Prior - Default prior distribution for the fourier seasonality - beta parameters. + default `Prior("Laplace", mu=0, b=1)` + name : str, optional + Name of the variable that multiplies the fourier modes, by default None + variable_name : str, optional + Name of the variable that multiplies the fourier modes, by default None """ - days_in_period = DAYS_IN_MONTH + days_in_period: float = DAYS_IN_MONTH diff --git a/pymc_marketing/mmm/tvp.py b/pymc_marketing/mmm/tvp.py index 9b8c84237..0d4537622 100644 --- a/pymc_marketing/mmm/tvp.py +++ b/pymc_marketing/mmm/tvp.py @@ -93,6 +93,7 @@ from pymc.distributions.shape_utils import Dims from pymc_marketing.constants import DAYS_IN_YEAR +from pymc_marketing.hsgp_kwargs import HSGPKwargs def time_varying_prior( @@ -100,12 +101,7 @@ def time_varying_prior( X: pt.sharedvar.TensorSharedVariable, dims: Dims, X_mid: int | float | None = None, - m: int = 200, - L: int | float | None = None, - eta_lam: float = 1, - ls_mu: float = 5, - ls_sigma: float = 5, - cov_func: pm.gp.cov.Covariance | None = None, + hsgp_kwargs: HSGPKwargs | None = None, ) -> pt.TensorVariable: """Time varying prior, based on the Hilbert Space Gaussian Process (HSGP). @@ -124,20 +120,9 @@ def time_varying_prior( the time dimension, and the second may be any other dimension, across which independent time varying priors for each coordinate are desired (e.g. channels). - m : int - Number of basis functions. - L : int - Extent of basis functions. Set this to reflect the expected range of - in+out-of-sample data (considering that time-indices are zero-centered). - Default is `X_mid * 2` (identical to `c=2` in HSGP). - eta_lam : float - Exponential prior for the variance. - ls_mu : float - Mean of the inverse gamma prior for the lengthscale. - ls_sigma : float - Standard deviation of the inverse gamma prior for the lengthscale. - cov_func : pm.gp.cov.Covariance - Covariance function. + hsgp_kwargs : HSGPKwargs + Keyword arguments for the Hilbert Space Gaussian Process. By default it is None, + in which case the default parameters are used. See `HSGPKwargs` for more information. Returns ------- @@ -153,24 +138,29 @@ def time_varying_prior( Regression. """ + if hsgp_kwargs is None: + hsgp_kwargs = HSGPKwargs() + if X_mid is None: X_mid = float(X.mean().eval()) - if L is None: - L = X_mid * 2 + if hsgp_kwargs.L is None: + hsgp_kwargs.L = X_mid * 2 model = pm.modelcontext(None) - if cov_func is None: - eta = pm.Exponential(f"{name}_eta", lam=eta_lam) - ls = pm.InverseGamma(f"{name}_ls", mu=ls_mu, sigma=ls_sigma) + if hsgp_kwargs.cov_func is None: + eta = pm.Exponential(f"{name}_eta", lam=hsgp_kwargs.eta_lam) + ls = pm.InverseGamma( + f"{name}_ls", mu=hsgp_kwargs.ls_mu, sigma=hsgp_kwargs.ls_sigma + ) cov_func = eta**2 * pm.gp.cov.Matern52(1, ls=ls) - model.add_coord("m", np.arange(m)) # type: ignore + model.add_coord("m", np.arange(hsgp_kwargs.m)) # type: ignore hsgp_dims: str | tuple[str, str] = "m" if isinstance(dims, tuple): hsgp_dims = (dims[1], "m") - gp = pm.gp.HSGP(m=[m], L=[L], cov_func=cov_func) + gp = pm.gp.HSGP(m=[hsgp_kwargs.m], L=[hsgp_kwargs.L], cov_func=cov_func) phi, sqrt_psd = gp.prior_linearized(Xs=X[:, None] - X_mid) hsgp_coefs = pm.Normal(f"{name}_hsgp_coefs", dims=hsgp_dims) f = phi @ (hsgp_coefs * sqrt_psd).T @@ -185,7 +175,7 @@ def create_time_varying_gp_multiplier( time_index: pt.sharedvar.TensorSharedVariable, time_index_mid: int, time_resolution: int, - model_config: dict, + hsgp_kwargs: HSGPKwargs, ) -> pt.TensorVariable: """Create a time-varying Gaussian Process multiplier. @@ -203,30 +193,26 @@ def create_time_varying_gp_multiplier( Midpoint of the time points. time_resolution : int Resolution of time points. - model_config : dict - Configuration dictionary for the model. + hsgp_kwargsg : HSGPKwargs + Keyword arguments for the Hilbert Space Gaussian Process (HSGP) component. Returns ------- pt.TensorVariable Time-varying Gaussian Process multiplier for a given variable. """ + if hsgp_kwargs.L is None: + hsgp_kwargs.L = time_index_mid + DAYS_IN_YEAR / time_resolution + if hsgp_kwargs.ls_mu is None: + hsgp_kwargs.ls_mu = DAYS_IN_YEAR / time_resolution * 2 - tvp_config = model_config[f"{name}_tvp_config"] - - if tvp_config["L"] is None: - tvp_config["L"] = time_index_mid + DAYS_IN_YEAR / time_resolution - if tvp_config["ls_mu"] is None: - tvp_config["ls_mu"] = DAYS_IN_YEAR / time_resolution * 2 - - multiplier = time_varying_prior( + return time_varying_prior( name=f"{name}_temporal_latent_multiplier", X=time_index, X_mid=time_index_mid, dims=dims, - **tvp_config, + hsgp_kwargs=hsgp_kwargs, ) - return multiplier def infer_time_index( diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index a8702a564..021285491 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""Base class responsible of the high level API for model building, fitting saving and loading.""" import hashlib import json @@ -27,6 +27,7 @@ import xarray as xr from pymc.util import RandomState +from pymc_marketing.hsgp_kwargs import HSGPKwargs from pymc_marketing.prior import Prior # If scikit-learn is available, use its data validator @@ -64,8 +65,6 @@ def __init__( Parameters ---------- - data : Dictionary, optional - It is the data we need to train the model on. model_config : Dictionary, optional dictionary of parameters that initialise model configuration. Class-default defined by the user default_model_config method. @@ -262,7 +261,9 @@ def build_model( None """ - def set_idata_attrs(self, idata=None): + def set_idata_attrs( + self, idata: az.InferenceData | None = None + ) -> az.InferenceData: """ Set attributes on an InferenceData object. @@ -294,7 +295,8 @@ def set_idata_attrs(self, idata=None): def default(x): if isinstance(x, Prior): return x.to_json() - + elif isinstance(x, HSGPKwargs): + return x.model_dump(mode="json") return x.__dict__ idata.attrs["id"] = self.id diff --git a/pymc_marketing/model_config.py b/pymc_marketing/model_config.py index 129abeb8e..d565c1913 100644 --- a/pymc_marketing/model_config.py +++ b/pymc_marketing/model_config.py @@ -16,6 +16,7 @@ import warnings from typing import Any +from pymc_marketing.hsgp_kwargs import HSGPKwargs from pymc_marketing.prior import Prior @@ -23,11 +24,13 @@ class ModelConfigError(Exception): """Exception raised for errors in model configuration.""" -ModelConfig = dict[str, Prior | Any] +ModelConfig = dict[str, HSGPKwargs | Prior | Any] def parse_model_config( - model_config: ModelConfig, non_distributions: list[str] | None = None + model_config: ModelConfig, + hsgp_kwargs_fields: list[str] | None = None, + non_distributions: list[str] | None = None, ) -> ModelConfig: """Parse the model config dictionary. @@ -35,6 +38,8 @@ def parse_model_config( ---------- model_config : dict The model configuration dictionary. + hsgp_kwargs_fields : list[str], optional + A list of keys to parse as HSGP kwargs. non_distributions : list[str], optional A list of keys to ignore when parsing the model configuration dictionary due to them not being distributions. @@ -50,6 +55,7 @@ def parse_model_config( .. code-block:: python + from pymc_marketing.hsgp_kwargs import HSGPKwargs from pymc_marketing.model_config import parse_model_config from pymc_marketing.prior import Prior @@ -62,18 +68,28 @@ def parse_model_config( }, }, "beta": Prior("HalfNormal"), - "tvp_intercept": { + "intercept_tvp_config": { + "m": 200, + "L": 119.17, + "eta_lam": 1.0, + "ls_mu": 5.0, + "ls_sigma": 10.0, + "cov_func": None, + }, + "other_intercept": { "key": "Some other non-distribution configuration", }, } parsed_model_config = parse_model_config( model_config, - non_distributions=["tvp_intercept"], + hsgp_kwargs_fields=["intercept_tvp_config"], + non_distributions=["other_intercept"], ) # {'alpha': Prior("Normal", mu=0, sigma=1), # 'beta': Prior("HalfNormal"), - # 'tvp_intercept': {'key': 'Some other non-distribution configuration'}} + # 'intercept_tvp_config': HSGPKwargs(m=200, L=119.17, eta_lam=1.0, ls_mu=5.0, ls_sigma=10.0, cov_func=None), + # 'other_intercept': {'key': 'Some other non-distribution configuration'}} Parsing with an error: @@ -97,11 +113,12 @@ def parse_model_config( """ non_distributions = non_distributions or [] + hsgp_kwargs_fields = hsgp_kwargs_fields or [] parse_errors = [] def handle_prior_config(name, prior_config): - if name in non_distributions: + if name in non_distributions or name in hsgp_kwargs_fields: return prior_config if isinstance(prior_config, Prior): @@ -120,10 +137,27 @@ def handle_prior_config(name, prior_config): return dist - result = { + def handle_hggp_kwargs(name, config): + if name not in hsgp_kwargs_fields: + return config + + if isinstance(config, HSGPKwargs): + return config + + try: + hsgp_kwargs = HSGPKwargs.model_validate(config) + return hsgp_kwargs + except Exception as e: + parse_errors.append(f"Parameter {name}: {e}") + + # Parse the model configuration to extrat the `Prior` objects. + result: ModelConfig = { name: handle_prior_config(name, prior_config) for name, prior_config in model_config.items() } + # Parse the model configuration to extract the `HSGPKwargs` objects. + result = {name: handle_hggp_kwargs(name, config) for name, config in result.items()} + if parse_errors: combined_errors = ", ".join(parse_errors) msg = ( diff --git a/pymc_marketing/prior.py b/pymc_marketing/prior.py index 2c301a237..18cd419f1 100644 --- a/pymc_marketing/prior.py +++ b/pymc_marketing/prior.py @@ -91,6 +91,7 @@ import pymc as pm import pytensor.tensor as pt import xarray as xr +from pydantic import validate_call from pymc.distributions.shape_utils import Dims @@ -254,6 +255,7 @@ class Prior: pymc_distribution: type[pm.Distribution] pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None + @validate_call def __init__( self, distribution: str, @@ -281,9 +283,6 @@ def distribution(self, distribution: str) -> None: if hasattr(self, "_distribution"): raise AttributeError("Can't change the distribution") - if not isinstance(distribution, str): - raise ValueError("Distribution must be a string") - self._distribution = distribution self.pymc_distribution = _get_pymc_distribution(distribution) @@ -294,9 +293,6 @@ def transform(self) -> str | None: @transform.setter def transform(self, transform: str | None) -> None: - if not isinstance(transform, str) and transform is not None: - raise ValueError("Transform must be a string or None") - self._transform = transform self.pytensor_transform = not transform or _get_transform(transform) # type: ignore diff --git a/pyproject.toml b/pyproject.toml index 0c899c71c..1969af145 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "matplotlib>=3.5.1", "numpy>=1.17", "pandas", + "pydantic>=2.1.0", # NOTE: Used as minimum pymc version with ci.yml `OLDEST_PYMC_VERSION` "pymc>=5.13.0,<5.16.0", "scikit-learn>=1.1.1", @@ -93,6 +94,7 @@ repository = "https://github.com/pymc-labs/pymc-marketing" [tool.ruff.lint] select = ["B", "E", "F", "I", "RUF", "S", "UP", "W"] ignore = [ + "B008", # Do not perform calls in argument defaults (this is ok with Field from pydantic) "B904", # raise-without-from-inside-except "RUF001", # String contains ambiguous character (such as Greek letters) "RUF002", # Docstring contains ambiguous character (such as Greek letters) diff --git a/tests/mmm/components/test_adstock.py b/tests/mmm/components/test_adstock.py index 492c1c392..096c40f40 100644 --- a/tests/mmm/components/test_adstock.py +++ b/tests/mmm/components/test_adstock.py @@ -16,6 +16,7 @@ import pytensor.tensor as pt import pytest import xarray as xr +from pydantic import ValidationError from pymc_marketing.mmm.components.adstock import ( AdstockTransformation, @@ -94,6 +95,14 @@ def test_get_adstock_function(name, adstock_cls, kwargs): assert isinstance(adstock, adstock_cls) +def test_adstock_no_negative_lmax(): + with pytest.raises( + ValidationError, + match="1 validation error for __init__\\nl_max\\n Input should be greater than 0", + ): + DelayedAdstock(l_max=-1) + + @pytest.mark.parametrize( "adstock", adstocks(), @@ -112,6 +121,11 @@ def test_get_adstock_function_unknown(): _get_adstock_function(function="Unknown") +def test_get_adstock_function_unknown_wrong_type(): + with pytest.raises(ValueError, match="Unknown adstock function: 1."): + _get_adstock_function(function=1) + + @pytest.mark.parametrize( "adstock", adstocks(), diff --git a/tests/mmm/components/test_saturation.py b/tests/mmm/components/test_saturation.py index c4bb3d275..fc78b3628 100644 --- a/tests/mmm/components/test_saturation.py +++ b/tests/mmm/components/test_saturation.py @@ -18,6 +18,7 @@ import pytensor.tensor as pt import pytest import xarray as xr +from pydantic import ValidationError from pymc_marketing.mmm.components.saturation import ( HillSaturation, @@ -209,3 +210,20 @@ def test_sample_curve_with_additional_dims( assert curve.coords["channel"].to_numpy().tolist() == ["C1", "C2", "C3"] assert "random_dim" not in curve.coords + + +@pytest.mark.parametrize( + argnames="max_value", argvalues=[0, -1], ids=["zero", "negative"] +) +def test_sample_curve_with_bad_max_value(max_value) -> None: + dummy_distribution = Prior("HalfNormal", dims="channel") + priors = { + "alpha": dummy_distribution, + "lam": dummy_distribution, + } + saturation = MichaelisMentenSaturation(priors=priors) + + with pytest.raises(ValidationError): + saturation.sample_curve( + parameters=mock_menten_parameters_with_additional_dim, max_value=max_value + ) diff --git a/tests/mmm/test_budget_optimizer.py b/tests/mmm/test_budget_optimizer.py index 97ff6e701..244c1b639 100644 --- a/tests/mmm/test_budget_optimizer.py +++ b/tests/mmm/test_budget_optimizer.py @@ -78,7 +78,13 @@ def test_allocate_budget( saturation = _get_saturation_function(function="logistic") # Create BudgetOptimizer Instance - optimizer = BudgetOptimizer(adstock, saturation, 30, parameters, adstock_first=True) + optimizer = BudgetOptimizer( + adstock=adstock, + saturation=saturation, + num_days=30, + parameters=parameters, + adstock_first=True, + ) # Allocate Budget optimal_budgets, total_response = optimizer.allocate_budget( @@ -118,7 +124,13 @@ def test_allocate_budget_zero_total( ): adstock = _get_adstock_function(function="geometric", l_max=4) saturation = _get_saturation_function(function="logistic") - optimizer = BudgetOptimizer(adstock, saturation, 30, parameters, adstock_first=True) + optimizer = BudgetOptimizer( + adstock=adstock, + saturation=saturation, + num_days=30, + parameters=parameters, + adstock_first=True, + ) optimal_budgets, total_response = optimizer.allocate_budget( total_budget, budget_bounds ) @@ -147,7 +159,13 @@ def test_allocate_budget_custom_minimize_args(minimize_mock) -> None: adstock = _get_adstock_function(function="geometric", l_max=4) saturation = _get_saturation_function(function="logistic") - optimizer = BudgetOptimizer(adstock, saturation, 30, parameters, adstock_first=True) + optimizer = optimizer = BudgetOptimizer( + adstock=adstock, + saturation=saturation, + num_days=30, + parameters=parameters, + adstock_first=True, + ) optimizer.allocate_budget( total_budget, budget_bounds, minimize_kwargs=minimize_kwargs ) @@ -196,7 +214,13 @@ def test_allocate_budget_infeasible_constraints( ): adstock = _get_adstock_function(function="geometric", l_max=4) saturation = _get_saturation_function(function="logistic") - optimizer = BudgetOptimizer(adstock, saturation, 30, parameters, adstock_first=True) + optimizer = optimizer = BudgetOptimizer( + adstock=adstock, + saturation=saturation, + num_days=30, + parameters=parameters, + adstock_first=True, + ) with pytest.raises(MinimizeException, match="Optimization failed"): optimizer.allocate_budget(total_budget, budget_bounds, custom_constraints) diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index a6b7cc0c6..63e4ed4a2 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -1046,7 +1046,6 @@ def test_save_load_with_tvp( file = "tmp-model" mmm.save(file) loaded_mmm = MMM.load(file) - assert mmm.time_varying_intercept == loaded_mmm.time_varying_intercept assert mmm.time_varying_intercept == time_varying_intercept assert mmm.time_varying_media == loaded_mmm.time_varying_media diff --git a/tests/mmm/test_fourier.py b/tests/mmm/test_fourier.py index 3da8bdea7..86c90ea98 100644 --- a/tests/mmm/test_fourier.py +++ b/tests/mmm/test_fourier.py @@ -146,9 +146,25 @@ def test_plot_curve() -> None: assert axes.shape == (2, 2) -@pytest.mark.parametrize("n_order", [0, -1, -100, 2.5]) -def test_bad_order(n_order) -> None: - with pytest.raises(ValueError, match="n_order must be a positive integer"): +@pytest.mark.parametrize("n_order", [0, -1, -100]) +def test_bad_negative_order(n_order) -> None: + with pytest.raises( + ValueError, + match="1 validation error for YearlyFourier\\nn_order\\n Input should be greater than 0", + ): + YearlyFourier(n_order=n_order) + + +@pytest.mark.parametrize( + argnames="n_order", + argvalues=[2.5, 100.001, "m", None], + ids=["neg_float", "neg_float_2", "str", "None"], +) +def test_bad_non_integer_order(n_order) -> None: + with pytest.raises( + ValueError, + match="1 validation error for YearlyFourier\nn_order\n Input should be a valid integer", + ): YearlyFourier(n_order=n_order) @@ -240,15 +256,19 @@ def result_callback(x): assert model["components"].eval().shape == (365, n_order * 2) -def test_error_with_prefix_and_name() -> None: +def test_error_with_prefix_and_variable_name() -> None: name = "variable_name" with pytest.raises(ValueError, match="Variable name cannot"): - YearlyFourier(n_order=2, name=name, prefix=name) + YearlyFourier(n_order=2, prefix=name, variable_name=name) def test_change_name() -> None: variable_name = "variable_name" - fourier = YearlyFourier(n_order=2, name=variable_name) - assert fourier.variable_name == variable_name + fourier = YearlyFourier(n_order=2, variable_name=variable_name) prior = fourier.sample_prior(samples=10) assert variable_name in prior + + +def test_serialization_to_json() -> None: + fourier = YearlyFourier(n_order=2) + fourier.model_dump_json() diff --git a/tests/mmm/test_tvp.py b/tests/mmm/test_tvp.py index 534f2c559..10dce78ae 100644 --- a/tests/mmm/test_tvp.py +++ b/tests/mmm/test_tvp.py @@ -17,6 +17,7 @@ import pytensor.tensor as pt import pytest +from pymc_marketing.hsgp_kwargs import HSGPKwargs from pymc_marketing.mmm.tvp import ( create_time_varying_gp_multiplier, infer_time_index, @@ -33,22 +34,25 @@ def coords(): @pytest.fixture -def model_config(): +def model_config() -> dict[str, HSGPKwargs]: return { - "intercept_tvp_config": { - "m": 200, - "eta_lam": 1, - "ls_mu": None, - "ls_sigma": 5, - "L": None, - }, + "intercept_tvp_config": HSGPKwargs( + m=200, + L=None, + eta_lam=1, + ls_mu=5, + ls_sigma=5, + ) } def test_time_varying_prior(coords): with pm.Model(coords=coords) as model: X = pm.Data("X", np.array([0, 1, 2, 3, 4]), dims="date") - prior = time_varying_prior(name="test", X=X, X_mid=2, dims="date", m=3, L=10) + hsgp_kwargs = HSGPKwargs(m=3, L=10, eta_lam=1, ls_sigma=5) + prior = time_varying_prior( + name="test", X=X, X_mid=2, dims="date", hsgp_kwargs=hsgp_kwargs + ) # Assert output verification assert isinstance(prior, pt.TensorVariable) @@ -89,8 +93,9 @@ def test_calling_without_default_args(coords): def test_multidimensional(coords): with pm.Model(coords=coords) as model: X = pm.Data("X", np.array([0, 1, 2, 3, 4]), dims="date") + hsgp_kwargs = HSGPKwargs(m=7) prior = time_varying_prior( - name="test", X=X, X_mid=2, dims=("date", "channel"), m=7 + name="test", X=X, X_mid=2, dims=("date", "channel"), hsgp_kwargs=hsgp_kwargs ) # Assert internal parameters are created correctly @@ -107,7 +112,10 @@ def test_multidimensional(coords): def test_calling_without_model(): with pytest.raises(TypeError, match="No model on context stack."): X = pm.Data("X", np.array([0, 1, 2, 3, 4]), dims="date") - time_varying_prior(name="test", X=X, X_mid=2, dims="date", m=5, L=10) + hsgp_kwargs = HSGPKwargs(m=5, L=10) + time_varying_prior( + name="test", X=X, X_mid=2, dims="date", hsgp_kwargs=hsgp_kwargs + ) def test_create_time_varying_intercept(coords, model_config): @@ -121,7 +129,7 @@ def test_create_time_varying_intercept(coords, model_config): time_index=time_index, time_index_mid=time_index_mid, time_resolution=time_resolution, - model_config=model_config, + hsgp_kwargs=model_config["intercept_tvp_config"], ) assert isinstance(result, pt.TensorVariable) diff --git a/tests/test_model_config.py b/tests/test_model_config.py index c3e73be2b..2f00841e5 100644 --- a/tests/test_model_config.py +++ b/tests/test_model_config.py @@ -17,10 +17,8 @@ import numpy as np import pytest -from pymc_marketing.model_config import ( - ModelConfigError, - parse_model_config, -) +from pymc_marketing.hsgp_kwargs import HSGPKwargs +from pymc_marketing.model_config import ModelConfigError, parse_model_config from pymc_marketing.prior import Prior @@ -142,6 +140,15 @@ def model_config(): "dims": ("channel", "geo"), "centered": False, }, + # TVP Intercept + "intercept_tvp_config": { + "m": 200, + "L": 119.17, + "eta_lam": 1.0, + "ls_mu": 5.0, + "ls_sigma": 10.0, + "cov_func": None, + }, # Incorrect config "error": { "dist": "Normal", @@ -166,6 +173,7 @@ def test_parse_model_config(model_config) -> None: result = parse_model_config( to_parse, + hsgp_kwargs_fields=["intercept_tvp_config"], non_distributions=non_distributions, ) @@ -203,6 +211,14 @@ def test_parse_model_config(model_config) -> None: dims=("channel", "geo"), centered=False, ), + "intercept_tvp_config": HSGPKwargs( + m=200, + L=119.17, + eta_lam=1.0, + ls_mu=5.0, + ls_sigma=10.0, + cov_func=None, + ), "error": { "dist": "Normal", "kwargs": {"mu": "wrong"}, diff --git a/tests/test_prior.py b/tests/test_prior.py index e1e36edcb..3ebae2816 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -20,6 +20,7 @@ import xarray as xr from graphviz.graphs import Digraph from preliz.distributions.distributions import Distribution +from pydantic import ValidationError from pymc.model_graph import fast_eval from pymc_marketing.prior import ( @@ -576,7 +577,10 @@ def test_cant_reset_distribution() -> None: def test_nonstring_distribution() -> None: - with pytest.raises(ValueError, match="Distribution must be a string"): + with pytest.raises( + ValidationError, + match="1 validation error for __init__\\n1\\n Input should be a valid string", + ): Prior(pm.Normal) @@ -587,7 +591,10 @@ def test_change_the_transform() -> None: def test_nonstring_transform() -> None: - with pytest.raises(ValueError, match="Transform must be a string"): + with pytest.raises( + ValidationError, + match="1 validation error for __init__\\ntransform\\n Input should be a valid string", + ): Prior("Normal", transform=pm.math.log)