Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement von Mises-Fisher Distribution #302

Merged
merged 2 commits into from
Nov 7, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ export
cquantile, # complementary quantile (i.e. using prob in right hand tail)
cumulant, # cumulants of distribution
complete, # turn an incomplete formulation into a complete distribution
concentration, # the concentration parameter
dim, # sample dimension of multivariate distribution
entropy, # entropy of distribution in nats
fit, # fit a distribution to data (using default method)
Expand Down Expand Up @@ -183,6 +184,7 @@ export
sqmahal, # squared Mahalanobis distance to Gaussian center
sqmahal!, # inplace evaluation of sqmahal
mean, # mean of distribution
meandir, # mean direction (of a spherical distribution)
meanform, # convert a normal distribution from canonical form to mean form
median, # median of distribution
mgf, # moment generating function
Expand Down
199 changes: 90 additions & 109 deletions src/multivariate/vonmisesfisher.jl
Original file line number Diff line number Diff line change
@@ -1,130 +1,111 @@
# Von-Mises Fisher: a multivariate distribution useful in directional statistics

# Useful notes:
# http://www.mitsuba-renderer.org/~wenzel/vmf.pdf
# Some of the code adapted from http://www.unc.edu/~sungkyu/manifolds/randvonMisesFisherm.m
# as well as the movMF R package.
# von Mises-Fisher distribution is useful for directional statistics
#
# The implementation here follows:
#
# - Wikipedia:
# http://en.wikipedia.org/wiki/Von_Mises–Fisher_distribution
#
# - R's movMF package's document:
# http://cran.r-project.org/web/packages/movMF/vignettes/movMF.pdf
#
# - Wenzel Jakob's notes:
# http://www.mitsuba-renderer.org/~wenzel/files/vmf.pdf
#

immutable VonMisesFisher <: ContinuousMultivariateDistribution
mu::Vector{Float64}
kappa::Float64
μ::Vector{Float64}
κ::Float64
logCκ::Float64

function VonMisesFisher{T <: Real}(mu::Vector{T}, kappa::Float64)
mu = mu ./ norm(mu)
if kappa < 0
throw(ArgumentError("kappa must be a nonnegative real number."))
function VonMisesFisher(μ::Vector{Float64}, κ::Float64; checknorm::Bool=true)
if checknorm
isunitvec(μ) || error("μ must be a unit vector")
end
new(float64(mu), kappa)
κ > 0 || error("κ must be positive.")
new(μ, κ, vmflck(length(μ), κ))
end
end

length(d::VonMisesFisher) = length(d.mu)
mean(d::VonMisesFisher) = d.mu
scale(d::VonMisesFisher) = d.kappa
VonMisesFisher{T<:Real}(μ::Vector{T}, κ::Real) = VonMisesFisher(float64(μ), float64(κ))

insupport{T<:Real}(d::VonMisesFisher, x::AbstractVector{T}) = abs(sum(x) - 1.) < 1e-8
VonMisesFisher(θ::Vector{Float64}) = (κ = vecnorm(θ); VonMisesFisher(scale(θ, 1.0 / κ), κ))
VonMisesFisher{T<:Real}(θ::Vector{T}) = VonMisesFisher(float64(θ))

function _logpdf{T<:Real}(d::VonMisesFisher, x::DenseVector{T}; stable=true)
if abs(d.kappa - 0.0) < eps()
return 0.25 / pi
end
if stable
# As suggested by Wenzel Jakob: http://www.mitsuba-renderer.org/~wenzel/vmf.pdf
return d.kappa * dot(d.mu, x) - d.kappa + log(d.kappa) - log(2*pi) - log(1-exp(-2*d.kappa))
else
# As described on Wikipedia
p = length(d)
logCpk = 0.0
if p == 3
logCpk = log(d.kappa) - log(2 * pi * (exp(kappa) - exp(-kappa)))
else
logCpk = (p/2 - 1) * log(d.kappa) - (p/2) * log(2*pi) - log(besselj(p/2-1, d.kappa))
end
return d.kappa * dot(d.mu, x) + logCpk
end
end
show(io::IO, d::VonMisesFisher) = show(io, d, (:μ, :κ))


### Basic properties

# sampling (TODO: make it consistent with the common API)
length(d::VonMisesFisher) = length(d.μ)

function rand(d::VonMisesFisher, n::Int)
randvonMisesFisher(n, d.kappa, d.mu)
meandir(d::VonMisesFisher) = d.μ
concentration(d::VonMisesFisher) = d.κ

insupport{T<:Real}(d::VonMisesFisher, x::DenseVector{T}) = isunitvec(x)


### Evaluation

function _vmflck(p, κ)
hp = 0.5 * p
q = hp - 1.0
q * log(κ) - hp * log(2π) - log(besseli(q, κ))
end
_vmflck3(κ) = log(κ) - log2π - κ - log1mexp(-2.0 * κ)
vmflck(p, κ) = (p == 3 ? _vmflck3(κ) : _vmflck(p, κ))::Float64

function randvonMisesFisher(n, kappa, mu)
m = length(mu)
w = rW(n, kappa, m)
v = rand(MvNormal(zeros(m-1), eye(m-1)), n)

# normalize each column of v
for j = 1:n
s = 0.
vj = view(v,:,j)
for i = 1:size(v,1)
s += abs2(vj[i])
end
s = sqrt(s)
for i = 1:size(v,1)
vj[i] /= s
end
end
v = v'
_logpdf{T<:Real}(d::VonMisesFisher, x::DenseVector{T}) = d.logCκ + d.κ * dot(d.μ, x)


### Sampling

sampler(d::VonMisesFisher) = VonMisesFisherSampler(d.μ, d.κ)

_rand!(d::VonMisesFisher, x::DenseVector) = _rand!(sampler(d), x)
_rand!(d::VonMisesFisher, x::DenseMatrix) = _rand!(sampler(d), x)


### Estimation

r = sqrt(1.0 .- w .^ 2)
for j = 1:size(v,2) v[:,j] = v[:,j] .* r; end
x = hcat(v, w)
mu = mu / norm(mu)
return rotMat(mu)'*x'
function fit_mle(::Type{VonMisesFisher}, X::Matrix{Float64})
r = vec(sum(X, 2))
n = size(X, 2)
r_nrm = vecnorm(r)
μ = scale!(r, 1.0 / r_nrm)
ρ = r_nrm / n
κ = _vmf_estkappa(length(μ), ρ)
VonMisesFisher(μ, κ)
end

# Randomly sample W
function rW(n, kappa, m)
y = zeros(n)
l = kappa;
d = m - 1;
b = (- 2. * l + sqrt(4. * l * l + d * d)) / d;
x = (1. - b) / (1. + b);
c = l * x + d * log(1. - x * x);
w = 0
for i=1:n
done = false
while !done
z = rand(Beta(d / 2., d / 2.))
w = (1. - (1. + b) * z) / (1. - (1. - b) * z);
u = rand()
if l * w + d * log(1. - x * w) - c >= log(u)
done = true
end
fit_mle{T<:Real}(::Type{VonMisesFisher}, X::Matrix{T}) = fit_mle(VonMisesFisher, float64(X))

function _vmf_estkappa(p::Int, ρ::Float64)
# Using the fixed-point iteration algorithm in the following paper:
#
# Akihiro Tanabe, Kenji Fukumizu, and Shigeyuki Oba, Takashi Takenouchi, and Shin Ishii
# Parameter estimation for von Mises-Fisher distributions.
# Computational Statistics, 2007, Vol. 22:145-157.
#

const maxiter = 200
half_p = 0.5 * p

ρ2 = abs2(ρ)
κ = ρ * (p - ρ2) / (1 - ρ2)
i = 0
while i < maxiter
i += 1
κ_prev = κ
a = (ρ / _vmfA(half_p, κ))
# println("i = $i, a = $a, abs(a - 1) = $(abs(a - 1))")
κ *= a
if abs(a - 1.0) < 1.0e-12
break
end
y[i] = w
end
return y
return κ
end

# Rotation helper function
function rotMat(b)
d = length(b)
b= b/norm(b)
a = [zeros(d-1,1); 1]
alpha = acos(a'*b)[1]
c = b - a * (a'*b); c = c / norm(c)
A = a*c' - c*a'
return eye(d) + sin(alpha)*A + (cos(alpha) - 1)*(a*a' +c*c')
end
_vmfA(half_p::Float64, κ::Float64) = besseli(half_p, κ) / besseli(half_p - 1.0, κ)


# Each row of x assumed to be ~ VonMisesFisher(mu, kappa)
# MLE notes from: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.186.1887&rep=rep1&type=pdf

function fit_mle(::Type{VonMisesFisher}, x::Matrix{Float64})
(n,p) = size(x)
sx = sum(x, 1)
mu = sx[:] / norm(sx)
rbar = norm(sx) / n
kappa0 = rbar * (p-rbar^2) / (1-rbar^2) # Eqn. 4
# TODO: Include a few Newton steps to get a better approximation.
# A(p,kappa) = besselj(p/2, kappa) / besselj(p/2-1, kappa)
# apk0 = A(p,kappa0)
# kappa1 = kappa0 + (apk0 - rbar) / (1 - apk0^2 - (p-1)*apk0/kappa0)
# apk1 = A(p,kappa1)
# kappa2 = kappa1 + (apk1 - rbar) / (1 - apk1^2 - (p-1)*apk1/kappa1)
return VonMisesFisher(mu, kappa0)#, kappa1, kappa2)
end
3 changes: 2 additions & 1 deletion src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ for fname in ["categorical.jl",
"exponential.jl",
"gamma.jl",
"multinomial.jl",
"vonmises.jl"]
"vonmises.jl",
"vonmisesfisher.jl"]

include(joinpath("samplers", fname))
end
120 changes: 120 additions & 0 deletions src/samplers/vonmisesfisher.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Sampler for von Mises-Fisher

immutable VonMisesFisherSampler
p::Int # the dimension
κ::Float64
b::Float64
x0::Float64
c::Float64
Q::Matrix{Float64}
end

function VonMisesFisherSampler(μ::Vector{Float64}, κ::Float64)
p = length(μ)
b = _vmf_bval(p, κ)
x0 = (1.0 - b) / (1.0 + b)
c = κ * x0 + (p - 1) * log1p(-abs2(x0))
Q = _vmf_rotmat(μ)
VonMisesFisherSampler(p, κ, b, x0, c, Q)
end

function _rand!(spl::VonMisesFisherSampler, x::DenseVector, t::DenseVector)
w = _vmf_genw(spl)
p = spl.p
t[1] = w
s = 0.0
for i = 2:p
t[i] = ti = randn()
s += abs2(ti)
end

# normalize t[2:p]
r = sqrt((1.0 - abs2(w)) / s)
for i = 2:p
t[i] *= r
end

# rotate
A_mul_B!(x, spl.Q, t)
return x
end

_rand!(spl::VonMisesFisherSampler, x::DenseVector) = _rand!(spl, x, Array(Float64, length(x)))

function _rand!(spl::VonMisesFisherSampler, x::DenseMatrix)
t = Array(Float64, size(x, 1))
for j = 1:size(x, 2)
_rand!(spl, view(x,:,j), t)
end
return x
end


### Core computation

_vmf_bval(p::Int, κ::Real) = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1)))

function _vmf_genw(p, b, x0, c, κ)
# generate the W value -- the key step in simulating vMF
#
# following movMF's document
#

r = (p - 1) / 2.0
betad = Beta(r, r)
z = rand(betad)
w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z)
while κ * w + (p - 1) * log(1 - x0 * w) - c < log(rand())
z = rand(betad)
w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z)
end
return w::Float64
end

_vmf_genw(s::VonMisesFisherSampler) = _vmf_genw(s.p, s.b, s.x0, s.c, s.κ)

function _vmf_rotmat(u::Vector{Float64})
# construct a rotation matrix Q
# s.t. Q * [1,0,...,0]^T --> u
#
# Strategy: construct a full-rank matrix
# with first column being u, and then
# perform QR factorization
#

p = length(u)
A = zeros(p, p)
copy!(view(A,:,1), u)

# let k the be index of entry with max abs
k = 1
a = abs(u[1])
for i = 2:p
@inbounds ai = abs(u[i])
if ai > a
k = i
a = ai
end
end

# other columns of A will be filled with
# indicator vectors, except the one
# that activates the k-th entry
i = 1
for j = 2:p
if i == k
i += 1
end
A[i, j] = 1.0
end

# perform QR factorization
Q = full(qrfact!(A)[:Q])
if dot(view(Q,:,1), u) < 0.0 # the first column was negated
for i = 1:p
@inbounds Q[i,1] = -Q[i,1]
end
end
return Q
end

2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ convert{T}(::Type{Vector{T}}, v::ZeroVector{T}) = full(v)

type NoArgCheck end

isunitvec{T}(v::AbstractVector{T}) = (vecnorm(v) - 1.0) < 1.0e-12

function allfinite{T<:Real}(x::Array{T})
for i = 1 : length(x)
if !(isfinite(x[i]))
Expand Down
Loading