Skip to content

Commit

Permalink
support xarray in ModelBuilder (#1459)
Browse files Browse the repository at this point in the history
* remove argument that does ... nothing

* skin some fat

* remove the method from CLV

* handle the ndarray case

* simplify the tests

* define the name if unnamed

* change to xarray then use merge

* make a mocking fixture

* make a specific test with xarray

* include DataArray for X as well

* test for DataArray

* regenerate the license

* additional type hint for _data_setter
  • Loading branch information
wd60622 authored Feb 3, 2025
1 parent 6c06554 commit 79a8008
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 96 deletions.
4 changes: 0 additions & 4 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,6 @@ def output_var(self):
"""Output variable of the model."""
pass

def _generate_and_preprocess_model_data(self, *args, **kwargs):
"""Generate and preprocess model data."""
pass

def _data_setter(self):
"""Set the data for the model."""
pass
88 changes: 36 additions & 52 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def _validate_data(self, X, y=None):
@abstractmethod
def _data_setter(
self,
X: np.ndarray | pd.DataFrame,
y: np.ndarray | pd.Series | None = None,
X: np.ndarray | pd.DataFrame | xr.Dataset | xr.DataArray,
y: np.ndarray | pd.Series | xr.DataArray | None = None,
) -> None:
"""Set new data in the model.
Expand Down Expand Up @@ -304,44 +304,11 @@ def default_sampler_config(self) -> dict:
"""

@abstractmethod
def _generate_and_preprocess_model_data(
self, X: pd.DataFrame | pd.Series, y: np.ndarray
) -> None:
"""Apply preprocessing to the data before fitting the model.
if validate is True, it will check if the data is valid for the model.
sets self.model_coords based on provided dataset
In case of optional parameters being passed into the model, this method should implement the conditional
logic responsible for correct handling of the optional parameters, and including them into the dataset.
Parameters
----------
X : array, shape (n_obs, n_features)
y : array, shape (n_obs,)
Examples
--------
>>> @classmethod
>>> def _generate_and_preprocess_model_data(self, X, y):
coords = {
'x_dim': X.dim_variable,
} #only include if applicable for your model
>>> self.X = X
>>> self.y = y
Returns
-------
None
"""

@abstractmethod
def build_model(
self,
X: pd.DataFrame,
y: pd.Series | np.ndarray,
X: pd.DataFrame | xr.Dataset | xr.DataArray,
y: pd.Series | np.ndarray | xr.DataArray,
**kwargs,
) -> None:
"""Create an instance of `pm.Model` based on provided data and model_config.
Expand Down Expand Up @@ -656,10 +623,30 @@ def load(cls, fname: str):
)
raise DifferentModelError(error_msg) from e

def create_fit_data(
self,
X: pd.DataFrame | xr.Dataset | xr.DataArray,
y: np.ndarray | pd.Series | xr.DataArray,
) -> xr.Dataset:
"""Create the fit_data group based on the input data."""
if isinstance(y, np.ndarray):
y = pd.Series(y, index=X.index, name=self.output_var)

if y.name is None:
y.name = self.output_var

if isinstance(X, pd.DataFrame):
X = X.to_xarray()

if isinstance(y, pd.Series):
y = y.to_xarray()

return xr.merge([X, y])

def fit(
self,
X: pd.DataFrame,
y: pd.Series | np.ndarray | None = None,
X: pd.DataFrame | xr.Dataset | xr.DataArray,
y: pd.Series | xr.DataArray | np.ndarray | None = None,
progressbar: bool | None = None,
random_seed: RandomState | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -694,23 +681,23 @@ def fit(
Initializing NUTS using jitter+adapt_diag...
"""
if isinstance(y, pd.Series) and not X.index.equals(y.index):
if (
isinstance(y, pd.Series)
and isinstance(X, pd.DataFrame)
and not X.index.equals(y.index)
):
raise ValueError("Index of X and y must match.")

if y is None:
y = np.zeros(X.shape[0])

y_df = pd.DataFrame({self.output_var: y}, index=X.index)
self._generate_and_preprocess_model_data(X, y_df.values.flatten())
if self.X is None or self.y is None:
raise ValueError("X and y must be set before calling build_model!")
if self.output_var in X.columns:
if self.output_var in X:
raise ValueError(
f"X includes a column named '{self.output_var}', which conflicts with the target variable."
)

if not hasattr(self, "model"):
self.build_model(self.X, self.y)
self.build_model(X, y)

sampler_kwargs = create_sample_kwargs(
self.sampler_config,
Expand All @@ -727,21 +714,18 @@ def fit(
else:
self.idata = idata

X_df = pd.DataFrame(X, columns=X.columns)
combined_data = pd.concat([X_df, y_df], axis=1)
if not all(combined_data.columns):
raise ValueError("All columns must have non-empty names")

if "fit_data" in self.idata:
del self.idata.fit_data

fit_data = self.create_fit_data(X, y)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="The group fit_data is not defined in the InferenceData scheme",
)
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore
self.idata.add_groups(fit_data=fit_data)
self.set_idata_attrs(self.idata)
return self.idata # type: ignore

Expand Down
22 changes: 16 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ def mock_sample(*args, **kwargs):
"""This is a mock of pm.sample that returns the prior predictive samples as the posterior."""
random_seed = kwargs.get("random_seed", None)
model = kwargs.get("model", None)
samples = kwargs.get("draws", 10)
draws = kwargs.get("draws", 10)
n_chains = kwargs.get("chains", 1)
idata: InferenceData = pm.sample_prior_predictive(
model=model,
random_seed=random_seed,
samples=samples,
draws=draws,
)

expanded_chains = DataArray(
Expand All @@ -147,6 +147,16 @@ def mock_sample(*args, **kwargs):
return idata


@pytest.fixture
def mock_pymc_sample():
original_sample = pm.sample
pm.sample = mock_sample

yield

pm.sample = original_sample


def mock_fit_MAP(self, *args, **kwargs):
draws = 1
chains = 1
Expand All @@ -173,9 +183,7 @@ def fitted_bg(test_summary_data) -> BetaGeoModel:
model_config=model_config,
)
model.build_model()
fake_fit = pm.sample_prior_predictive(
samples=50, model=model.model, random_seed=rng
)
fake_fit = pm.sample_prior_predictive(draws=50, model=model.model, random_seed=rng)
# posterior group required to pass L80 assert check
fake_fit.add_groups(posterior=fake_fit.prior)
set_model_fit(model, fake_fit)
Expand Down Expand Up @@ -205,7 +213,9 @@ def fitted_pnbd(test_summary_data) -> ParetoNBDModel:
# Mock an idata object for tests requiring a fitted model
# TODO: This is quite slow. Check similar fixtures in the model tests to speed this up.
fake_fit = pm.sample_prior_predictive(
samples=50, model=pnbd_model.model, random_seed=rng
draws=50,
model=pnbd_model.model,
random_seed=rng,
)
# posterior group required to pass L80 assert check
fake_fit.add_groups(posterior=fake_fit.prior)
Expand Down
Loading

0 comments on commit 79a8008

Please sign in to comment.