-
Notifications
You must be signed in to change notification settings - Fork 419
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
108 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,96 +1,99 @@ | ||
############################################################################## | ||
# Wishart distribution | ||
# | ||
# Wishart Distribution | ||
# following the Wikipedia parameterization | ||
# | ||
# Parameters nu and S such that E(X) = nu * S | ||
# See the rwish and dwish implementation in R's MCMCPack | ||
# This parametrization differs from Bernardo & Smith p 435 | ||
# in this way: (nu, S) = (2.0 * alpha, 0.5 * beta^-1) | ||
# | ||
############################################################################## | ||
|
||
immutable Wishart <: ContinuousMatrixDistribution | ||
nu::Float64 | ||
Schol::Cholesky{Float64} | ||
function Wishart(n::Real, Sc::Cholesky{Float64}) | ||
if n > size(Sc, 1) - 1 | ||
new(float64(n), Sc) | ||
else | ||
error("Wishart parameters must be df > p - 1") | ||
end | ||
end | ||
|
||
immutable Wishart{ST<:AbstractPDMat} <: ContinuousMatrixDistribution | ||
df::Float64 # degree of freedom | ||
S::ST # the scale matrix | ||
c0::Float64 # the logarithm of normalizing constant in pdf | ||
end | ||
|
||
Wishart(nu::Real, S::Matrix{Float64}) = Wishart(nu, cholfact(S)) | ||
#### Constructors | ||
|
||
show(io::IO, d::Wishart) = show_multline(io, d, [(:nu, d.nu), (:S, full(d.Schol))]) | ||
function Wishart{ST<:AbstractPDMat}(df::Real, S::ST) | ||
p = dim(S) | ||
df > p - 1 || error("df should be greater than dim - 1.") | ||
Wishart{ST}(df, S, _wishart_c0(df, S)) | ||
end | ||
|
||
Wishart(df::Real, S::Matrix{Float64}) = Wishart(df, PDMat(S)) | ||
|
||
dim(W::Wishart) = size(W.Schol, 1) | ||
size(W::Wishart) = size(W.Schol) | ||
Wishart(df::Real, S::Cholesky) = Wishart(df, PDMat(S)) | ||
|
||
function insupport(W::Wishart, X::Matrix{Float64}) | ||
return size(X) == size(W) && isApproxSymmmetric(X) && hasCholesky(X) | ||
end | ||
# This just checks if X could come from any Wishart | ||
function insupport(::Type{Wishart}, X::Matrix{Float64}) | ||
return size(X, 1) == size(X, 2) && isApproxSymmmetric(X) && hasCholesky(X) | ||
function _wishart_c0(df::Float64, S::AbstractPDMat) | ||
h_df = df / 2 | ||
p = dim(S) | ||
h_df * (logdet(S) + p * logtwo) + lpgamma(p, h_df) | ||
end | ||
|
||
mean(w::Wishart) = w.nu * (w.Schol[:U]' * w.Schol[:U]) | ||
|
||
function expected_logdet(W::Wishart) | ||
logd = 0. | ||
d = dim(W) | ||
#### Properties | ||
|
||
for i=1:d | ||
logd += digamma(0.5 * (W.nu + 1 - i)) | ||
end | ||
insupport(::Type{Wishart}, X::Matrix{Float64}) = isposdef(X) | ||
insupport(d::Wishart, X::Matrix{Float64}) = size(X) == size(d) && isposdef(X) | ||
|
||
logd += d * log(2) | ||
logd += logdet(W.Schol) | ||
dim(d::Wishart) = dim(d.S) | ||
size(d::Wishart) = (p = dim(d); (p, p)) | ||
|
||
return logd | ||
end | ||
|
||
function lognorm(W::Wishart) | ||
d = dim(W) | ||
return (W.nu / 2) * logdet(W.Schol) + (d * W.nu / 2) * log(2) + lpgamma(d, W.nu / 2) | ||
end | ||
#### Show | ||
|
||
show(io::IO, d::Wishart) = show_multline(io, d, [(:df, d.df), (:S, full(d.S))]) | ||
|
||
|
||
#### Statistics | ||
|
||
mean(d::Wishart) = d.df * full(d.S) | ||
|
||
function _logpdf{T<:Real}(W::Wishart, X::DenseMatrix{T}) | ||
Xchol = trycholfact(X) | ||
if size(X) == size(W) && isApproxSymmmetric(X) && isa(Xchol, Cholesky) | ||
d = dim(W) | ||
logd = -lognorm(W) | ||
logd += 0.5 * (W.nu - d - 1.0) * logdet(Xchol) | ||
logd -= 0.5 * trace(W.Schol \ X) | ||
return logd | ||
else | ||
return -Inf | ||
function meanlogdet(d::Wishart) | ||
p = dim(d) | ||
df = d.df | ||
v = logdet(d.S) + p * logtwo | ||
for i = 1:p | ||
v += digamma(0.5 * (df - (i - 1))) | ||
end | ||
return v | ||
end | ||
|
||
function rand(w::Wishart) | ||
p = size(w.Schol, 1) | ||
X = zeros(p, p) | ||
for ii in 1:p | ||
X[ii, ii] = sqrt(rand(Chisq(w.nu - ii + 1))) | ||
end | ||
if p > 1 | ||
for col in 2:p | ||
for row in 1:(col - 1) | ||
X[row, col] = randn() | ||
end | ||
end | ||
end | ||
Z = X * w.Schol[:U] | ||
return At_mul_B(Z, Z) | ||
function entropy(d::Wishart) | ||
p = dim(d) | ||
df = d.df | ||
d.c0 - 0.5 * (df - p - 1) * meanlogdet(d) + 0.5 * df * p | ||
end | ||
|
||
function entropy(W::Wishart) | ||
d = dim(W) | ||
return lognorm(W) - (W.nu - d - 1) / 2 * expected_logdet(W) + W.nu * d / 2 | ||
|
||
#### Evaluation | ||
|
||
function _logpdf(d::Wishart, X::DenseMatrix{Float64}) | ||
Xcf = cholfact(X) | ||
df = d.df | ||
p = dim(d) | ||
0.5 * ((df - (p + 1)) * logdet(Xcf) - trace(d.S \ X)) - d.c0 | ||
end | ||
|
||
var(w::Wishart) = error("Not yet implemented") | ||
|
||
#### Sampling | ||
|
||
function rand(d::Wishart) | ||
Z = unwhiten!(d.S, _wishart_genA(dim(d), d.df)) | ||
A_mul_Bt(Z, Z) | ||
end | ||
|
||
function _wishart_genA(p::Int, df::Float64) | ||
# Generate the matrix A in the Bartlett decomposition | ||
# | ||
# A is a lower triangular matrix, with | ||
# | ||
# A(i, j) ~ sqrt of Chisq(df - i + 1) when i == j | ||
# ~ Normal() when i > j | ||
# | ||
A = zeros(p, p) | ||
for i = 1:p | ||
@inbounds A[i,i] = sqrt(rand(Chisq(df - i + 1.0))) | ||
end | ||
for j = 1:p-1, i = j+1:p | ||
@inbounds A[i,j] = randn() | ||
end | ||
return A | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.