From 1b798b3eed50824fd9d8beaedb7ab6cd102f3919 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Tue, 13 Jun 2017 10:17:35 +0200 Subject: [PATCH 1/2] suppress error raise in #2258 This is a temporary solution!!! Behavior: 1, The user supply start value is taken as priority (same as before the fix) 2, if there is transformed RV conditioned on another free_RV, the start value will be ignore (due to error in #2258) 3, No more error using default init in situation 2. --- pymc3/sampling.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 4caf44d583..ee9aab1e09 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -480,7 +480,12 @@ def _update_start_vals(a, b, model): if is_transformed_name(tname) and get_untransformed_name(tname) == name: transform_func = [d.transformation for d in model.deterministics if d.name == name] if transform_func: - b[tname] = transform_func[0].forward(a[name]).eval() + # TODO: Fix issue #2258 properly, bypassing for now + # (same as previous behavior) + try: + b[tname] = transform_func[0].forward(a[name]).eval() + except: + pass a.update({k: v for k, v in b.items() if k not in a}) @@ -594,9 +599,9 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None, n=n_init, method='advi', model=model, callbacks=cb, progressbar=progressbar, - obj_optimizer=pm.adagrad_window + obj_optimizer=pm.adagrad_window, ) # type: pm.MeanField - start = approx.sample(draws=njobs) + start = approx.sample(draws=njobs, include_transformed=True) stds = approx.gbij.rmap(approx.std.eval()) cov = model.dict_to_array(stds) ** 2 if njobs == 1: @@ -611,7 +616,7 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None, progressbar=progressbar, obj_optimizer=pm.adagrad_window ) - start = approx.sample(draws=njobs) + start = approx.sample(draws=njobs, include_transformed=True) stds = approx.gbij.rmap(approx.std.eval()) cov = model.dict_to_array(stds) ** 2 if njobs == 1: From 47a27c84a3d9a5bf0097893ed83f64fe1c3b3128 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Tue, 13 Jun 2017 10:20:26 +0200 Subject: [PATCH 2/2] tidy up --- pymc3/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index ee9aab1e09..fe1ddceb72 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -599,7 +599,7 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None, n=n_init, method='advi', model=model, callbacks=cb, progressbar=progressbar, - obj_optimizer=pm.adagrad_window, + obj_optimizer=pm.adagrad_window ) # type: pm.MeanField start = approx.sample(draws=njobs, include_transformed=True) stds = approx.gbij.rmap(approx.std.eval())