Skip to content

Commit

Permalink
Merge pull request #158 from QuantEcon/fix_mvn_sampler
Browse files Browse the repository at this point in the history
BUG: Fix MVNSampler
  • Loading branch information
sglyon authored May 1, 2017
2 parents df22e32 + 3fb9f1d commit 5ea70a3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
26 changes: 15 additions & 11 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,45 @@ function MVNSampler{TM<:Real,TS<:Real}(mu::Vector{TM}, Sigma::Matrix{TS})

n = length(mu)

if size(Sigma) != (n,n) # Check Sigma is n x n
throw(ArgumentError("Sigma must be 2 dimensional and square matrix of same length to mu"))
if size(Sigma) != (n, n) # Check Sigma is n x n
throw(ArgumentError(
"Sigma must be 2 dimensional and square matrix of same length to mu"
))
end

issymmetric(Sigma) || throw(ArgumentError("Sigma must be symmetric"))

A = copy(Sigma)
C = cholfact!(Symmetric(A, :L), Val{true})
C = cholfact(Symmetric(Sigma, :L), Val{true})
A = C.factors
r = C.rank
p = invperm(C.piv)

if C.rank == n # Positive definite
Q = tril!(A)[p,p]
if r == n # Positive definite
Q = tril!(A)[p, p]
return MVNSampler(mu, Sigma, Q)
end

non_PSD_msg = "Sigma must be positive semidefinite"

for i in C.rank+1:n
C[:L][i, i] >= -ATOL1 - RTOL1 * C[:L][1, 1] ||
for i in r+1:n
A[i, i] >= -ATOL1 - RTOL1 * A[1, 1] ||
throw(ArgumentError(non_PSD_msg))
end

tril!(view(A, :, 1:r))
A[:, r+1:end] = 0
Q = A[p,p]
Q = A[p, p]
isapprox(Q*Q', Sigma; rtol=RTOL2, atol=ATOL2) ||
throw(ArgumentError(non_PSD_msg))

return MVNSampler(mu, Sigma, Q)
end

# methods with the optional rng argument first
Base.rand(rng::AbstractRNG, d::MVNSampler) = d.mu + d.Q * randn(rng, length(d.mu))
Base.rand(rng::AbstractRNG, d::MVNSampler, n::Integer) = d.mu.+d.Q*randn(rng,(length(d.mu),n))
Base.rand(rng::AbstractRNG, d::MVNSampler) =
d.mu + d.Q * randn(rng, length(d.mu))
Base.rand(rng::AbstractRNG, d::MVNSampler, n::Integer) =
d.mu .+ d.Q * randn(rng, (length(d.mu), n))

# methods to draw from `MVNSampler`
Base.rand(d::MVNSampler) = rand(Base.GLOBAL_RNG, d)
Expand Down
10 changes: 10 additions & 0 deletions test/test_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,14 @@
@test typeof(MVNSampler(mu,Sigma)) <: MVNSampler
end
end

@testset "test covariance matrices of Int and Rational" begin
n = 2
mu = zeros(2)
for T in [Int, Rational{Int}]
Sigma = eye(T, n)
@test typeof(MVNSampler(mu, Sigma)) <: MVNSampler
end
end

end # @testset

0 comments on commit 5ea70a3

Please sign in to comment.