Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support scan markov with history > 1 #848

Merged
merged 11 commits into from
Dec 23, 2020
Merged

Support scan markov with history > 1 #848

merged 11 commits into from
Dec 23, 2020

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Dec 19, 2020

Resolves #702.

Tasks:

  • Revise log_density implementation to incorporate this feature.
  • Add hmm example with history > 1
  • Add test to verify the consistency of density, output carry, ys with unrolled scan

@fehiepsi
Copy link
Member Author

@eb8680 Could you help me review this PR? Most of the changes are generalized from history=1 -> history>1. There is a complicated issue to return the correct shape for the last carry: assume history=2, then we run step 0, step 1 outside of lax.scan, then use lax.scan for steps 2 -> (T - 1). Under enumeration, we have 3 different shapes for carry, but we can only interpret those shapes using the input carry and output carry at step 2. So one of them is missing.

We can run step 0, step 1, step 2, then using lax.scan for steps 3 -> (T - 1). This way we can interpret all carry shapes (using infos at steps 2 and 3) but need to unroll history * 2 steps (in general). So unless it is really needed, I would like to unroll the number of steps as small as possible. :) Anyway, we already added a note previously that: there should not have any site outside of scan depend on the first output of scan (the last carry value)

return wrapped_carry, (PytreeTrace({}), ys)
wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - 1, reverse)
y0s = []
for i in range(history):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarification, after changing this to for i in range(2 * history) what other changes would be necessary to get the correct final carry_shape?

# XXX: unless we unroll more steps, we only know the correct shapes starting from
# the `history`-th iteration; that means we only know carry shapes when
# i == history - 1 or i == history, which corresponds to input carry or output carry
# at the `history`-th iteration.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi I think I understand this issue, thanks for explaining. How much compilation overhead does unrolling 2 * history steps add in the HMM examples? Unrolling history steps in exchange for incorrect output carry shapes is a very subtle optimization that only produces constant-time savings in program length, and may cause maintenance difficulties down the line, so it's important to be sure that it's worth it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it again, I guess it won't take much compiling time because history is usually small (1 or 2). Let me address the issue in this PR.

@eb8680 eb8680 requested a review from neerajprad December 21, 2020 05:57
eb8680
eb8680 previously approved these changes Dec 21, 2020
Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after latest changes.

@fehiepsi
Copy link
Member Author

Thanks for reviewing, @eb8680! Looking like there is a bug when I unrolled more steps. Looking into it...

# + msg: fn.batch_shape = (3,), value.shape = (2, 3) + fn.event_shape
# process_message(msg): promote fn so that fn.batch_shape = (1, 3).
def process_message(self, msg):
def postprocess_message(self, msg):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we only want to do this right before getting the trace.

trace = {}
else:
with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage of enum(), markov() here still gives correct log density but does not recycle the dimensions at this step. To resolve this, I have moved block right before the generator markov(unroll_steps + 1, history) and remove markov handler here.

# Note that `funsor.sum_product.sarkka_bilmes_product` does support history > 1.
# amount number of steps to unroll
history = min(history, length)
unroll_steps = min(2 * history - 1, length)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the minimum number of unroll_steps to get a correct last carry shape. But it is safe to use any number greater than this number.

@fehiepsi
Copy link
Member Author

@eb8680 Though the previous implementation gives correct log density, the carry shapes are not recycled with

with markov(history=history):
    for i in range(unroll_steps):
        transition_fn()
    lax.scan(with markov(): transition_fn())

logic. I have revised it to be

for i in markov(range(unroll_steps + 1)):
    if i < unroll_steps:
        transition_fn()
    else:
        lax.scan(transition_fn())

Previously, this issue is not detected due to a typo in the test, where I computed expected and actual values on the same model:

    expected_x_prev, expected_x_curr = enum(config_enumerate(model))()
    actual_x_prev, actual_x_curr = enum(config_enumerate(model))()  # typo: should be fun_model

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Now that sarkka_bilmes_product is actually being used, I guess we should probably try to move away from the P prefix naming scheme upstream in sarkka_bilmes_product to a safer interface so no non-Markov names from user code with P in them get incorrectly mangled. Maybe the simplest thing to do would be to make the prefixes longer (e.g. _PREV_ instead of P).

@fehiepsi
Copy link
Member Author

no non-Markov names from user code with P in them get incorrectly mangled

Good point! _PREV_ should work well here.

@fehiepsi fehiepsi merged commit 4ac5d6f into pyro-ppl:master Dec 23, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support markov with history size greater than 1
2 participants