Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ChangesOfVariables and InverseFunctions #212

Merged
merged 29 commits into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
31adbe8
Add ChangesOfVariables and InverseFunctions to deps
oschulz Dec 10, 2021
000c83f
Replace forward by with_logabsdet_jacobian
oschulz Dec 10, 2021
0c1bf48
Replace Base.inv with InverseFunctions.inverse
oschulz Dec 10, 2021
f6385ea
Improve deprecation scheme for forward
oschulz Dec 10, 2021
0aab88b
Improve deprecation scheme for inv
oschulz Dec 10, 2021
e6f549d
Test forward and inv deprecations
oschulz Dec 10, 2021
4c7f706
Apply suggestions from code review
oschulz Dec 11, 2021
3815505
Fixes regarding with_logabsdet_jacobian and inverse
oschulz Dec 11, 2021
2b93560
Fix with_logabsdet_jacobian for NamedComposition
oschulz Dec 11, 2021
20e50d4
Fix deprecation of inv
oschulz Dec 11, 2021
fc990bb
Use inverse instead of inv for Composed
oschulz Dec 11, 2021
625fbb7
Use with_logabsdet_jacobian instead of forward
oschulz Dec 11, 2021
8273628
Workaround for intermittent failures in Dirichlet test
oschulz Dec 11, 2021
167a26e
Use with_logabsdet_jacobian instead of forward
oschulz Dec 11, 2021
8a0c658
Use with_logabsdet_jacobian instead of forward
oschulz Dec 11, 2021
34f7fc6
Add rrules for combine with PartitionMask
oschulz Dec 12, 2021
2f1c36d
Use inv instead of inverse for numbers
oschulz Dec 12, 2021
012d90a
Apply suggestions from code review
oschulz Dec 12, 2021
3fc7a43
Whitespace fix.
oschulz Dec 12, 2021
045412f
Move combine rrule and add test
oschulz Dec 12, 2021
7e96b8d
Apply suggestions from code review
oschulz Dec 12, 2021
874b5ec
Use @test_deprecated
oschulz Dec 12, 2021
d9c8562
Use @test_deprecated
oschulz Dec 12, 2021
5cad1e4
Use inverse instead of inv
oschulz Dec 12, 2021
6ed6fa9
Use test_inverse and test_with_logabsdet_jacobian
oschulz Dec 12, 2021
4fadffc
Use inverse instead of inv
oschulz Dec 12, 2021
5f4d982
Increase version number to v0.9.12
oschulz Dec 12, 2021
fb54734
Reexport with_logabsdet_jacobian and inverse
oschulz Dec 13, 2021
4b683e5
Increase package version to v0.10.0
oschulz Dec 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.9.11"
version = "0.9.12"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be 0.10.0, given the magnitude of changes (and that inv is now inverse, among other things).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion in it - @devmotion you did consider it non-breaking, right?

Copy link
Member

@devmotion devmotion Dec 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it again, and I still think it is non-breaking if we add fallback definitions for inverse and with_logabsdet_jacobian:

function inverse(b::AbstractBijector)
    Base.depwarn("`inv(b::AbstractBijector)` is deprecated, please use `inverse(b)`", :inverse)
    return inv(b)
end
function with_logabsdet_jacobian(b::AbstractBijector, x)
    Base.depwarn(
        "`forward(b::AbstractBijector, x)` is deprecated, please use `with_logabsdet_jacobian(b, x)`", 
        :with_logabsdet_jacobian,
    )
    return forward(b, x)
end

This is the only breaking change I can imagine with this PR: If a function that operates with bijectors is defined with the new API (maybe even in Bijectors) but the bijector at hand only implements the old API. This can lead to a StackOverflow error - but only if for a bijector neither the old nor the new API is implemented, and hence the implementation is broken anyway.

Otherwise, forward and inv are deprecated and the other changes are merely replacements in the code and tests (to fix deprecation warnings). So even though the PR is quite large the changes itself seem small and well defined.

@oschulz can you add the fallback definitions, and ideally also test them (e.g. with a dummy bijector that only implements the old API)? Then I am convinced that the PR is non-breaking.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oschulz can you add the fallback definitions, and ideally also test them

Yes, will do.

Copy link
Collaborator Author

@oschulz oschulz Dec 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oschulz can you add the fallback definitions, and ideally also test them

I think I found a way to do that and defend against the stack overflow, so we can return a meaningful error if neither forward or with_logabsdet_jacobian is defined, by using a wrapper bijector. The same mechanism can also be used to allow defining Bijectors via with_logabsdet_jacobian without defining logabsdetjac.

@devmotion, I think we we implement JuliaMath/ChangesOfVariables.jl#3 we could then immediately deprecate logabsdetjac as well and still keep this non-breaking, using the same wrapper trick. Let me try something ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I remove the export of inverse and with_logabsdet_jacobian?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think reexporting makes it easy to miss to which package the function actually belongs. Generally, I started to think one should be a bit more careful when it comes to reexporting since it means any breaking change of the upstream definitions seems to require a breaking release in the downstream package as well.

On the other hand, it might seem a bit strange to not export them if they are part of the API 🤷‍♂️

What do you think @torfjelde?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should do a release that's only technically breaking (the dependent package use inv and forward in a few places, but don't specialize them at all) here, and then remove the deprecations later on as part of #183? That way, the dependent package could switch from the using old to using the new API in the mean time.

I'm also in favour of this: sounds good 👍

Should I remove the export of inverse and with_logabsdet_jacobian?

Personally, I'm in favour of exporting. It's very rare someone does using Bijectors without the intention of also using inverse and/or with_logabsdet_jacobian since implementations of these is essentially the point of Bijectors.jl, hence it seems a bit weird to me if they then need to qualify the usages of these functions 🤷

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so we keep the export? If so, this PR should be good to go from my side.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's keep it 👍


[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Expand All @@ -22,9 +24,11 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
ArgCheck = "1, 2"
ChainRulesCore = "0.10.11, 1"
ChangesOfVariables = "0.1"
Compat = "3"
Distributions = "0.23.3, 0.24, 0.25"
Functors = "0.1, 0.2"
InverseFunctions = "0.1"
IrrationalConstants = "0.1"
LogExpFunctions = "0.3.3"
MappedArrays = "0.2.2, 0.3, 0.4"
Expand Down
78 changes: 40 additions & 38 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ The following table lists mathematical operations for a bijector and the corresp

| Operation | Method | Automatic |
|:------------------------------------:|:-----------------:|:-----------:|
| `b ↦ b⁻¹` | `inv(b)` | ✓ |
| `b ↦ b⁻¹` | `inverse(b)` | ✓ |
| `(b₁, b₂) ↦ (b₁ ∘ b₂)` | `b₁ ∘ b₂` | ✓ |
| `(b₁, b₂) ↦ [b₁, b₂]` | `stack(b₁, b₂)` | ✓ |
| `x ↦ b(x)` | `b(x)` | × |
| `y ↦ b⁻¹(y)` | `inv(b)(y)` | × |
| `y ↦ b⁻¹(y)` | `inverse(b)(y)` | × |
| `x ↦ log|det J(b, x)|` | `logabsdetjac(b, x)` | AD |
| `x ↦ b(x), log|det J(b, x)|` | `forward(b, x)` | ✓ |
| `x ↦ b(x), log|det J(b, x)|` | `with_logabsdet_jacobian(b, x)` | ✓ |
| `p ↦ q := b_* p` | `q = transformed(p, b)` | ✓ |
| `y ∼ q` | `y = rand(q)` | ✓ |
| `p ↦ b` such that `support(b_* p) = ℝᵈ` | `bijector(p)` | ✓ |
Expand Down Expand Up @@ -123,7 +123,7 @@ true
What about `invlink`?

```julia
julia> b⁻¹ = inv(b)
julia> b⁻¹ = inverse(b)
Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))

julia> b⁻¹(y)
Expand All @@ -133,7 +133,7 @@ julia> b⁻¹(y) ≈ invlink(dist, y)
true
```

Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inv(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inv(Exp()) isa Log` is true.
Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inverse(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inverse(Exp()) isa Log` is true.

#### Dimensionality
One more thing. See the `0` in `Inverse{Logit{Float64}, 0}`? It represents the *dimensionality* of the bijector, in the same sense as for an `AbstractArray` with the exception of `0` which means it expects 0-dim input and output, i.e. `<:Real`. This can also be accessed through `dimension(b)`:
Expand Down Expand Up @@ -162,7 +162,7 @@ true
And since `Composed isa Bijector`:

```julia
julia> id_x = inv(id_y)
julia> id_x = inverse(id_y)
Composed{Tuple{Inverse{Logit{Float64},0},Logit{Float64}},0}((Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Logit{Float64}(0.0, 1.0)))

julia> id_x(x) ≈ x
Expand Down Expand Up @@ -199,9 +199,9 @@ julia> logpdf_forward(td, x)
-1.123311289915276
```

#### `logabsdetjac` and `forward`
#### `logabsdetjac` and `with_logabsdet_jacobian`

In the computation of both `logpdf` and `logpdf_forward` we need to compute `log(abs(det(jacobian(inv(b), y))))` and `log(abs(det(jacobian(b, x))))`, respectively. This computation is available using the `logabsdetjac` method
In the computation of both `logpdf` and `logpdf_forward` we need to compute `log(abs(det(jacobian(inverse(b), y))))` and `log(abs(det(jacobian(b, x))))`, respectively. This computation is available using the `logabsdetjac` method

```julia
julia> logabsdetjac(b⁻¹, y)
Expand All @@ -218,21 +218,21 @@ julia> logabsdetjac(b, x) ≈ -logabsdetjac(b⁻¹, y)
true
```

which is always the case for a differentiable bijection with differentiable inverse. Therefore if you want to compute `logabsdetjac(b⁻¹, y)` and we know that `logabsdetjac(b, b⁻¹(y))` is actually more efficient, we'll return `-logabsdetjac(b, b⁻¹(y))` instead. For some bijectors it might be easy to compute, say, the forward pass `b(x)`, but expensive to compute `b⁻¹(y)`. Because of this you might want to avoid doing anything "backwards", i.e. using `b⁻¹`. This is where `forward` comes to good use:
which is always the case for a differentiable bijection with differentiable inverse. Therefore if you want to compute `logabsdetjac(b⁻¹, y)` and we know that `logabsdetjac(b, b⁻¹(y))` is actually more efficient, we'll return `-logabsdetjac(b, b⁻¹(y))` instead. For some bijectors it might be easy to compute, say, the forward pass `b(x)`, but expensive to compute `b⁻¹(y)`. Because of this you might want to avoid doing anything "backwards", i.e. using `b⁻¹`. This is where `with_logabsdet_jacobian` comes to good use:

```julia
julia> forward(b, x)
(rv = -0.5369949942509267, logabsdetjac = 1.4575353795716655)
julia> with_logabsdet_jacobian(b, x)
(-0.5369949942509267, 1.4575353795716655)
```

Similarily

```julia
julia> forward(inv(b), y)
(rv = 0.3688868996596376, logabsdetjac = -1.4575353795716655)
julia> with_logabsdet_jacobian(inverse(b), y)
(0.3688868996596376, -1.4575353795716655)
```

In fact, the purpose of `forward` is to just _do the right thing_, not necessarily "forward". In this function we'll have access to both the original value `x` and the transformed value `y`, so we can compute `logabsdetjac(b, x)` in either direction. Furthermore, in a lot of cases we can re-use a lot of the computation from `b(x)` in the computation of `logabsdetjac(b, x)`, or vice-versa. `forward(b, x)` will take advantage of such opportunities (if implemented).
In fact, the purpose of `with_logabsdet_jacobian` is to just _do the right thing_, not necessarily "forward". In this function we'll have access to both the original value `x` and the transformed value `y`, so we can compute `logabsdetjac(b, x)` in either direction. Furthermore, in a lot of cases we can re-use a lot of the computation from `b(x)` in the computation of `logabsdetjac(b, x)`, or vice-versa. `with_logabsdet_jacobian(b, x)` will take advantage of such opportunities (if implemented).

#### Sampling from `TransformedDistribution`
At this point we've only shown that we can replicate the existing functionality. But we said `TransformedDistribution isa Distribution`, so we also have `rand`:
Expand All @@ -241,7 +241,7 @@ At this point we've only shown that we can replicate the existing functionality.
julia> y = rand(td) # ∈ ℝ
0.999166054552483

julia> x = inv(td.transform)(y) # transform back to interval [0, 1]
julia> x = inverse(td.transform)(y) # transform back to interval [0, 1]
0.7308945834125756
```

Expand All @@ -261,7 +261,7 @@ Beta{Float64}(α=2.0, β=2.0)
julia> b = bijector(dist) # (0, 1) → ℝ
Logit{Float64}(0.0, 1.0)

julia> b⁻¹ = inv(b) # ℝ → (0, 1)
julia> b⁻¹ = inverse(b) # ℝ → (0, 1)
Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))

julia> td = transformed(Normal(), b⁻¹) # x ∼ 𝓝(0, 1) then b(x) ∈ (0, 1)
Expand All @@ -280,7 +280,7 @@ It's worth noting that `support(Beta)` is the _closed_ interval `[0, 1]`, while
```julia
td = transformed(Beta())

inv(td.transform)(rand(td))
inverse(td.transform)(rand(td))
```

will never result in `0` or `1` though any sample arbitrarily close to either `0` or `1` is possible. _Disclaimer: numerical accuracy is limited, so you might still see `0` and `1` if you're lucky._
Expand Down Expand Up @@ -335,7 +335,7 @@ julia> # Construct the transform
bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists
(Logit{Float64}(0.0, 1.0), Log{0}(), SimplexBijector{true}())

julia> ibs = inv.(bs) # invert, so we get unconstrained-to-constrained
julia> ibs = inverse.(bs) # invert, so we get unconstrained-to-constrained
(Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Exp{0}(), Inverse{SimplexBijector{true},1}(SimplexBijector{true}()))

julia> sb = Stacked(ibs, ranges) # => Stacked <: Bijector
Expand Down Expand Up @@ -411,7 +411,7 @@ Similarily to the multivariate ADVI example, we could use `Stacked` to get a _bo
```julia
julia> d = MvNormal(zeros(2), ones(2));

julia> ibs = inv.(bijector.((InverseGamma(2, 3), Beta())));
julia> ibs = inverse.(bijector.((InverseGamma(2, 3), Beta())));

julia> sb = stack(ibs...) # == Stacked(ibs) == Stacked(ibs, [i:i for i = 1:length(ibs)]
Stacked{Tuple{Exp{0},Inverse{Logit{Float64},0}},2}((Exp{0}(), Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))), (1:1, 2:2))
Expand Down Expand Up @@ -481,7 +481,7 @@ julia> Flux.params(flow)
Params([[-1.05099; 0.502079] (tracked), [-0.216248; -0.706424] (tracked), [-4.33747] (tracked)])
```

Another useful function is the `forward(d::Distribution)` method. It is similar to `forward(b::Bijector)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path.
Another useful function is the `forward(d::Distribution)` method. It is similar to `with_logabsdet_jacobian(b::Bijector, x)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path.

```julia
julia> x, y, logjac, logpdf_y = forward(flow) # sample + transform and returns all the useful quantities in one pass
Expand Down Expand Up @@ -542,41 +542,43 @@ Logit{Float64}(0.0, 1.0)
julia> b(0.6)
0.4054651081081642

julia> inv(b)(y)
julia> inverse(b)(y)
Tracked 2-element Array{Float64,1}:
0.3078149833748082
0.72380041667891

julia> logabsdetjac(b, 0.6)
1.4271163556401458

julia> logabsdetjac(inv(b), y) # defaults to `- logabsdetjac(b, inv(b)(x))`
julia> logabsdetjac(inverse(b), y) # defaults to `- logabsdetjac(b, inverse(b)(x))`
Tracked 2-element Array{Float64,1}:
-1.546158373866469
-1.6098711387913573

julia> forward(b, 0.6) # defaults to `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`
(rv = 0.4054651081081642, logabsdetjac = 1.4271163556401458)
julia> with_logabsdet_jacobian(b, 0.6) # defaults to `(b(x), logabsdetjac(b, x))`
(0.4054651081081642, 1.4271163556401458)
```

For further efficiency, one could manually implement `forward(b::Logit, x)`:
For further efficiency, one could manually implement `with_logabsdet_jacobian(b::Logit, x)`:

```julia
julia> import Bijectors: forward, Logit
julia> using Bijectors: Logit

julia> function forward(b::Logit{<:Real}, x)
julia> import Bijectors: with_logabsdet_jacobian

julia> function with_logabsdet_jacobian(b::Logit{<:Real}, x)
totally_worth_saving = @. (x - b.a) / (b.b - b.a) # spoiler: it's probably not
y = logit.(totally_worth_saving)
logjac = @. - log((b.b - x) * totally_worth_saving)
return (rv=y, logabsdetjac = logjac)
return (y, logjac)
end
forward (generic function with 16 methods)

julia> forward(b, 0.6)
(rv = 0.4054651081081642, logabsdetjac = 1.4271163556401458)
julia> with_logabsdet_jacobian(b, 0.6)
(0.4054651081081642, 1.4271163556401458)

julia> @which forward(b, 0.6)
forward(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2
julia> @which with_logabsdet_jacobian(b, 0.6)
with_logabsdet_jacobian(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2
```

As you can see it's a very contrived example, but you get the idea.
Expand Down Expand Up @@ -613,10 +615,10 @@ julia> logabsdetjac(b_ad, 0.6)
julia> y = b_ad(0.6)
0.4054651081081642

julia> inv(b_ad)(y)
julia> inverse(b_ad)(y)
0.6

julia> logabsdetjac(inv(b_ad), y)
julia> logabsdetjac(inverse(b_ad), y)
-1.4271163556401458
```

Expand Down Expand Up @@ -665,7 +667,7 @@ help?> Bijectors.Composed

A Bijector representing composition of bijectors. composel and composer results in a Composed for which application occurs from left-to-right and right-to-left, respectively.

Note that all the alternative ways of constructing a Composed returns a Tuple of bijectors. This ensures type-stability of implementations of all relating methdos, e.g. inv.
Note that all the alternative ways of constructing a Composed returns a Tuple of bijectors. This ensures type-stability of implementations of all relating methods, e.g. inverse.

If you want to use an Array as the container instead you can do

Expand Down Expand Up @@ -713,9 +715,9 @@ The distribution interface consists of:
#### Methods
The following methods are implemented by all subtypes of `Bijector`, this also includes bijectors such as `Composed`.
- `(b::Bijector)(x)`: implements the transform of the `Bijector`
- `inv(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`.
- `inverse(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`.
- `logabsdetjac(b::Bijector, x)`: computes log(abs(det(jacobian(b, x)))).
- `forward(b::Bijector, x)`: returns named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))` in the most efficient manner.
- `with_logabsdet_jacobian(b::Bijector, x)`: returns the tuple `(b(x), logabsdetjac(b, x))` in the most efficient manner.
- `∘`, `composel`, `composer`: convenient and type-safe constructors for `Composed`. `composel(bs...)` composes s.t. the resulting composition is evaluated left-to-right, while `composer(bs...)` is evaluated right-to-left. `∘` is right-to-left, as excepted from standard mathematical notation.
- `jacobian(b::Bijector, x)` [OPTIONAL]: returns the Jacobian of the transformation. In some cases the analytical Jacobian has been implemented for efficiency.
- `dimension(b::Bijector)`: returns the dimensionality of `b`.
Expand All @@ -725,7 +727,7 @@ For `TransformedDistribution`, together with default implementations for `Distri
- `bijector(d::Distribution)`: returns the default constrained-to-unconstrained bijector for `d`
- `transformed(d::Distribution)`, `transformed(d::Distribution, b::Bijector)`: constructs a `TransformedDistribution` from `d` and `b`.
- `logpdf_forward(d::Distribution, x)`, `logpdf_forward(d::Distribution, x, logjac)`: computes the `logpdf(td, td.transform(x))` using the forward pass, which is potentially faster depending on the transform at hand.
- `forward(d::Distribution)`: returns `(x = rand(dist), y = b(x), logabsdetjac = logabsdetjac(b, x), logpdf = logpdf_forward(td, x))` where `b = td.transform`. This combines sampling from base distribution and transforming into one function. The intention is that this entire process should be performed in the most efficient manner, e.g. the `logabsdetjac(b, x)` call might instead be implemented as `- logabsdetjac(inv(b), b(x))` depending on which is most efficient.
- `forward(d::Distribution)`: returns `(x = rand(dist), y = b(x), logabsdetjac = logabsdetjac(b, x), logpdf = logpdf_forward(td, x))` where `b = td.transform`. This combines sampling from base distribution and transforming into one function. The intention is that this entire process should be performed in the most efficient manner, e.g. the `logabsdetjac(b, x)` call might instead be implemented as `- logabsdetjac(inverse(b), b(x))` depending on which is most efficient.

# Bibliography
1. Rezende, D. J., & Mohamed, S. (2015). Variational Inference With Normalizing Flows. [arXiv:1505.05770](https://arxiv.org/abs/1505.05770v6).
Expand Down
16 changes: 13 additions & 3 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ using MappedArrays
using Base.Iterators: drop
using LinearAlgebra: AbstractTriangular

using ChangesOfVariables: ChangesOfVariables, with_logabsdet_jacobian
using InverseFunctions: InverseFunctions, inverse

import ChainRulesCore
import Functors
import IrrationalConstants
Expand Down Expand Up @@ -121,7 +124,7 @@ end
# Distributions

link(d::Distribution, x) = bijector(d)(x)
invlink(d::Distribution, y) = inv(bijector(d))(y)
invlink(d::Distribution, y) = inverse(bijector(d))(y)
function logpdf_with_trans(d::Distribution, x, transform::Bool)
if ispd(d)
return pd_logpdf_with_trans(d, x, transform)
Expand Down Expand Up @@ -188,14 +191,14 @@ function invlink(
y::AbstractVecOrMat{<:Real},
::Val{proj}=Val(true),
) where {proj}
return inv(SimplexBijector{proj}())(y)
return inverse(SimplexBijector{proj}())(y)
end
function invlink_jacobian(
d::Dirichlet,
y::AbstractVector{<:Real},
::Val{proj}=Val(true),
) where {proj}
return jacobian(inv(SimplexBijector{proj}()), y)
return jacobian(inverse(SimplexBijector{proj}()), y)
end

## Matrix
Expand Down Expand Up @@ -249,6 +252,13 @@ include("utils.jl")
include("interface.jl")
include("chainrules.jl")

Base.@deprecate forward(b::AbstractBijector, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x))

@noinline function Base.inv(b::AbstractBijector)
Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `InverseFunctions.inverse(b)` instead.", :inv)
inverse(b)
end

# Broadcasting here breaks Tracker for some reason
maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...)
maporbroadcast(f, x::AbstractArray...) = f.(x...)
Expand Down
Loading