Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WBIC for KTR #719

Merged
merged 11 commits into from
Mar 9, 2022
180 changes: 113 additions & 67 deletions docs/tutorials/wbic.ipynb

Large diffs are not rendered by default.

6,762 changes: 6,629 additions & 133 deletions examples/ktr.ipynb

Large diffs are not rendered by default.

180 changes: 113 additions & 67 deletions examples/wbic.ipynb

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion orbit/estimators/pyro_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def __init__(self, num_sample=100, num_particles=100, init_scale=0.1, **kwargs):
self.num_particles = num_particles
self.init_scale = init_scale

def fit(self, model_name, model_param_names, data_input, fitter=None, init_values=None):
def fit(self, model_name, model_param_names, data_input, sampling_temperature, fitter=None, init_values=None):
data_input.update({'T_STAR': sampling_temperature})
# verbose is passed through from orbit.template.base_estimator
verbose = self.verbose
message = self.message
Expand Down Expand Up @@ -133,5 +134,11 @@ def fit(self, model_name, model_param_names, data_input, fitter=None, init_value
# filter out unnecessary keys
posteriors = {param: extract[param] for param in model_param_names}
training_metrics = {'loss_elbo': np.array(loss_elbo)}

log_p = extract['log_prob']
training_metrics.update({'log_probability': log_p})
training_metrics.update({'sampling_temperature': sampling_temperature})



return posteriors, training_metrics
28 changes: 26 additions & 2 deletions orbit/forecaster/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def set_forecaster_training_meta(self, data_input):
data_input.update({'WITH_MCMC': 0})
return data_input

def fit(self, df, point_method=None, keep_samples=True):
super().fit(df)
def fit(self, df, point_method=None, keep_samples=True, sampling_temperature=1.0):
super().fit(df, sampling_temperature=sampling_temperature)
self._point_method = point_method

if point_method is not None:
Expand Down Expand Up @@ -183,3 +183,27 @@ def load_extra_methods(self):
self.get_point_posteriors(),
self.get_posterior_samples()
))

def get_wbic_value(self):
# This function calculates the WBIC given that MCMC sampling happened with sampling_temperature = log(n)
training_metrics = self.get_training_metrics() # get the training metrics
training_meta = self.get_training_meta() # get the meta data
sampling_temp = training_metrics['sampling_temperature'] # get the sampling temperature
nobs = training_meta['num_of_obs'] # the number of observations
if sampling_temp != np.log(nobs):
raise ForecasterException('Sampling temperature is not log(n); WBIC calculation is not valid!')
return -2 * np.nanmean(training_metrics['log_probability']) * nobs

def fit_wbic(self, df):
"""This function calculates the WBIC for a Orbit model
Note that if sampling has not been done ith sampling_temperature = log(n) then
the MCMC sampling is redone to get the WBIC
"""
nobs = df.shape[0]
self.fit(df, sampling_temperature=np.log(nobs))
return self.get_wbic_value()





12 changes: 11 additions & 1 deletion orbit/pyro/ktr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pyro
import pyro.distributions as dist


# FIXME: this is sort of dangerous; consider better implementation later
torch.set_default_tensor_type('torch.DoubleTensor')
pyro.enable_validation(True)
Expand All @@ -28,6 +29,7 @@ def __init__(self, data):
value = torch.tensor(value, dtype=torch.double)
self.__dict__[key] = value


def __call__(self):
"""
Notes
Expand All @@ -51,6 +53,8 @@ def __call__(self):
dof = self.dof
lev_knot_loc = self.lev_knot_loc
seas_term = self.seas_term
# added for tempured sampling
T = self.t_star

pr = self.pr
nr = self.nr
Expand Down Expand Up @@ -209,10 +213,15 @@ def __call__(self):
obs_scale_base = pyro.sample("obs_scale_base", dist.Beta(5, 1)).unsqueeze(-1)
obs_scale = obs_scale_base * resid_scale_ub

pyro.sample("response",
#this line addes a tempurature to the obs fit
with pyro.poutine.scale(scale=1.0/T):
pyro.sample("response",
dist.StudentT(dof, yhat[..., which_valid], obs_scale).to_event(1),
obs=response_tran[which_valid])


log_prob = dist.StudentT(dof, yhat[..., which_valid], obs_scale).log_prob(response_tran[which_valid])

lev_knot = lev_knot_tran + meany

extra_out.update({
Expand All @@ -223,5 +232,6 @@ def __call__(self):
'coef_knot': coef_knot,
'coef_init_knot': coef_init_knot,
'obs_scale': obs_scale,
'log_prob': log_prob,
})
return extra_out
15 changes: 13 additions & 2 deletions orbit/pyro/lgt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def __call__(self):
response = self.response
num_of_obs = self.num_of_obs
extra_out = {}

# added for tempured sampling
T = self.t_star

# smoothing params
if self.lev_sm_input < 0:
Expand Down Expand Up @@ -187,11 +190,19 @@ def __call__(self):
yhat = lgt_sum + s[..., :num_of_obs] + r

with pyro.plate("response_plate", num_of_obs - 1):
pyro.sample("response", dist.StudentT(nu, yhat[..., 1:], obs_sigma),
with pyro.poutine.scale(scale=1.0/T):
pyro.sample("response", dist.StudentT(nu, yhat[..., 1:], obs_sigma),
obs=response[1:])

log_prob = dist.StudentT(nu, yhat[..., 1:], obs_sigma).log_prob(response[1:])


# we care beta not the pr_beta, nr_beta, ...
extra_out['beta'] = torch.cat([pr_beta, nr_beta, rr_beta], dim=-1)

extra_out.update({'b': b, 'l': l, 's': s, 'lgt_sum': lgt_sum})
extra_out.update({'b': b,
'l': l,
's': s,
'lgt_sum': lgt_sum,
'log_prob': log_prob})
return extra_out
1 change: 1 addition & 0 deletions tests/orbit/estimators/test_pyro_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def test_pyro_estimator_vi(stan_estimator_lgt_model_input):
model_name=stan_model_name,
model_param_names=model_param_names,
data_input=data_input,
sampling_temperature=1.0
)

assert set(model_param_names) == set(posteriors.keys())
Expand Down