Skip to content

Commit

Permalink
Update conversion of observations and disable log likelihood
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril authored and brandonwillard committed Mar 25, 2021
1 parent 7a3cf11 commit 3af0a00
Showing 1 changed file with 9 additions and 55 deletions.
64 changes: 9 additions & 55 deletions pymc3/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def __init__(
dims: Optional[DimSpec] = None,
model=None,
save_warmup: Optional[bool] = None,
density_dist_obs: bool = True,
index_origin: Optional[int] = None,
):

Expand Down Expand Up @@ -190,28 +189,18 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
model_dims = {k: list(v) for k, v in self.model.RV_dims.items()}
self.dims = {**model_dims, **self.dims}

self.density_dist_obs = density_dist_obs
self.observations, self.multi_observations = self.find_observations()
self.observations = self.find_observations()

def find_observations(self) -> Tuple[Optional[Dict[str, Var]], Optional[Dict[str, Var]]]:
def find_observations(self) -> Optional[Dict[str, Var]]:
"""If there are observations available, return them as a dictionary."""
if self.model is None:
return (None, None)
return None
observations = {}
multi_observations = {}
for obs in self.model.observed_RVs:
if hasattr(obs, "observations"):
aux_obs = obs.observations
observations[obs.name] = (
aux_obs.get_value() if hasattr(aux_obs, "get_value") else aux_obs
)
elif hasattr(obs, "data") and self.density_dist_obs:
for key, val in obs.data.items():
aux_obs = val.eval() if hasattr(val, "eval") else val
multi_observations[key] = (
aux_obs.get_value() if hasattr(aux_obs, "get_value") else aux_obs
)
return observations, multi_observations
if hasattr(obs.tag, "observations"):
aux_obs = obs.tag.observations
observations[obs.name] = aux_obs.data if hasattr(aux_obs, "data") else aux_obs
return observations

def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
"""Split MultiTrace object into posterior and warmup.
Expand All @@ -233,41 +222,6 @@ def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrac
trace_posterior = self.trace[self.ntune :]
return trace_posterior, trace_warmup

def log_likelihood_vals_point(self, point, var, log_like_fun):
"""Compute log likelihood for each observed point."""
log_like_val = utils.one_de(log_like_fun(point))
if var.missing_values:
mask = var.observations.mask
if np.ndim(mask) > np.ndim(log_like_val):
mask = np.any(mask, axis=-1)
log_like_val = np.where(mask, np.nan, log_like_val)
return log_like_val

def _extract_log_likelihood(self, trace):
"""Compute log likelihood of each observation."""
if self.trace is None:
return None
if self.model is None:
return None

if self.log_likelihood is True:
cached = [(var, var.logp_elemwise) for var in self.model.observed_RVs]
else:
cached = [
(var, var.logp_elemwise)
for var in self.model.observed_RVs
if var.name in self.log_likelihood
]
log_likelihood_dict = _DefaultTrace(len(trace.chains))
for var, log_like_fun in cached:
for k, chain in enumerate(trace.chains):
log_like_chain = [
self.log_likelihood_vals_point(point, var, log_like_fun)
for point in trace.points([chain])
]
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k)
return log_likelihood_dict.trace_dict

@requires("trace")
def posterior_to_xarray(self):
"""Convert the posterior to an xarray dataset."""
Expand Down Expand Up @@ -348,6 +302,8 @@ def sample_stats_to_xarray(self):
@requires("model")
def log_likelihood_to_xarray(self):
"""Extract log likelihood and log_p data from PyMC3 trace."""
# TODO: add pointwise log likelihood extraction to the converter
return None
if self.predictions or not self.log_likelihood:
return None
data_warmup = {}
Expand Down Expand Up @@ -540,7 +496,6 @@ def to_inference_data(
dims: Optional[DimSpec] = None,
model: Optional["Model"] = None,
save_warmup: Optional[bool] = None,
density_dist_obs: bool = True,
) -> InferenceData:
"""Convert pymc3 data into an InferenceData object.
Expand Down Expand Up @@ -590,7 +545,6 @@ def to_inference_data(
dims=dims,
model=model,
save_warmup=save_warmup,
density_dist_obs=density_dist_obs,
).to_inference_data()


Expand Down

0 comments on commit 3af0a00

Please sign in to comment.