diff --git a/src/sampler.jl b/src/sampler.jl index 8e3adaaa..dafae8ed 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -16,32 +16,34 @@ function MVNSampler{TM<:Real,TS<:Real}(mu::Vector{TM}, Sigma::Matrix{TS}) n = length(mu) - if size(Sigma) != (n,n) # Check Sigma is n x n - throw(ArgumentError("Sigma must be 2 dimensional and square matrix of same length to mu")) + if size(Sigma) != (n, n) # Check Sigma is n x n + throw(ArgumentError( + "Sigma must be 2 dimensional and square matrix of same length to mu" + )) end issymmetric(Sigma) || throw(ArgumentError("Sigma must be symmetric")) - A = copy(Sigma) - C = cholfact!(Symmetric(A, :L), Val{true}) + C = cholfact(Symmetric(Sigma, :L), Val{true}) + A = C.factors r = C.rank p = invperm(C.piv) - if C.rank == n # Positive definite - Q = tril!(A)[p,p] + if r == n # Positive definite + Q = tril!(A)[p, p] return MVNSampler(mu, Sigma, Q) end non_PSD_msg = "Sigma must be positive semidefinite" - for i in C.rank+1:n - C[:L][i, i] >= -ATOL1 - RTOL1 * C[:L][1, 1] || + for i in r+1:n + A[i, i] >= -ATOL1 - RTOL1 * A[1, 1] || throw(ArgumentError(non_PSD_msg)) end tril!(view(A, :, 1:r)) A[:, r+1:end] = 0 - Q = A[p,p] + Q = A[p, p] isapprox(Q*Q', Sigma; rtol=RTOL2, atol=ATOL2) || throw(ArgumentError(non_PSD_msg)) @@ -49,8 +51,10 @@ function MVNSampler{TM<:Real,TS<:Real}(mu::Vector{TM}, Sigma::Matrix{TS}) end # methods with the optional rng argument first -Base.rand(rng::AbstractRNG, d::MVNSampler) = d.mu + d.Q * randn(rng, length(d.mu)) -Base.rand(rng::AbstractRNG, d::MVNSampler, n::Integer) = d.mu.+d.Q*randn(rng,(length(d.mu),n)) +Base.rand(rng::AbstractRNG, d::MVNSampler) = + d.mu + d.Q * randn(rng, length(d.mu)) +Base.rand(rng::AbstractRNG, d::MVNSampler, n::Integer) = + d.mu .+ d.Q * randn(rng, (length(d.mu), n)) # methods to draw from `MVNSampler` Base.rand(d::MVNSampler) = rand(Base.GLOBAL_RNG, d) diff --git a/test/test_sampler.jl b/test/test_sampler.jl index f74014a6..6381d212 100644 --- a/test/test_sampler.jl +++ b/test/test_sampler.jl @@ -47,4 +47,14 @@ @test typeof(MVNSampler(mu,Sigma)) <: MVNSampler end end + + @testset "test covariance matrices of Int and Rational" begin + n = 2 + mu = zeros(2) + for T in [Int, Rational{Int}] + Sigma = eye(T, n) + @test typeof(MVNSampler(mu, Sigma)) <: MVNSampler + end + end + end # @testset