Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Precondition interpretation for Gaussian TVE #553

Merged
merged 82 commits into from
Nov 12, 2021
Merged

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Sep 24, 2021

Addresses pyro-ppl/pyro#2813

This adds a Precondition interpretation for Gaussian factor graphs.

Similar to batched MonteCarlo(..., sample_inputs) where samples depend on a discrete sample index, Precondition() returns samples that depend on a single white noise input aux : Reals[total_num_elements]. Notably this new interpretation is deterministic, since the samples are drawn lazily. These lazy samples can then be used for sampling by either substituting white noise (in variational inference) or substituting an HMC-controlled vector. Specifically, this can be used to implement Pyro's AutoGaussian.get_transform() for use in NeuTraReparam.

Changes to Gaussian._sample() can be seen as a first step towards making Gaussian depend on funsor.Tensors rather than backend arrays, as suggested by @eb8680 #556.

Remaining tasks

  • implement Gaussian - Gaussian
  • implement Subs._sample()
  • implement Gaussian._sample() for subsets of real variables
  • implement ops.cat
  • implement ops.randn
  • revert Gaussian - Gaussian which turns out to be incorrect

Tested

  • test recipes.py

@fritzo fritzo added the WIP label Sep 24, 2021
# TODO Replace this with root + Constant(...) after #548 merges.
root_vars = root.input_vars | batch_vars

def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=set()):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that batch_vars can now change during the course of adjoint, e.g. in Precondition where the aux vars aren't know until each Approximate term is hit.

@fritzo fritzo mentioned this pull request Oct 19, 2021
1 task
@fritzo fritzo closed this in #577 Oct 19, 2021
@fritzo fritzo reopened this Oct 19, 2021
@fritzo fritzo marked this pull request as ready for review October 19, 2021 23:57
@fritzo
Copy link
Member Author

fritzo commented Oct 27, 2021

@eb8680 let me know if you want a zoom tour

Comment on lines +149 to +154
samples = {k: v(**subs) for k, v in samples.items()}

# Compute log density at each sample, lazily dependent on aux_name.
log_prob = -log_Z
for f in factors.values():
term = f(**samples)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like these substitutions are triggering expensive materialize() operations due to Gaussian.eager_subs().

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sampling math looks right, just some clarifying questions

@@ -883,6 +883,19 @@ def eager_finitary_stack(op, parts):
return Tensor(raw_result, inputs, parts[0].dtype)


@eager.register(Finitary, ops.CatOp, typing.Tuple[Tensor, ...])
def eager_finitary_cat(op, parts):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, do relevant tests still pass if you replace this pattern with

eager.register(Finitary, ops.CatOp, typing.Tuple[Tensor, ...])(
    funsor.op_factory.eager_tensor_made_op)

which uses the generic tensor op evaluation pattern in https://github.com/pyro-ppl/funsor/blob/master/funsor/op_factory.py#L19

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't recall, but my guess is that since ops.cat() doesn't broadcast, we needed to explicitly set expand=True when calling align_tensors().

samples = g2.sample(all_vars, sample_inputs, rng_keys[2])
actual_mean, actual_cov = compute_moments(samples)
assert_close(actual_mean, expected_mean, atol=1e-1, rtol=1e-1)
assert_close(actual_cov, expected_cov, atol=1e-1, rtol=1e-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these tests strong enough to catch bugs with these tolerances? We could also test these computations exactly by preconditioning with shared noise, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These did indeed catch some bugs, but I'd like to eventually switch to double precision in tests #573

test/test_gaussian.py Outdated Show resolved Hide resolved
# Substitute noise for the aux value, as would happen each SVI step.
aux_numel = log_prob.inputs["aux"].num_elements
noise = Tensor(randn(num_samples, aux_numel))["particle"]
with memoize():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this memoize for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the natural evaluation pattern for dags. Since downstream user code will use this pattern, I'd prefer to use this pattern in tests.

shape += tuple(d.size for d in int_inputs.values())
shape += (dim,)
assert ops.is_numeric_array(prototype)
return ops.randn(prototype, shape, rng_key)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the motivation for _sample_white_noise not returning a Funsor in all cases? It seems like making it consistent would slightly reduce the mental overhead of understanding our fairly complicated sample code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it just seemed simpler this way

prec_sqrt = prec_sqrt_b - prec_sqrt_b @ proj_a
white_vec = self.white_vec - _vm(self.white_vec, proj_a)
result += Gaussian(white_vec, prec_sqrt, inputs)
else: # The Gaussian over xa is zero.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which Gaussian tests exercise this case? I want to verify that the size-0 trick used here works on all backends and doesn't introduce unexpected behavior in downstream code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember, maybe in one of the GaussianHMM tests?

@fritzo fritzo closed this Nov 12, 2021
@fritzo fritzo reopened this Nov 12, 2021
@eb8680 eb8680 merged commit 93250f9 into master Nov 12, 2021
@eb8680 eb8680 deleted the precondition branch November 12, 2021 22:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants