Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict BroadcastTo lifting of RandomVariables #71

Closed

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Oct 12, 2021

This PR restricts the cases in which BroadcastTo Ops will be lifted through RandomVariable Ops via naive_bcast_rv_lift.

Simply put, an expression like at.broadcast_to(at.random.normal(0, 1), (10,)) should not be lifted, since it would result in ten independent random variates, instead of a single variate that's broadcasted to the shape (10,). Lifting should only happen when the shape of the RandomVariable already matches the broadcasted shape (e.g. via size and/or one of the distribution parameters), or when the broadcasting only introduces broadcastable dimensions (i.e. adds extra dimensions of length one).

This PR is—in part—an answer to an issue mentioned in Part 2 of this comment: #51 (reply in thread).

Currently, this draft PR only introduces some tests for naive_bcast_rv_lift that state what needs to be changed/implemented.

  • Actually put the lifting restriction in place.

@brandonwillard brandonwillard marked this pull request as draft October 12, 2021 05:40
@brandonwillard brandonwillard added bug Something isn't working important This label is used to indicate priority over things not given this label graph rewriting Involves the implementation of rewrites to Aesara graphs labels Oct 12, 2021
@brandonwillard
Copy link
Member Author

brandonwillard commented Oct 16, 2021

While looking into this, I realized that we can use this naive broadcasting as we always have, but that we should replace the underlying RandomVariables with a duplicate type that doesn't allow sampling and doesn't have an associated RNG object.

The idea is that we're only using these RandomVariables for log-probability derivations, so, while at.broadcast_to(at.random.normal(0, 1), (10,)) isn't equivalent to at.random.normal(at.broadcast_to(0, (10,)), at.broadcast_to(1, (10,))) or at.random.normal(0, 1, size=(10,)), it is equivalent to those latter forms when the scalar value variable is also broadcasted and only a log-probability is constructed from those terms.

In other words, we have equivalence "modulo" a log-probability function (and an accompanying value variable transform).

To be responsible about such a rewrite, we shouldn't return RandomVariables, though. Also, by stripping the RNG objects from RandomVariables we might also be able to get rid of some of those annoying and unused shared variables (e.g. the ones introduced by Scan).

Actually, it might make sense to include the value variables as inputs to this new RandomVariable-like type. For example, we would replace the RandomVariable generated by at.random.normal(0, 1) and its value variable y_vv with ValuedRV(NormalRV)(y_vv, 0, 1). This could be performed as a first pass through a graph and serve as a more natural replacement for PreserveRVMappings.

This idea is very similar to the old ObservedRV approach; however, we might be able to avoid some of the major problems ObservedRV had—like its inability to re-use existing rewrites.

cc: @ricardoV94 @kc611

@ricardoV94
Copy link
Contributor

ricardoV94 commented Oct 16, 2021

it is equivalent to those latter forms when the scalar value variable is also broadcasted and only a log-probability is constructed from those terms.

It is equivalent to the final logprob returned by aeppl, but that final logprob is not consistent with the original graph. It would double count terms.

Maybe we need an explicit MeasurableBroadcast that undoes the broadcast of the value variable / variable parameters before calling the base logprob?

rv = at.broadcast_to(at.random.normal(0, 1), (10,))
vv = rv.clone()

# rv logp should be
at.switch(
  # check all values are equal
  at.all(vv == v[0]),
  normal_logprob(vv[0], 0, 1),
  -np.inf
)

But that looks really awkward.

Also there is probably no sampler that could ever propose valid values, so we might be better off just rejecting them altogether.

@brandonwillard
Copy link
Member Author

It is equivalent to the final logprob returned by aeppl, but that final logprob is not consistent with the original graph. It would double count terms.

Under our log-probability mapping, a graph representing an array containing a broadcasted random variable would map to an array of said random variable's scalar log-probability broadcasted, so we want the result to have multiple terms (i.e be an array).

Are you refering to another issue?

Broadly, I'm talking about the implementation of a general log-probability mapping for graphs representing broadcasted random variables. More specifically, the topic is the relevant parsing and representation of these domain elements.

@kc611
Copy link
Collaborator

kc611 commented Oct 16, 2021

Actually, it might make sense to include the value variables as inputs to this new RandomVariable-like type. For example, we would replace the RandomVariable generated by at.random.normal(0, 1) and its value variable y_vv with ValuedRV(NormalRV)(y_vv, 0, 1). This could be performed as a first pass through a graph and serve as a more natural replacement for PreserveRVMappings.

+1 for this idea (or atleast the general direction of it). Always having a value variable attached to this newly built RV derivative will ensure that such replacement only takes pace when the RV is used for log-likelihood graph generation and not for sampling later on. Here I'm assuming that this should work for values which are being provided by the users and not sampled from the RV itself. If it is the latter then there might be some inconsistencies depending upon how we handle them.

It is equivalent to the final logprob returned by aeppl, but that final logprob is not consistent with the original graph. It would double count terms.

I think what @ricardoV94 is referring to here is the same case, i.e. the broadcasting of the value variable itself in cases when the values aren't explicitly provided but are sampled randomly from the RV.

@ricardoV94
Copy link
Contributor

ricardoV94 commented Oct 16, 2021

Under our log-probability mapping, a graph representing an array containing a broadcasted random variable would map to an array of said random variable's scalar log-probability broadcasted, so we want the result to have multiple terms (i.e be an array).

If I understand you correctly it means we will treat the two RVs as the same:

rv1 = at.broadcast_to(at.random.normal(0, 1), (10,))
rv2 = at.random.normal(0, 1, size=10)

vv1 = rv1.clone()
vv2 = rv2.clone()

That simplifies things quite a lot for other derived RVs, but we should decide that explicitly.

