Skip to content

Commit

Permalink
Handle case when cores and chains are None for nutpie
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexAndorra committed Dec 11, 2024
1 parent ebb735a commit 5fe7947
Showing 1 changed file with 22 additions and 35 deletions.
57 changes: 22 additions & 35 deletions bambi/backend/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ def run(
)

# NOTE: Methods return different types of objects (idata, approximation, and dictionary)
if inference_method in (
self.pymc_methods["mcmc"] + self.bayeux_methods["mcmc"]
):
if inference_method in (self.pymc_methods["mcmc"] + self.bayeux_methods["mcmc"]):
result = self._run_mcmc(
draws,
tune,
Expand All @@ -154,9 +152,7 @@ def run(
elif inference_method == "laplace":
result = self._run_laplace(draws, omit_offsets, include_response_params)
else:
raise NotImplementedError(
f"'{inference_method}' method has not been implemented"
)
raise NotImplementedError(f"'{inference_method}' method has not been implemented")

self.fit = True
return result
Expand Down Expand Up @@ -255,6 +251,17 @@ def _run_mcmc(
import bayeux as bx # pylint: disable=import-outside-toplevel
import jax # pylint: disable=import-outside-toplevel

# pylint: disable=import-outside-toplevel
from pymc.sampling.parallel import (
_cpu_count,
)

# handle case where cores and chains are not provided
if cores is None:
cores = min(4, _cpu_count())
if chains is None:
chains = max(2, cores)

# Set the seed for reproducibility if provided
if random_seed is not None:
if not isinstance(random_seed, int):
Expand Down Expand Up @@ -285,9 +292,7 @@ def _run_mcmc(
f" {self.pymc_methods['mcmc'] + self.bayeux_methods['mcmc']}"
)

idata = self._clean_results(
idata, omit_offsets, include_response_params, idata_from
)
idata = self._clean_results(idata, omit_offsets, include_response_params, idata_from)
return idata

def _clean_results(self, idata, omit_offsets, include_response_params, idata_from):
Expand Down Expand Up @@ -321,9 +326,7 @@ def _clean_results(self, idata, omit_offsets, include_response_params, idata_fro
getattr(idata, group).attrs["modeling_interface_version"] = __version__

if omit_offsets:
offset_vars = [
var for var in idata.posterior.data_vars if var.endswith("_offset")
]
offset_vars = [var for var in idata.posterior.data_vars if var.endswith("_offset")]
idata.posterior = idata.posterior.drop_vars(offset_vars)

dims_original = list(self.model.coords)
Expand Down Expand Up @@ -369,9 +372,7 @@ def _clean_results(self, idata, omit_offsets, include_response_params, idata_fro
dims += tuple(response_coords)

posterior = idata.posterior.stack(samples=dims)
coefs = np.vstack(
[np.atleast_2d(posterior[name].values) for name in common_terms]
)
coefs = np.vstack([np.atleast_2d(posterior[name].values) for name in common_terms])
name = get_aliased_name(bambi_component.intercept_term)
center_factor = np.dot(X.mean(0), coefs).reshape(shape)
idata.posterior[name] = idata.posterior[name] - center_factor
Expand Down Expand Up @@ -434,24 +435,16 @@ def _run_laplace(self, draws, omit_offsets, include_response_params):
samples = np.random.multivariate_normal(modes, cov, size=draws)

idata = _posterior_samples_to_idata(samples, self.model)
idata = self._clean_results(
idata, omit_offsets, include_response_params, idata_from="pymc"
)
idata = self._clean_results(idata, omit_offsets, include_response_params, idata_from="pymc")
return idata

@property
def constant_components(self):
return {
k: v for k, v in self.components.items() if isinstance(v, ConstantComponent)
}
return {k: v for k, v in self.components.items() if isinstance(v, ConstantComponent)}

@property
def distributional_components(self):
return {
k: v
for k, v in self.components.items()
if isinstance(v, DistributionalComponent)
}
return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)}


def _posterior_samples_to_idata(samples, model):
Expand Down Expand Up @@ -519,9 +512,7 @@ def create_posterior_bayeux(posterior, pm_model):
data_vars_dims = {}
for data_var_name in data_vars_names:
if data_var_name in vars_to_dims:
data_vars_dims[data_var_name] = ["chain", "draw"] + list(
vars_to_dims[data_var_name]
)
data_vars_dims[data_var_name] = ["chain", "draw"] + list(vars_to_dims[data_var_name])
else:
data_vars_dims[data_var_name] = ["chain", "draw"]

Expand All @@ -536,13 +527,9 @@ def create_posterior_bayeux(posterior, pm_model):

# Get coords
dims_in_use = set(dim for dims in data_vars_dims.values() for dim in dims)
coords_in_use = {
coord_name: np.array(coords[coord_name]) for coord_name in dims_in_use
}
coords_in_use = {coord_name: np.array(coords[coord_name]) for coord_name in dims_in_use}

return xr.Dataset(
data_vars=data_vars_values, coords=coords_in_use, attrs=posterior.attrs
)
return xr.Dataset(data_vars=data_vars_values, coords=coords_in_use, attrs=posterior.attrs)


def create_observed_data_bayeux(pm_model):
Expand Down

0 comments on commit 5fe7947

Please sign in to comment.