Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add thin wrapper around advi functionality #1365

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
21 changes: 20 additions & 1 deletion pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@
Method used to fit the model. Options are:
- "mcmc": Samples from the posterior via `pymc.sample` (default)
- "map": Finds maximum a posteriori via `pymc.find_MAP`
- "demz": Samples from the posterior via `pymc.sample` using DEMetropolisZ
- "advi": Samples from the posterior via `pymc.fit(method="advi")` and `pymc.sample`
kwargs:
Other keyword arguments passed to the underlying PyMC routines

Expand All @@ -120,9 +122,12 @@
idata = self._fit_MAP(**kwargs)
case "demz":
idata = self._fit_DEMZ(**kwargs)
case "advi":
idata = self._fit_advi(**kwargs)

case _:
raise ValueError(
f"Fit method options are ['mcmc', 'map', 'demz'], got: {fit_method}"
f"Fit method options are ['mcmc', 'map', 'demz', 'advi'], got: {fit_method}"
)

self.idata = idata
Expand Down Expand Up @@ -164,6 +169,20 @@
with self.model:
return pm.sample(step=pm.DEMetropolisZ(), **sampler_config)

def _fit_advi(self, **kwargs) -> az.InferenceData:
"""Fit a model with ADVI."""
sampler_config = {}
if self.sampler_config is not None:
sampler_config = self.sampler_config.copy()
sampler_config.update(**kwargs)
if sampler_config.get("method") is not None:
PabloRoque marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(

Check warning on line 179 in pymc_marketing/clv/models/basic.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/models/basic.py#L179

Added line #L179 was not covered by tests
"The 'method' parameter is set in sampler_config. Cannot be called with 'advi'."
)
with self.model:
pm.fit(**{"method": "advi"})
return pm.sample(**sampler_config)
wd60622 marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def load(cls, fname: str):
"""Create a ModelBuilder instance from a file.
Expand Down
13 changes: 13 additions & 0 deletions tests/clv/models/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,19 @@ def test_fit_demz(self, mocker):
assert len(idata.posterior.draw) == 10
assert model.fit_result is idata.posterior

def test_fit_advi(self, mocker):
model = CLVModelTest()
# mocker.patch("pymc.sample", mock_sample)
idata = model.fit(
fit_method="advi",
tune=5,
chains=2,
draws=10,
)
assert isinstance(idata, InferenceData)
assert len(idata.posterior.chain) == 2
assert len(idata.posterior.draw) == 10

def test_wrong_fit_method(self):
model = CLVModelTest()
with pytest.raises(
Expand Down
1 change: 1 addition & 0 deletions tests/clv/models/test_beta_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def test_customer_id_duplicate(self):
[
("mcmc", 0.1),
("map", 0.2),
("advi", 0.1),
],
)
def test_model_convergence(self, fit_method, rtol, model_config):
Expand Down
1 change: 1 addition & 0 deletions tests/clv/models/test_modified_beta_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def test_numerically_stable_logp(
[
("mcmc", 0.075),
("map", 0.15),
("advi", 0.075),
],
)
def test_model_convergence(self, fit_method, rtol, model_config):
Expand Down
2 changes: 1 addition & 1 deletion tests/clv/models/test_pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def test_expected_purchase_probability(self, n_purchases, future_t):
rtol=0.001,
)

@pytest.mark.parametrize("fit_type", ("map", "mcmc"))
@pytest.mark.parametrize("fit_type", ("map", "mcmc", "advi"))
def test_posterior_distributions(self, fit_type) -> None:
rng = np.random.default_rng(42)
dim_T = 2357
Expand Down
Loading