Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,12 @@ def _update_start_vals(a, b, model):
if is_transformed_name(tname) and get_untransformed_name(tname) == name:
transform_func = [d.transformation for d in model.deterministics if d.name == name]
if transform_func:
b[tname] = transform_func[0].forward(a[name]).eval()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use tag.test_value instead

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wrong

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try it out please. I was confused by the shortcut pymc3.sampling above the codeblock and thought it is internals of sampling. Seems like it can be the right solution as it is executed once.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand what you mean - if you find a solution you can push to this branch.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the best solution is recompiled the graph using something like _compile_theano_function similar to distribution.draw_value. However, we always need to feed a dictwith the right input to the new compiled function as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @aseyboldt as he mention he might refactor _compile_theano_function for #2296

# TODO: Fix issue #2258 properly, bypassing for now
# (same as previous behavior)
try:
b[tname] = transform_func[0].forward(a[name]).eval()
except:
pass

a.update({k: v for k, v in b.items() if k not in a})

Expand Down Expand Up @@ -596,7 +601,7 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window
) # type: pm.MeanField
start = approx.sample(draws=njobs)
start = approx.sample(draws=njobs, include_transformed=True)
stds = approx.gbij.rmap(approx.std.eval())
cov = model.dict_to_array(stds) ** 2
if njobs == 1:
Expand All @@ -611,7 +616,7 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window
)
start = approx.sample(draws=njobs)
start = approx.sample(draws=njobs, include_transformed=True)
stds = approx.gbij.rmap(approx.std.eval())
cov = model.dict_to_array(stds) ** 2
if njobs == 1:
Expand Down