Skip to content

Commit 4888932

Browse files
committed
fixed VI
1 parent 320af8c commit 4888932

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

src/variational/VariationalInference.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
module Variational
22

3-
import AdvancedVI
4-
import Bijectors
5-
import DistributionsAD
6-
import DynamicPPL
7-
import StatsBase
8-
import StatsFuns
3+
using AdvancedVI
4+
using Bijectors
5+
using DistributionsAD
6+
using DynamicPPL
7+
using StatsBase
8+
using StatsFuns
9+
using Distributions
910

10-
import Random
11+
using Random
1112

1213
# Reexports
1314
using AdvancedVI: vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad

src/variational/advi.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@ end
66

77
Bijectors.inverse(f::Vec) = Vec(Bijectors.inverse(f.b), f.size)
88

9-
function (f::Vec)(x::AbstractVector)
9+
function Bijectors.with_logabsdet_jacobian(f::Vec, x)
10+
return Bijectors.transform(f, x), Bijectors.logabsdetjac(f, x)
11+
end
12+
13+
function Bijectors.transform(f::Vec, x::AbstractVector)
1014
# Reshape into shape compatible with wrapped bijector and then `vec` again.
1115
return vec(f.b(reshape(x, f.size)))
1216
end
1317

14-
function (f::Vec)(x::AbstractMatrix)
18+
function Bijectors.transform(f::Vec, x::AbstractMatrix)
1519
# At the moment we do batching for higher-than-1-dim spaces by simply using
1620
# lists of inputs rather than `AbstractArray` with `N + 1` dimension.
1721
cols = Iterators.Stateful(eachcol(x))
@@ -68,11 +72,10 @@ function Bijectors.bijector(
6872

6973
bs = map(tuple(dists...)) do d
7074
b = Bijectors.bijector(d)
71-
72-
return if Bijectors.dimension(b) > 1
73-
Vec(b, size(d))
74-
else
75+
if d isa Distributions.UnivariateDistribution
7576
b
77+
else
78+
Vec(b, size(d))
7679
end
7780
end
7881

0 commit comments

Comments
 (0)