diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 655ac912..311a8b94 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -132,9 +132,7 @@ def run( ) # NOTE: Methods return different types of objects (idata, approximation, and dictionary) - if inference_method in ( - self.pymc_methods["mcmc"] + self.bayeux_methods["mcmc"] - ): + if inference_method in (self.pymc_methods["mcmc"] + self.bayeux_methods["mcmc"]): result = self._run_mcmc( draws, tune, @@ -154,9 +152,7 @@ def run( elif inference_method == "laplace": result = self._run_laplace(draws, omit_offsets, include_response_params) else: - raise NotImplementedError( - f"'{inference_method}' method has not been implemented" - ) + raise NotImplementedError(f"'{inference_method}' method has not been implemented") self.fit = True return result @@ -255,6 +251,17 @@ def _run_mcmc( import bayeux as bx # pylint: disable=import-outside-toplevel import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + from pymc.sampling.parallel import ( + _cpu_count, + ) + + # handle case where cores and chains are not provided + if cores is None: + cores = min(4, _cpu_count()) + if chains is None: + chains = max(2, cores) + # Set the seed for reproducibility if provided if random_seed is not None: if not isinstance(random_seed, int): @@ -285,9 +292,7 @@ def _run_mcmc( f" {self.pymc_methods['mcmc'] + self.bayeux_methods['mcmc']}" ) - idata = self._clean_results( - idata, omit_offsets, include_response_params, idata_from - ) + idata = self._clean_results(idata, omit_offsets, include_response_params, idata_from) return idata def _clean_results(self, idata, omit_offsets, include_response_params, idata_from): @@ -321,9 +326,7 @@ def _clean_results(self, idata, omit_offsets, include_response_params, idata_fro getattr(idata, group).attrs["modeling_interface_version"] = __version__ if omit_offsets: - offset_vars = [ - var for var in idata.posterior.data_vars if var.endswith("_offset") - ] + offset_vars = [var for var in idata.posterior.data_vars if var.endswith("_offset")] idata.posterior = idata.posterior.drop_vars(offset_vars) dims_original = list(self.model.coords) @@ -369,9 +372,7 @@ def _clean_results(self, idata, omit_offsets, include_response_params, idata_fro dims += tuple(response_coords) posterior = idata.posterior.stack(samples=dims) - coefs = np.vstack( - [np.atleast_2d(posterior[name].values) for name in common_terms] - ) + coefs = np.vstack([np.atleast_2d(posterior[name].values) for name in common_terms]) name = get_aliased_name(bambi_component.intercept_term) center_factor = np.dot(X.mean(0), coefs).reshape(shape) idata.posterior[name] = idata.posterior[name] - center_factor @@ -434,24 +435,16 @@ def _run_laplace(self, draws, omit_offsets, include_response_params): samples = np.random.multivariate_normal(modes, cov, size=draws) idata = _posterior_samples_to_idata(samples, self.model) - idata = self._clean_results( - idata, omit_offsets, include_response_params, idata_from="pymc" - ) + idata = self._clean_results(idata, omit_offsets, include_response_params, idata_from="pymc") return idata @property def constant_components(self): - return { - k: v for k, v in self.components.items() if isinstance(v, ConstantComponent) - } + return {k: v for k, v in self.components.items() if isinstance(v, ConstantComponent)} @property def distributional_components(self): - return { - k: v - for k, v in self.components.items() - if isinstance(v, DistributionalComponent) - } + return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)} def _posterior_samples_to_idata(samples, model): @@ -519,9 +512,7 @@ def create_posterior_bayeux(posterior, pm_model): data_vars_dims = {} for data_var_name in data_vars_names: if data_var_name in vars_to_dims: - data_vars_dims[data_var_name] = ["chain", "draw"] + list( - vars_to_dims[data_var_name] - ) + data_vars_dims[data_var_name] = ["chain", "draw"] + list(vars_to_dims[data_var_name]) else: data_vars_dims[data_var_name] = ["chain", "draw"] @@ -536,13 +527,9 @@ def create_posterior_bayeux(posterior, pm_model): # Get coords dims_in_use = set(dim for dims in data_vars_dims.values() for dim in dims) - coords_in_use = { - coord_name: np.array(coords[coord_name]) for coord_name in dims_in_use - } + coords_in_use = {coord_name: np.array(coords[coord_name]) for coord_name in dims_in_use} - return xr.Dataset( - data_vars=data_vars_values, coords=coords_in_use, attrs=posterior.attrs - ) + return xr.Dataset(data_vars=data_vars_values, coords=coords_in_use, attrs=posterior.attrs) def create_observed_data_bayeux(pm_model):