Skip to content

Commit

Permalink
Move fit_result to model_builder.py and remove redundancies from …
Browse files Browse the repository at this point in the history
…CLV and MMM (#1344)

* Added setter property and fit result attribute

* Removed fit_result from CLV basic.py

* Removed fit_result from MMM  base.py

* Defining fit_result separately

* Move setter test from CLV to model_builder

* move test_fit_result from CLV to model_builder

* Fix the errors in tests in model_builder

* Fixed build_model in the model builder test

* make use of the data fixtures

---------

Co-authored-by: Will Dean <wd60622@gmail.com>
  • Loading branch information
sreekailash and wd60622 authored Jan 8, 2025
1 parent b8f2e8e commit 9684821
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 44 deletions.
18 changes: 0 additions & 18 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from pymc.backends import NDArray
from pymc.backends.base import MultiTrace
from pymc.model.core import Model
from xarray import Dataset

from pymc_marketing.model_builder import ModelBuilder
from pymc_marketing.model_config import ModelConfig, parse_model_config
Expand Down Expand Up @@ -256,23 +255,6 @@ def default_sampler_config(self) -> dict:
def _serializable_model_config(self) -> dict:
return self.model_config

@property
def fit_result(self) -> Dataset:
"""Get the fit result."""
if self.idata is None or "posterior" not in self.idata:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
return self.idata["posterior"]

@fit_result.setter
def fit_result(self, res: az.InferenceData) -> None:
if self.idata is None:
self.idata = res
elif "posterior" in self.idata:
warnings.warn("Overriding pre-existing fit_result", stacklevel=1)
self.idata.posterior = res
else:
self.idata.posterior = res

def fit_summary(self, **kwargs):
"""Compute the summary of the fit result."""
res = self.fit_result
Expand Down
7 changes: 0 additions & 7 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,6 @@ def get_target_transformer(self) -> Pipeline:
identity_transformer = FunctionTransformer()
return Pipeline(steps=[("scaler", identity_transformer)])

@property
def fit_result(self) -> Dataset:
"""Get the posterior data."""
if self.idata is None or "posterior" not in self.idata:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
return self.idata["posterior"]

def _get_group_predictive_data(
self,
group: Literal["prior_predictive", "posterior_predictive"],
Expand Down
37 changes: 37 additions & 0 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,42 @@ def fit(
self.set_idata_attrs(self.idata)
return self.idata # type: ignore

@property
def fit_result(self) -> xr.Dataset:
"""Get the posterior fit_result.
Returns
-------
InferenceData object.
"""
return create_idata_accessor(
"posterior", "The model hasn't been fit yet, call .fit() first"
).__get__(self)

@fit_result.setter
def fit_result(self, res: az.InferenceData) -> None:
"""Create a setter method to overwrite the pre-existing fit_result.
Parameters
----------
res : az.InferenceData
The inferencedata object to be set
Returns
-------
property
The property setter for the InferenceData object.
"""
if self.idata is None:
self.idata = res
elif "posterior" in self.idata:
warnings.warn("Overriding pre-existing fit_result", stacklevel=1)
self.idata.posterior = res
else:
self.idata.posterior = res

def predict(
self,
X_pred: np.ndarray | pd.DataFrame | pd.Series,
Expand Down Expand Up @@ -959,6 +995,7 @@ def graphviz(self, **kwargs):
posterior = create_idata_accessor(
"posterior", "The model hasn't been fit yet, call .fit() first"
)

posterior_predictive = create_idata_accessor(
"posterior_predictive",
"The model hasn't been fit yet, call .sample_posterior_predictive() first",
Expand Down
19 changes: 0 additions & 19 deletions tests/clv/models/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,6 @@ def test_wrong_fit_method(self):
):
model.fit(fit_method="wrong_method")

def test_fit_result_error(self):
model = CLVModelTest()
with pytest.raises(RuntimeError, match="The model hasn't been fit yet"):
model.fit_result

def test_load(self, mocker):
model = CLVModelTest()

Expand All @@ -153,20 +148,6 @@ def test_default_sampler_config(self):
model = CLVModelTest()
assert model.sampler_config == {}

def test_set_fit_result(self):
model = CLVModelTest()
model.build_model()
model.idata = None
fake_fit = pm.sample_prior_predictive(
samples=50, model=model.model, random_seed=1234
)
fake_fit.add_groups(dict(posterior=fake_fit.prior))
model.fit_result = fake_fit
with pytest.warns(UserWarning, match="Overriding pre-existing fit_result"):
model.fit_result = fake_fit
model.idata = None
model.fit_result = fake_fit

def test_fit_summary_for_mcmc(self, mocker):
model = CLVModelTest()

Expand Down
21 changes: 21 additions & 0 deletions tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,27 @@ def test_fit_dup_Y(toy_X, toy_y):
model_builder.fit(X=toy_X, chains=1, draws=100, tune=100)


def test_fit_result_error():
model = ModelBuilderTest()
with pytest.raises(RuntimeError, match="The model hasn't been fit yet"):
model.fit_result


def test_set_fit_result(toy_X, toy_y):
model = ModelBuilderTest()
model.build_model(X=toy_X, y=toy_y)
model.idata = None
fake_fit = pm.sample_prior_predictive(
samples=50, model=model.model, random_seed=1234
)
fake_fit.add_groups(dict(posterior=fake_fit.prior))
model.fit_result = fake_fit
with pytest.warns(UserWarning, match="Overriding pre-existing fit_result"):
model.fit_result = fake_fit
model.idata = None
model.fit_result = fake_fit


@pytest.mark.skipif(
sys.platform == "win32",
reason="Permissions for temp files not granted on windows CI.",
Expand Down

0 comments on commit 9684821

Please sign in to comment.