-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support scan markov with history > 1 #848
Conversation
@eb8680 Could you help me review this PR? Most of the changes are generalized from history=1 -> history>1. There is a complicated issue to return the correct shape for the last carry: assume history=2, then we run step 0, step 1 outside of We can run step 0, step 1, step 2, then using |
numpyro/contrib/control_flow/scan.py
Outdated
return wrapped_carry, (PytreeTrace({}), ys) | ||
wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - 1, reverse) | ||
y0s = [] | ||
for i in range(history): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarification, after changing this to for i in range(2 * history)
what other changes would be necessary to get the correct final carry_shape
?
numpyro/contrib/control_flow/scan.py
Outdated
# XXX: unless we unroll more steps, we only know the correct shapes starting from | ||
# the `history`-th iteration; that means we only know carry shapes when | ||
# i == history - 1 or i == history, which corresponds to input carry or output carry | ||
# at the `history`-th iteration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi I think I understand this issue, thanks for explaining. How much compilation overhead does unrolling 2 * history
steps add in the HMM examples? Unrolling history
steps in exchange for incorrect output carry
shapes is a very subtle optimization that only produces constant-time savings in program length, and may cause maintenance difficulties down the line, so it's important to be sure that it's worth it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about it again, I guess it won't take much compiling time because history is usually small (1 or 2). Let me address the issue in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after latest changes.
Thanks for reviewing, @eb8680! Looking like there is a bug when I unrolled more steps. Looking into it... |
# + msg: fn.batch_shape = (3,), value.shape = (2, 3) + fn.event_shape | ||
# process_message(msg): promote fn so that fn.batch_shape = (1, 3). | ||
def process_message(self, msg): | ||
def postprocess_message(self, msg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we only want to do this right before getting the trace.
trace = {} | ||
else: | ||
with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The usage of enum()
, markov()
here still gives correct log density but does not recycle the dimensions at this step. To resolve this, I have moved block
right before the generator markov(unroll_steps + 1, history)
and remove markov
handler here.
# Note that `funsor.sum_product.sarkka_bilmes_product` does support history > 1. | ||
# amount number of steps to unroll | ||
history = min(history, length) | ||
unroll_steps = min(2 * history - 1, length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the minimum number of unroll_steps to get a correct last carry shape. But it is safe to use any number greater than this number.
@eb8680 Though the previous implementation gives correct log density, the carry shapes are not recycled with with markov(history=history):
for i in range(unroll_steps):
transition_fn()
lax.scan(with markov(): transition_fn()) logic. I have revised it to be for i in markov(range(unroll_steps + 1)):
if i < unroll_steps:
transition_fn()
else:
lax.scan(transition_fn()) Previously, this issue is not detected due to a typo in the test, where I computed expected_x_prev, expected_x_curr = enum(config_enumerate(model))()
actual_x_prev, actual_x_curr = enum(config_enumerate(model))() # typo: should be fun_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Now that sarkka_bilmes_product
is actually being used, I guess we should probably try to move away from the P
prefix naming scheme upstream in sarkka_bilmes_product
to a safer interface so no non-Markov names from user code with P
in them get incorrectly mangled. Maybe the simplest thing to do would be to make the prefixes longer (e.g. _PREV_
instead of P
).
Good point! |
Resolves #702.
Tasks: