Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ChangesOfVariables and InverseFunctions #212

Merged
merged 29 commits into from
Dec 15, 2021

Conversation

oschulz
Copy link
Collaborator

@oschulz oschulz commented Dec 10, 2021

This PR implements the changes discussed in #199, by adding support for JuliaMath/ChangesOfVariables and JuliaMath/InverseFunctions. Both are very lightweight, low-dependency, low-bias packages designed to enable composability of packages that provide/implement or use variable transformation capabilities.

Specifically, this PR adds support for ChangesOfVariables.with_logabsdet_jacobian(::AbstractBijector, ::Any) (a direct equivalent of - and indeed modeled after - the current Bijectors.forward(::AbstractBijector, ::Any)) and InverseFunctions.inverse (a direct equivalent of the current Base.inv(::AbstractBijector)).

The following registered packages directly depend on Bijectors, currently: DifferentialEvolutionMCMC DynamicPPL Turing TuringModels Transits ParameterHandling MeasureTheory AdvancedVI AIBECS Soss

None of those dependent packages define subtypes of AbstractBijector, specialize Bijectors.forward or seem to specialize Base.inv (hope I didn't overlook any). So it seems resonable to deprecate Bijectors.forward(::AbstractBijector, ::Any) and Base.inv(::AbstractBijector) directly and replace all use of them inside of Bijectors.jl with ChangesOfVariables.with_logabsdet_jacobian and InverseFunctions.inverse.

The return type of with_logabsdet_jacobian is slightly different from forward though, it returns a Tuple instead of a NamedTuple{(:rv, :logabsdetjac)}. It seems that the only package that uses these fields is MeasureTheory, in a single place (not at all anymore on the current master branch, it seems). It's handled in the deprecation of forward in this PR.

Closes #199.

CC @torfjelde, @devmotion, @willtebbutt, @cscherrer

@oschulz oschulz marked this pull request as draft December 10, 2021 22:12
@yebai
Copy link
Member

yebai commented Dec 10, 2021

Cc @phipsgabler

@devmotion
Copy link
Member

@oschulz it seems you forgot to push some changes?

@oschulz
Copy link
Collaborator Author

oschulz commented Dec 10, 2021

@oschulz it seems you forgot to push some changes?

Was still working on them. :-) Changes are pushed now.

@oschulz
Copy link
Collaborator Author

oschulz commented Dec 10, 2021

Don't run workflow yet, still fixing tests locally.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't check the PR in detail since it's marked as WIP. However, I noticed that the deprecation of forward seems to be incorrect since it introduces a breaking change. IMO it would be nice if it is possible to not break the external API in this PR:

  • It would make the PR much simpler since it would not have to update the tests (which also more clearly indicates if/how breaking the PR is)
  • It would not break downstream packages

src/Bijectors.jl Outdated Show resolved Hide resolved
src/Bijectors.jl Outdated Show resolved Hide resolved
@oschulz
Copy link
Collaborator Author

oschulz commented Dec 10, 2021

Ready for review and CI (could you trigger the workflow, @devmotion ?)

I have one local test failure with Julia v1.7 in "test/transform.jl:151" (section with a comment "This should fail at the minute") but I get the same test failure with the current master branch, so it seems unrelated.

@oschulz oschulz marked this pull request as ready for review December 10, 2021 23:18
@oschulz oschulz changed the title [WIP] Use ChangesOfVariables and InverseFunctions Use ChangesOfVariables and InverseFunctions Dec 10, 2021
README.md Outdated Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
src/Bijectors.jl Show resolved Hide resolved
src/transformed_distribution.jl Outdated Show resolved Hide resolved
src/transformed_distribution.jl Outdated Show resolved Hide resolved
src/transformed_distribution.jl Outdated Show resolved Hide resolved
src/transformed_distribution.jl Outdated Show resolved Hide resolved
src/transformed_distribution.jl Outdated Show resolved Hide resolved
oschulz and others added 2 commits December 11, 2021 01:42
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
@oschulz
Copy link
Collaborator Author

oschulz commented Dec 11, 2021

Sorry I missed to many forwards initially, @devmotion !

I fixed a few things, should be ready for another CI run now.

@oschulz
Copy link
Collaborator Author

oschulz commented Dec 11, 2021

Should we also deprecate logabsdetjac? Currently, the default implementation of with_logabsdet_jacobian falls back on logabsdetjac, while ChangesOfVariables has with_logabsdet_jacobian as the primary function. JuliaMath/ChangesOfVariables.jl#3 (still undecided) would add a logabsdet_jacobian as an analog of logabsdetjac - but even if added, things would work the other way round, logabsdet_jacobian would fall back on with_logabsdet_jacobian.

The only package that currently seems add methods to logabsdetjac is Transits.jl, in a single place, to define the LADJ of Kipping13Transform. That could easily be changed to with_logabsdet_jacobian, this would also add support for ChangesOfVariables.jl to Transits.jl. @mileslucas, would that be Ok from your side? Soss.jl also defines a method of with_logabsdet_jacobian in a single place in the current release, but that seems gone on the master branch (@cscherrer?).

I'm not sure we can cleanly support both ways in Bijectors.jl (users defining either with_logabsdet_jacobian or logabsdetjac, and the other function then using a default method it if not specialized as well), at least not without ugly trickery. Removing the "primiary" status from logabsdetjac would be breaking - but until we do, users can't code a Bijector the "ChangesOfVariables way", I think.

@devmotion
Copy link
Member

Maybe leave this for a separate PR as it seems to be a more fundamental change of Bijectors and, I assume, has to be benchmarked carefully? Maybe it would also benefit from an upstream definition of logabsdet_jacobian which seems to be another reason not to rush it.

In principle, however, I think one could use something like

logabsdetjac(b::AbstractBijector, x) = last(with_logabsdet_jacobian(b, x))
with_logabsdet_jacobian(b::AbstractBijector, x) = (b(x), logabsdetjac(b, x))

(potentially with some deprecations) to support implementations in both Bijectors and ChangesOfVariables style. But I guess it would be cleaner to do in a separate PR.

@oschulz
Copy link
Collaborator Author

oschulz commented Dec 11, 2021

In principle, however, I think one could use something like

logabsdetjac(b::AbstractBijector, x) = last(with_logabsdet_jacobian(b, x))
with_logabsdet_jacobian(b::AbstractBijector, x) = (b(x), logabsdetjac(b, x))

Won't that result in a stack overflow if neither is defined?

Maybe leave this for a separate PR

Sound good - I think @torfjelde was planning a deeper overhaul anyway?

@devmotion
Copy link
Member

Won't that result in a stack overflow if neither is defined?

Yes, it does. But I assumed this would be fine - for a user or developer it's an indication that you should define (at least) one of these methods.

@cscherrer
Copy link

Soss.jl also defines a method of with_logabsdet_jacobian in a single place in the current release, but that seems gone on the master branch (@cscherrer?).

IIRC this had been an optional dependency, but it wasn't working out because of exported Distributions, and it didn't allow manifolds to be represented an embeddings from lower-dimensional spaces. I think I'll need to wait for #183 before I can use it.

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
@oschulz
Copy link
Collaborator Author

oschulz commented Dec 12, 2021

fix the deprecation warnings that show up in the tests

On it.

@torfjelde
Copy link
Member

Sorry for being a bit awol; past week has been busy, preparing to go home for Christmas.
But just had a quick look and this looks great! Thank you @oschulz !

It seems like @devmotion has already done a proper review of this, so tbh I don't have any comments (beyond his latest on deprecation tests and bumping the verison number) 👍

So feel free to approve it when you're happy @devmotion . I'll be travelling tomorrow, so won't be able to have a look at this again until Tues at the earliest.

@oschulz
Copy link
Collaborator Author

oschulz commented Dec 12, 2021

But just had a quick look and this looks great! Thank you @oschulz !

Thanks! It did get a log bigger than I had expected, initially. :-)

In the meantime, can you update the version number and fix the deprecation warnings

Version number is up and I think I finally eliminated the last remaining deprecation warnings. Let's see if the tests go through clean this time.

@oschulz
Copy link
Collaborator Author

oschulz commented Dec 12, 2021

Ok, looks clean.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, looks very good! Just one minor question: I think we should re-export inverse and with_logabsdet_jacobian since they are part of the API now - or is there a specific reason for not exporting them?

test/runtests.jl Outdated Show resolved Hide resolved
@oschulz
Copy link
Collaborator Author

oschulz commented Dec 13, 2021

Ok, inverse and with_logabsdet_jacobian are re-exported.

Project.toml Outdated
@@ -1,13 +1,15 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.9.11"
version = "0.9.12"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be 0.10.0, given the magnitude of changes (and that inv is now inverse, among other things).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion in it - @devmotion you did consider it non-breaking, right?

Copy link
Member

@devmotion devmotion Dec 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it again, and I still think it is non-breaking if we add fallback definitions for inverse and with_logabsdet_jacobian:

function inverse(b::AbstractBijector)
    Base.depwarn("`inv(b::AbstractBijector)` is deprecated, please use `inverse(b)`", :inverse)
    return inv(b)
end
function with_logabsdet_jacobian(b::AbstractBijector, x)
    Base.depwarn(
        "`forward(b::AbstractBijector, x)` is deprecated, please use `with_logabsdet_jacobian(b, x)`", 
        :with_logabsdet_jacobian,
    )
    return forward(b, x)
end

This is the only breaking change I can imagine with this PR: If a function that operates with bijectors is defined with the new API (maybe even in Bijectors) but the bijector at hand only implements the old API. This can lead to a StackOverflow error - but only if for a bijector neither the old nor the new API is implemented, and hence the implementation is broken anyway.

Otherwise, forward and inv are deprecated and the other changes are merely replacements in the code and tests (to fix deprecation warnings). So even though the PR is quite large the changes itself seem small and well defined.

@oschulz can you add the fallback definitions, and ideally also test them (e.g. with a dummy bijector that only implements the old API)? Then I am convinced that the PR is non-breaking.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oschulz can you add the fallback definitions, and ideally also test them

Yes, will do.

Copy link
Collaborator Author

@oschulz oschulz Dec 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oschulz can you add the fallback definitions, and ideally also test them

I think I found a way to do that and defend against the stack overflow, so we can return a meaningful error if neither forward or with_logabsdet_jacobian is defined, by using a wrapper bijector. The same mechanism can also be used to allow defining Bijectors via with_logabsdet_jacobian without defining logabsdetjac.

@devmotion, I think we we implement JuliaMath/ChangesOfVariables.jl#3 we could then immediately deprecate logabsdetjac as well and still keep this non-breaking, using the same wrapper trick. Let me try something ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I remove the export of inverse and with_logabsdet_jacobian?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think reexporting makes it easy to miss to which package the function actually belongs. Generally, I started to think one should be a bit more careful when it comes to reexporting since it means any breaking change of the upstream definitions seems to require a breaking release in the downstream package as well.

On the other hand, it might seem a bit strange to not export them if they are part of the API 🤷‍♂️

What do you think @torfjelde?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should do a release that's only technically breaking (the dependent package use inv and forward in a few places, but don't specialize them at all) here, and then remove the deprecations later on as part of #183? That way, the dependent package could switch from the using old to using the new API in the mean time.

I'm also in favour of this: sounds good 👍

Should I remove the export of inverse and with_logabsdet_jacobian?

Personally, I'm in favour of exporting. It's very rare someone does using Bijectors without the intention of also using inverse and/or with_logabsdet_jacobian since implementations of these is essentially the point of Bijectors.jl, hence it seems a bit weird to me if they then need to qualify the usages of these functions 🤷

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so we keep the export? If so, this PR should be good to go from my side.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's keep it 👍

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks a lot @oschulz!

@devmotion devmotion merged commit b204712 into TuringLang:master Dec 15, 2021
@oschulz
Copy link
Collaborator Author

oschulz commented Dec 15, 2021

Thanks for all the comments and suggestions!

@torfjelde
Copy link
Member

Indeed, thank you so much @oschulz ! Great stuff:)

@oschulz
Copy link
Collaborator Author

oschulz commented Dec 15, 2021

@devmotion if it's fine with you in principle, I'd draft a PR for distributions to "lift" TransformedDistribution from Bijectors into Distributions. I think it will be a very valuable tool to have, and with InverseFunctions and ChangesOf Variables (and support for them in Bijectors) in place we can make it available without a dependency on Bijectors (I think, haven't drafted the code yet).

@devmotion
Copy link
Member

Sounds reasonable - currently we handle only the special case of affine transformations (limited but hopefully soon in a bit more general way) so I think it would be a valuable addition.

@oschulz
Copy link
Collaborator Author

oschulz commented Dec 15, 2021

Sounds reasonable - currently we handle only the special case of affine transformations

I do have an idea how we can (hopefully) support arbitrary variate types (I definitely want ValueShapes.NamedTupleDist to work with this). I will probably need to include defining Random.gentype for distributions (I think that was already under discussion?) and I may need to at least define new VariateForm for struct types. @devmotion if that's fine with you at least in principle, I would make a concrete draft PR as a basis for discussing details.

@devmotion
Copy link
Member

It would be nice to keep changes as minimal as possible (but e.g. definitions so general that they allow such use cases without breaking changes later on), such that the PR does not become too large and it is less likely that discussions diverge and/or not focus on the main changes.

@devmotion
Copy link
Member

In particular the Random stuff is a very sensitive area where people tend to have strong opinions 😄

@oschulz
Copy link
Collaborator Author

oschulz commented Dec 15, 2021

Understood - I'll try to make something compact.

@devmotion
Copy link
Member

BTW the current design proposal for eltype (but on purpose not including gentype etc.) is JuliaStats/Distributions.jl#1433. But maybe it is sufficient to just call regular rand of the untransformed distribution for sampling?

@oschulz
Copy link
Collaborator Author

oschulz commented Dec 15, 2021

But maybe it is sufficient to just call regular rand of the untransformed distribution for sampling

Yes. The tricky bit will be inferring the VariateForm of the transformed distribution. Base._return_type(f, (Random.gentype(orig_dist),)) will allow us to infer that in many cases, falling back on running the trafo on a single rand value and determining the VariateForm from the result (in the spirit of what Broadcast does to infer the return type and what it does when it can't).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support InverseFunctions.jl and ChangesOfVariables.jl
6 participants