It means we use the original aesara graph as something that somewhat loosely defines a log-probability graph and not as a true generative graph that we carefully invert to obtain the corresponding probability graph.


And then your point about the ValuedRVs is just to avoid distorting RVs that stay in the final graph?

@brandonwillard
Copy link
Member Author

brandonwillard commented Oct 16, 2021

If I understand you correctly it means we will treat the two RVs as the same:

The actual equivalence is

logprob(broadcast_to(Y, s), broadcast_to(y, s)) == broadcast_to(logprob(Y, y), s)

for y ~ Y (i.e. y is a value variable of the measurable/random variable Y).

I'm talking about implementing that equivalence relation above.

In practice, we start with the expression broadcast_to(Y, s) and the pair (Y, y), and we need to produce the left-hand side of the relation above so that we can use the right-hand side and derive the result in terms of logprob(Y, y).

What I'm saying is that we can use an intermediate representation in the space of (Y, y), so that broadcast_to((Y, y), s) == (broadcast_to(Y), broadcast_to(y)), where == is an equivalence in logprob. The original equivalence is simply restated under this representation, but it's easier to translate the existing rewrite to something that operates on (Y, y), and it fits better into our parsing strategy.

More specifically, I'm talking about expanding (Y, y) into ((Y_op, Y_rng), y) and actually working in (Y_op, y) (i.e. the equivalence holds under the/a projection).

My second point is about modeling (Y, y) and/or (Y_op, y) pairs explicitly in the graph via an Op.

@ricardoV94
Copy link
Contributor

logprob(broadcast_to(Y, s), broadcast_to(y, s)) == broadcast_to(logprob(Y, y), s)

I don't see what that achieves just yet. Seems like that already happens by the automatic broadcasting of the logprob terms when y is a vector of shape s.

My question was about the case where y.shape == broadcast_to(Y, s). In that case we might not want to broadcast the logprob term (depending on how we interpret the role of aeppl).

I am not entirely sure if we are talking about the same thing though.


I have no idea about the explicit inclusion of value variables in the graph. I say we give it a try and see if it makes our lives easier.

@brandonwillard
Copy link
Member Author

brandonwillard commented Oct 16, 2021

I don't see what that achieves just yet.

Aside from clarifying/understanding what it is we're actually doing and what we intend to do, using the types that are true to the system(s) they model avoids confusion, design quagmires, over-engineering, etc. In other words, it makes it possible to continually simplify a process instead of complicating it over time (e.g. through modularity and the like).

For instance, if we sort these things out at a high level like we are, we can probably find a redesign that recasts all the operations in terms of simple(r) local and global rewrites that lie entirely within the Aesara [Global|Local]Optimization framework. Then, by making independent improvements to that framework, we can automatically get improvements in our log-probability rewrite capabilities and performance.

At a lower level, using the correct types generally implies that type-related implementation details are addressed, such as type/instance equivalence (e.g. __eq__) and inheritance (e.g. isinstance, issubclass), which can functionally resolve confusions and help guard against conceptual mistakes. This is directly relevant to the topic of this PR, because, if naive_bcast_rv_lift operated on objects of a model-theoretically correct type, there would be little room to make the mistake of applying the rewrite to RandomVariables and producing inconsistent graphs.

Seems like that already happens by the automatic broadcasting of the logprob terms when y is a vector of shape s.

Yes, it should, and that's part of the reason why a change to the basic rewrite logic in naive_bcast_rv_lift isn't really a solution to the issue/concerns of this PR.

My question was about the case where y.shape == broadcast_to(Y, s). In that case we might not want to broadcast the logprob term (depending on how we interpret the role of aeppl).

The premise is that we want the broadcasting, because those are the elements we're handling with the naive_bcast_rv_lift rewrite. There should never be a case in which naive_bcast_rv_lift encounters a graph for which it can, but shouldn't, rewrite. That would imply there's an issue with naive_bcast_rv_lift's use, and not naive_bcast_rv_lift's implementation.

@brandonwillard
Copy link
Member Author

brandonwillard commented Oct 16, 2021

Just to follow this chain of thought a little more, if we use the old ObservedRV approach (i.e. an Op that takes the MeasurableVariable and value variable as inputs), we could replace the value variable logic in factorized_joint_logprob with a single rewrite that replaces (assumedly) measurable terms with ObservedRVs.

Most of our custom rewrites could be easily rewritten to handle the extra ObservedRV node between the graphs they presently rewrite; however, some of the existing RandomVariable rewrites (e.g. local_dimshuffle_rv_lift, local_subtensor_rv_lift) would no longer apply, again due to that extra ObservedRV node.

For example, the *Subtensor* lifting for RandomVariables wouldn't work once we replaced the RandomVariables with ObservedRVs, because any Y_rv[idx] graphs would be rewritten to observed(Y_rv, y_vv)[idx] and a rewrite would need to be created for that specific form as well.

Fortunately, with some simple stand-alone ObservedRV canonicalizations, we can probably bridge that gap quite easily. Using the previous example, we could lift *Subtensor* operations through ObservedRVs, and, after such ObservedRV-specific canonicalizations are applied, the more generic RandomVariable canonicalizations/rewrites can be applied.

We could also replace rvs_to_value_vars with a simple rewrite that substitutes ObservedRVs with their value variable inputs.

@ricardoV94
Copy link
Contributor

Let's give it a try and see how it evolves?

@brandonwillard
Copy link
Member Author

Let's give it a try and see how it evolves?

I'll repurpose this PR to that effect.

@brandonwillard brandonwillard mentioned this pull request Oct 23, 2021
5 tasks
@brandonwillard brandonwillard deleted the add-naive_bcast-tests branch October 23, 2021 23:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working graph rewriting Involves the implementation of rewrites to Aesara graphs important This label is used to indicate priority over things not given this label
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants