Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ChangesOfVariables and InverseFunctions #212

Merged
merged 29 commits into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
31adbe8
Add ChangesOfVariables and InverseFunctions to deps
oschulz Dec 10, 2021
000c83f
Replace forward by with_logabsdet_jacobian
oschulz Dec 10, 2021
0c1bf48
Replace Base.inv with InverseFunctions.inverse
oschulz Dec 10, 2021
f6385ea
Improve deprecation scheme for forward
oschulz Dec 10, 2021
0aab88b
Improve deprecation scheme for inv
oschulz Dec 10, 2021
e6f549d
Test forward and inv deprecations
oschulz Dec 10, 2021
4c7f706
Apply suggestions from code review
oschulz Dec 11, 2021
3815505
Fixes regarding with_logabsdet_jacobian and inverse
oschulz Dec 11, 2021
2b93560
Fix with_logabsdet_jacobian for NamedComposition
oschulz Dec 11, 2021
20e50d4
Fix deprecation of inv
oschulz Dec 11, 2021
fc990bb
Use inverse instead of inv for Composed
oschulz Dec 11, 2021
625fbb7
Use with_logabsdet_jacobian instead of forward
oschulz Dec 11, 2021
8273628
Workaround for intermittent failures in Dirichlet test
oschulz Dec 11, 2021
167a26e
Use with_logabsdet_jacobian instead of forward
oschulz Dec 11, 2021
8a0c658
Use with_logabsdet_jacobian instead of forward
oschulz Dec 11, 2021
34f7fc6
Add rrules for combine with PartitionMask
oschulz Dec 12, 2021
2f1c36d
Use inv instead of inverse for numbers
oschulz Dec 12, 2021
012d90a
Apply suggestions from code review
oschulz Dec 12, 2021
3fc7a43
Whitespace fix.
oschulz Dec 12, 2021
045412f
Move combine rrule and add test
oschulz Dec 12, 2021
7e96b8d
Apply suggestions from code review
oschulz Dec 12, 2021
874b5ec
Use @test_deprecated
oschulz Dec 12, 2021
d9c8562
Use @test_deprecated
oschulz Dec 12, 2021
5cad1e4
Use inverse instead of inv
oschulz Dec 12, 2021
6ed6fa9
Use test_inverse and test_with_logabsdet_jacobian
oschulz Dec 12, 2021
4fadffc
Use inverse instead of inv
oschulz Dec 12, 2021
5f4d982
Increase version number to v0.9.12
oschulz Dec 12, 2021
fb54734
Reexport with_logabsdet_jacobian and inverse
oschulz Dec 13, 2021
4b683e5
Increase package version to v0.10.0
oschulz Dec 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ version = "0.9.11"
[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Expand All @@ -22,9 +24,11 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
ArgCheck = "1, 2"
ChainRulesCore = "0.10.11, 1"
ChangesOfVariables = "0.1"
Compat = "3"
Distributions = "0.23.3, 0.24, 0.25"
Functors = "0.1, 0.2"
InverseFunctions = "0.1"
IrrationalConstants = "0.1"
LogExpFunctions = "0.3.3"
MappedArrays = "0.2.2, 0.3, 0.4"
Expand Down
71 changes: 36 additions & 35 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ The following table lists mathematical operations for a bijector and the corresp

| Operation | Method | Automatic |
|:------------------------------------:|:-----------------:|:-----------:|
| `b ↦ b⁻¹` | `inv(b)` | ✓ |
| `b ↦ b⁻¹` | `inverse(b)` | ✓ |
| `(b₁, b₂) ↦ (b₁ ∘ b₂)` | `b₁ ∘ b₂` | ✓ |
| `(b₁, b₂) ↦ [b₁, b₂]` | `stack(b₁, b₂)` | ✓ |
| `x ↦ b(x)` | `b(x)` | × |
| `y ↦ b⁻¹(y)` | `inv(b)(y)` | × |
| `y ↦ b⁻¹(y)` | `inverse(b)(y)` | × |
| `x ↦ log|det J(b, x)|` | `logabsdetjac(b, x)` | AD |
| `x ↦ b(x), log|det J(b, x)|` | `forward(b, x)` | ✓ |
| `x ↦ b(x), log|det J(b, x)|` | `with_logabsdet_jacobian(b, x)` | ✓ |
| `p ↦ q := b_* p` | `q = transformed(p, b)` | ✓ |
| `y ∼ q` | `y = rand(q)` | ✓ |
| `p ↦ b` such that `support(b_* p) = ℝᵈ` | `bijector(p)` | ✓ |
Expand Down Expand Up @@ -123,7 +123,7 @@ true
What about `invlink`?

```julia
julia> b⁻¹ = inv(b)
julia> b⁻¹ = inverse(b)
Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))

julia> b⁻¹(y)
Expand All @@ -133,7 +133,7 @@ julia> b⁻¹(y) ≈ invlink(dist, y)
true
```

Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inv(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inv(Exp()) isa Log` is true.
Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inverse(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inverse(Exp()) isa Log` is true.

#### Dimensionality
One more thing. See the `0` in `Inverse{Logit{Float64}, 0}`? It represents the *dimensionality* of the bijector, in the same sense as for an `AbstractArray` with the exception of `0` which means it expects 0-dim input and output, i.e. `<:Real`. This can also be accessed through `dimension(b)`:
Expand Down Expand Up @@ -162,7 +162,7 @@ true
And since `Composed isa Bijector`:

```julia
julia> id_x = inv(id_y)
julia> id_x = inverse(id_y)
Composed{Tuple{Inverse{Logit{Float64},0},Logit{Float64}},0}((Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Logit{Float64}(0.0, 1.0)))

julia> id_x(x) ≈ x
Expand Down Expand Up @@ -201,7 +201,7 @@ julia> logpdf_forward(td, x)

#### `logabsdetjac` and `forward`
oschulz marked this conversation as resolved.
Show resolved Hide resolved

In the computation of both `logpdf` and `logpdf_forward` we need to compute `log(abs(det(jacobian(inv(b), y))))` and `log(abs(det(jacobian(b, x))))`, respectively. This computation is available using the `logabsdetjac` method
In the computation of both `logpdf` and `logpdf_forward` we need to compute `log(abs(det(jacobian(inverse(b), y))))` and `log(abs(det(jacobian(b, x))))`, respectively. This computation is available using the `logabsdetjac` method

```julia
julia> logabsdetjac(b⁻¹, y)
Expand All @@ -221,18 +221,18 @@ true
which is always the case for a differentiable bijection with differentiable inverse. Therefore if you want to compute `logabsdetjac(b⁻¹, y)` and we know that `logabsdetjac(b, b⁻¹(y))` is actually more efficient, we'll return `-logabsdetjac(b, b⁻¹(y))` instead. For some bijectors it might be easy to compute, say, the forward pass `b(x)`, but expensive to compute `b⁻¹(y)`. Because of this you might want to avoid doing anything "backwards", i.e. using `b⁻¹`. This is where `forward` comes to good use:
oschulz marked this conversation as resolved.
Show resolved Hide resolved

```julia
julia> forward(b, x)
(rv = -0.5369949942509267, logabsdetjac = 1.4575353795716655)
julia> with_logabsdet_jacobian(b, x)
(-0.5369949942509267, 1.4575353795716655)
```

Similarily

```julia
julia> forward(inv(b), y)
(rv = 0.3688868996596376, logabsdetjac = -1.4575353795716655)
julia> forward(inverse(b), y)
oschulz marked this conversation as resolved.
Show resolved Hide resolved
(0.3688868996596376, -1.4575353795716655)
```

In fact, the purpose of `forward` is to just _do the right thing_, not necessarily "forward". In this function we'll have access to both the original value `x` and the transformed value `y`, so we can compute `logabsdetjac(b, x)` in either direction. Furthermore, in a lot of cases we can re-use a lot of the computation from `b(x)` in the computation of `logabsdetjac(b, x)`, or vice-versa. `forward(b, x)` will take advantage of such opportunities (if implemented).
In fact, the purpose of `with_logabsdet_jacobian` is to just _do the right thing_, not necessarily "forward". In this function we'll have access to both the original value `x` and the transformed value `y`, so we can compute `logabsdetjac(b, x)` in either direction. Furthermore, in a lot of cases we can re-use a lot of the computation from `b(x)` in the computation of `logabsdetjac(b, x)`, or vice-versa. `with_logabsdet_jacobian(b, x)` will take advantage of such opportunities (if implemented).

#### Sampling from `TransformedDistribution`
At this point we've only shown that we can replicate the existing functionality. But we said `TransformedDistribution isa Distribution`, so we also have `rand`:
Expand All @@ -241,7 +241,7 @@ At this point we've only shown that we can replicate the existing functionality.
julia> y = rand(td) # ∈ ℝ
0.999166054552483

julia> x = inv(td.transform)(y) # transform back to interval [0, 1]
julia> x = inverse(td.transform)(y) # transform back to interval [0, 1]
0.7308945834125756
```

Expand All @@ -261,7 +261,7 @@ Beta{Float64}(α=2.0, β=2.0)
julia> b = bijector(dist) # (0, 1) → ℝ
Logit{Float64}(0.0, 1.0)

julia> b⁻¹ = inv(b) # ℝ → (0, 1)
julia> b⁻¹ = inverse(b) # ℝ → (0, 1)
Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))

julia> td = transformed(Normal(), b⁻¹) # x ∼ 𝓝(0, 1) then b(x) ∈ (0, 1)
Expand All @@ -280,7 +280,7 @@ It's worth noting that `support(Beta)` is the _closed_ interval `[0, 1]`, while
```julia
td = transformed(Beta())

inv(td.transform)(rand(td))
inverse(td.transform)(rand(td))
```

will never result in `0` or `1` though any sample arbitrarily close to either `0` or `1` is possible. _Disclaimer: numerical accuracy is limited, so you might still see `0` and `1` if you're lucky._
Expand Down Expand Up @@ -335,7 +335,7 @@ julia> # Construct the transform
bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists
(Logit{Float64}(0.0, 1.0), Log{0}(), SimplexBijector{true}())

julia> ibs = inv.(bs) # invert, so we get unconstrained-to-constrained
julia> ibs = inverse.(bs) # invert, so we get unconstrained-to-constrained
(Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Exp{0}(), Inverse{SimplexBijector{true},1}(SimplexBijector{true}()))

julia> sb = Stacked(ibs, ranges) # => Stacked <: Bijector
Expand Down Expand Up @@ -411,7 +411,7 @@ Similarily to the multivariate ADVI example, we could use `Stacked` to get a _bo
```julia
julia> d = MvNormal(zeros(2), ones(2));

julia> ibs = inv.(bijector.((InverseGamma(2, 3), Beta())));
julia> ibs = inverse.(bijector.((InverseGamma(2, 3), Beta())));

julia> sb = stack(ibs...) # == Stacked(ibs) == Stacked(ibs, [i:i for i = 1:length(ibs)]
Stacked{Tuple{Exp{0},Inverse{Logit{Float64},0}},2}((Exp{0}(), Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))), (1:1, 2:2))
Expand Down Expand Up @@ -481,7 +481,7 @@ julia> Flux.params(flow)
Params([[-1.05099; 0.502079] (tracked), [-0.216248; -0.706424] (tracked), [-4.33747] (tracked)])
```

Another useful function is the `forward(d::Distribution)` method. It is similar to `forward(b::Bijector)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path.
Another useful function is the `forward(d::Distribution)` method. It is similar to `with_logabsdet_jacobian(b::Bijector, x)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path.

```julia
julia> x, y, logjac, logpdf_y = forward(flow) # sample + transform and returns all the useful quantities in one pass
Expand Down Expand Up @@ -542,41 +542,42 @@ Logit{Float64}(0.0, 1.0)
julia> b(0.6)
0.4054651081081642

julia> inv(b)(y)
julia> inverse(b)(y)
Tracked 2-element Array{Float64,1}:
0.3078149833748082
0.72380041667891

julia> logabsdetjac(b, 0.6)
1.4271163556401458

julia> logabsdetjac(inv(b), y) # defaults to `- logabsdetjac(b, inv(b)(x))`
julia> logabsdetjac(inverse(b), y) # defaults to `- logabsdetjac(b, inverse(b)(x))`
Tracked 2-element Array{Float64,1}:
-1.546158373866469
-1.6098711387913573

julia> forward(b, 0.6) # defaults to `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`
(rv = 0.4054651081081642, logabsdetjac = 1.4271163556401458)
julia> with_logabsdet_jacobian(b, 0.6) # defaults to `(b(x), logabsdetjac(b, x))`
(0.4054651081081642, 1.4271163556401458)
```

For further efficiency, one could manually implement `forward(b::Logit, x)`:
For further efficiency, one could manually implement `with_logabsdet_jacobian(b::Logit, x)`:

```julia
julia> import Bijectors: forward, Logit
julia> import ChangesOfVariables: with_logabsdet_jacobian
oschulz marked this conversation as resolved.
Show resolved Hide resolved

julia> function forward(b::Logit{<:Real}, x)
julia> function with_logabsdet_jacobian(b::Logit{<:Real}, x)
totally_worth_saving = @. (x - b.a) / (b.b - b.a) # spoiler: it's probably not
y = logit.(totally_worth_saving)
logjac = @. - log((b.b - x) * totally_worth_saving)
return (rv=y, logabsdetjac = logjac)
return (y, logjac)
end
forward (generic function with 16 methods)

julia> forward(b, 0.6)
(rv = 0.4054651081081642, logabsdetjac = 1.4271163556401458)
julia> with_logabsdet_jacobian(b, 0.6)
(0.4054651081081642, 1.4271163556401458)

julia> @which forward(b, 0.6)
forward(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2
julia> @which with_logabsdet_jacobian(b, 0.6)
with_logabsdet_jacobian(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2
```

As you can see it's a very contrived example, but you get the idea.
Expand Down Expand Up @@ -613,10 +614,10 @@ julia> logabsdetjac(b_ad, 0.6)
julia> y = b_ad(0.6)
0.4054651081081642

julia> inv(b_ad)(y)
julia> inverse(b_ad)(y)
0.6

julia> logabsdetjac(inv(b_ad), y)
julia> logabsdetjac(inverse(b_ad), y)
-1.4271163556401458
```

Expand Down Expand Up @@ -665,7 +666,7 @@ help?> Bijectors.Composed

A Bijector representing composition of bijectors. composel and composer results in a Composed for which application occurs from left-to-right and right-to-left, respectively.

Note that all the alternative ways of constructing a Composed returns a Tuple of bijectors. This ensures type-stability of implementations of all relating methdos, e.g. inv.
Note that all the alternative ways of constructing a Composed returns a Tuple of bijectors. This ensures type-stability of implementations of all relating methods, e.g. inverse.

If you want to use an Array as the container instead you can do

Expand Down Expand Up @@ -713,9 +714,9 @@ The distribution interface consists of:
#### Methods
The following methods are implemented by all subtypes of `Bijector`, this also includes bijectors such as `Composed`.
- `(b::Bijector)(x)`: implements the transform of the `Bijector`
- `inv(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`.
- `inverse(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`.
- `logabsdetjac(b::Bijector, x)`: computes log(abs(det(jacobian(b, x)))).
- `forward(b::Bijector, x)`: returns named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))` in the most efficient manner.
- `with_logabsdet_jacobian(b::Bijector, x)`: returns named tuple `(b(x), logabsdetjac(b, x))` in the most efficient manner.
oschulz marked this conversation as resolved.
Show resolved Hide resolved
- `∘`, `composel`, `composer`: convenient and type-safe constructors for `Composed`. `composel(bs...)` composes s.t. the resulting composition is evaluated left-to-right, while `composer(bs...)` is evaluated right-to-left. `∘` is right-to-left, as excepted from standard mathematical notation.
- `jacobian(b::Bijector, x)` [OPTIONAL]: returns the Jacobian of the transformation. In some cases the analytical Jacobian has been implemented for efficiency.
- `dimension(b::Bijector)`: returns the dimensionality of `b`.
Expand All @@ -725,7 +726,7 @@ For `TransformedDistribution`, together with default implementations for `Distri
- `bijector(d::Distribution)`: returns the default constrained-to-unconstrained bijector for `d`
- `transformed(d::Distribution)`, `transformed(d::Distribution, b::Bijector)`: constructs a `TransformedDistribution` from `d` and `b`.
- `logpdf_forward(d::Distribution, x)`, `logpdf_forward(d::Distribution, x, logjac)`: computes the `logpdf(td, td.transform(x))` using the forward pass, which is potentially faster depending on the transform at hand.
- `forward(d::Distribution)`: returns `(x = rand(dist), y = b(x), logabsdetjac = logabsdetjac(b, x), logpdf = logpdf_forward(td, x))` where `b = td.transform`. This combines sampling from base distribution and transforming into one function. The intention is that this entire process should be performed in the most efficient manner, e.g. the `logabsdetjac(b, x)` call might instead be implemented as `- logabsdetjac(inv(b), b(x))` depending on which is most efficient.
- `forward(d::Distribution)`: returns `(x = rand(dist), y = b(x), logabsdetjac = logabsdetjac(b, x), logpdf = logpdf_forward(td, x))` where `b = td.transform`. This combines sampling from base distribution and transforming into one function. The intention is that this entire process should be performed in the most efficient manner, e.g. the `logabsdetjac(b, x)` call might instead be implemented as `- logabsdetjac(inverse(b), b(x))` depending on which is most efficient.

# Bibliography
1. Rezende, D. J., & Mohamed, S. (2015). Variational Inference With Normalizing Flows. [arXiv:1505.05770](https://arxiv.org/abs/1505.05770v6).
Expand Down
16 changes: 13 additions & 3 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ using MappedArrays
using Base.Iterators: drop
using LinearAlgebra: AbstractTriangular

import ChangesOfVariables: with_logabsdet_jacobian
import InverseFunctions: inverse
oschulz marked this conversation as resolved.
Show resolved Hide resolved

import ChainRulesCore
import Functors
import IrrationalConstants
Expand Down Expand Up @@ -121,7 +124,7 @@ end
# Distributions

link(d::Distribution, x) = bijector(d)(x)
invlink(d::Distribution, y) = inv(bijector(d))(y)
invlink(d::Distribution, y) = inverse(bijector(d))(y)
function logpdf_with_trans(d::Distribution, x, transform::Bool)
if ispd(d)
return pd_logpdf_with_trans(d, x, transform)
Expand Down Expand Up @@ -188,14 +191,14 @@ function invlink(
y::AbstractVecOrMat{<:Real},
::Val{proj}=Val(true),
) where {proj}
return inv(SimplexBijector{proj}())(y)
return inverse(SimplexBijector{proj}())(y)
end
function invlink_jacobian(
d::Dirichlet,
y::AbstractVector{<:Real},
::Val{proj}=Val(true),
) where {proj}
return jacobian(inv(SimplexBijector{proj}()), y)
return jacobian(inverse(SimplexBijector{proj}()), y)
end

## Matrix
Expand Down Expand Up @@ -249,6 +252,13 @@ include("utils.jl")
include("interface.jl")
include("chainrules.jl")

Base.@deprecate forward(b::AbstractBijector, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x))

@noinline function Base.inv(b::AbstractBijector)
Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `InverseFunctions.inverse(b)` instead.", :(Base.inv))
inverse(b)
end

# Broadcasting here breaks Tracker for some reason
maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...)
maporbroadcast(f, x::AbstractArray...) = f.(x...)
Expand Down
26 changes: 13 additions & 13 deletions src/bijectors/composed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ A `Bijector` representing composition of bijectors. `composel` and `composer` re
`Composed` for which application occurs from left-to-right and right-to-left, respectively.

Note that all the alternative ways of constructing a `Composed` returns a `Tuple` of bijectors.
This ensures type-stability of implementations of all relating methdos, e.g. `inv`.
This ensures type-stability of implementations of all relating methdos, e.g. `inverse`.

If you want to use an `Array` as the container instead you can do

Expand All @@ -41,7 +41,7 @@ Composed{Tuple{Exp{0},Exp{0}},0}((Exp{0}(), Exp{0}()))
julia> (b ∘ b)(1.0) == exp(exp(1.0)) # evaluation
true

julia> inv(b ∘ b)(exp(exp(1.0))) == 1.0 # inversion
julia> inverse(b ∘ b)(exp(exp(1.0))) == 1.0 # inversion
true

julia> logabsdetjac(b ∘ b, 1.0) # determinant of jacobian
Expand Down Expand Up @@ -153,7 +153,7 @@ end
∘(::Identity{N}, b::Bijector{N}) where {N} = b
∘(b::Bijector{N}, ::Identity{N}) where {N} = b

inv(ct::Composed) = Composed(reverse(map(inv, ct.ts)))
inverse(ct::Composed) = Composed(reverse(map(inv, ct.ts)))

# # TODO: should arrays also be using recursive implementation instead?
function (cb::Composed{<:AbstractArray{<:Bijector}})(x)
Expand All @@ -179,8 +179,8 @@ function logabsdetjac(cb::Composed, x)
y, logjac = forward(cb.ts[1], x)
oschulz marked this conversation as resolved.
Show resolved Hide resolved
for i = 2:length(cb.ts)
res = forward(cb.ts[i], y)
oschulz marked this conversation as resolved.
Show resolved Hide resolved
y = res.rv
logjac += res.logabsdetjac
y = res[1]
logjac += res[2]
oschulz marked this conversation as resolved.
Show resolved Hide resolved
end

return logjac
Expand All @@ -195,8 +195,8 @@ end
for i = 2:N - 1
temp = gensym(:res)
push!(expr.args, :($temp = forward(cb.ts[$i], y)))
oschulz marked this conversation as resolved.
Show resolved Hide resolved
push!(expr.args, :(y = $temp.rv))
push!(expr.args, :(logjac += $temp.logabsdetjac))
push!(expr.args, :(y = $temp[1]))
push!(expr.args, :(logjac += $temp[2]))
oschulz marked this conversation as resolved.
Show resolved Hide resolved
end
# don't need to evaluate the last bijector, only it's `logabsdetjac`
push!(expr.args, :(logjac += logabsdetjac(cb.ts[$N], y)))
Expand All @@ -212,10 +212,10 @@ function forward(cb::Composed, x)

for t in cb.ts[2:end]
res = forward(t, rv)
oschulz marked this conversation as resolved.
Show resolved Hide resolved
rv = res.rv
logjac = res.logabsdetjac + logjac
rv = res[1]
logjac = res[2] + logjac
oschulz marked this conversation as resolved.
Show resolved Hide resolved
end
return (rv=rv, logabsdetjac=logjac)
return (rv, logjac)
end


Expand All @@ -225,10 +225,10 @@ end
for i = 2:length(T.parameters)
temp = gensym(:temp)
push!(expr.args, :($temp = forward(cb.ts[$i], y)))
oschulz marked this conversation as resolved.
Show resolved Hide resolved
push!(expr.args, :(y = $temp.rv))
push!(expr.args, :(logjac += $temp.logabsdetjac))
push!(expr.args, :(y = $temp[1]))
push!(expr.args, :(logjac += $temp[2]))
oschulz marked this conversation as resolved.
Show resolved Hide resolved
end
push!(expr.args, :(return (rv = y, logabsdetjac = logjac)))
push!(expr.args, :(return (y, logjac)))

return expr
end
Loading