Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Few AD fixes and improvements #70

Merged
merged 11 commits into from
May 1, 2020
1 change: 1 addition & 0 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export TuringScalMvNormal,
include("common.jl")
include("univariate.jl")
include("multivariate.jl")
include("mvcategorical.jl")
include("matrixvariate.jl")
include("flatten.jl")
include("arraydist.jl")
Expand Down
16 changes: 8 additions & 8 deletions src/arraydist.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
# Utils

function maporbroadcast(f, dists::AbstractArray, x::AbstractArray)
function summaporbroadcast(f, dists::AbstractArray, x::AbstractArray)
# Broadcasting here breaks Tracker for some reason
return sum(map(f, dists, x))
end
function maporbroadcast(f, dists::AbstractVector, x::AbstractMatrix)
return map(x -> maporbroadcast(f, dists, x), eachcol(x))
function summaporbroadcast(f, dists::AbstractVector, x::AbstractMatrix)
return map(x -> summaporbroadcast(f, dists, x), eachcol(x))
end
@init @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin
function maporbroadcast(f, dists::LazyArrays.BroadcastArray, x::AbstractArray)
function summaporbroadcast(f, dists::LazyArrays.BroadcastArray, x::AbstractArray)
return sum(copy(f.(dists, x)))
end
function maporbroadcast(f, dists::LazyArrays.BroadcastVector, x::AbstractMatrix)
function summaporbroadcast(f, dists::LazyArrays.BroadcastVector, x::AbstractMatrix)
return vec(sum(copy(f.(dists, x)), dims = 1))
end
lazyarray(f, x...) = LazyArrays.LazyArray(Base.broadcasted(f, x...))
Expand All @@ -27,11 +27,11 @@ function arraydist(dists::AbstractVector{<:UnivariateDistribution})
end

function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real})
return maporbroadcast(logpdf, dist.v, x)
return summaporbroadcast(logpdf, dist.v, x)
end
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
# eachcol breaks Zygote, so we need an adjoint
return maporbroadcast(logpdf, dist.v, x)
return summaporbroadcast(logpdf, dist.v, x)
end
ZygoteRules.@adjoint function Distributions.logpdf(
dist::VectorOfUnivariate,
Expand All @@ -54,7 +54,7 @@ function arraydist(dists::AbstractMatrix{<:UnivariateDistribution})
return MatrixOfUnivariate(dists)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
return maporbroadcast(logpdf, dist.dists, x)
return summaporbroadcast(logpdf, dist.dists, x)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return map(x -> logpdf(dist, x), x)
Expand Down
21 changes: 21 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,24 @@ function Distributions.isprobvec(p::TrackedArray{<:Real})
pdata = Tracker.data(p)
all(x -> x ≥ zero(x), pdata) && isapprox(sum(pdata), one(eltype(pdata)), atol = 1e-6)
end

# Some array functions - workaround https://github.com/FluxML/Tracker.jl/issues/4

import Base: +, -, *, /, \
import LinearAlgebra: dot

for f in (:+, :-, :*, :/, :\, :dot), (T1, T2) in [
(:TrackedArray, :AbstractArray),
(:TrackedMatrix, :AbstractMatrix),
(:TrackedMatrix, :AbstractVector),
(:TrackedVector, :AbstractMatrix),
]
@eval begin
function $f(a::$T1{T}, b::$T2{<:TrackedReal}) where {T <: Real}
return $f(convert(AbstractArray{TrackedReal{T}}, a), b)
end
function $f(a::$T2{<:TrackedReal}, b::$T1{T}) where {T <: Real}
return $f(a, convert(AbstractArray{TrackedReal{T}}, b))
end
end
end
18 changes: 15 additions & 3 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,24 @@ struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistrib
alpha0::T
lmnB::T
end
Base.length(d::TuringDirichlet) = length(d.alpha)
function check(alpha)
all(ai -> ai > 0, alpha) ||
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
end
ZygoteRules.@adjoint function check(alpha)
return check(alpha), _ -> nothing
return check(alpha), _ -> (nothing,)
end
function Distributions._rand!(rng::Random.AbstractRNG,
d::TuringDirichlet,
x::AbstractVector{<:Real})
s = 0.0
n = length(x)
α = d.alpha
for i in 1:n
@inbounds s += (x[i] = rand(rng, Gamma(α[i])))
end
Distributions.multiply!(x, inv(s)) # this returns x
end

function TuringDirichlet(alpha::AbstractVector)
Expand All @@ -31,11 +43,11 @@ function TuringDirichlet(d::Integer, alpha::Real)
TuringDirichlet{T, TV}(_alpha, alpha0, lmnB)
end
function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer}
Tf = float(T)
TuringDirichlet(convert(AbstractVector{Tf}, alpha))
TuringDirichlet(float.(alpha))
end
TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha))

Distributions.Dirichlet(alpha::AbstractVector) = TuringDirichlet(alpha)
Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha)
Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)

Expand Down
170 changes: 170 additions & 0 deletions src/mvcategorical.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
using Distributions: @check_args

"""
MvDiscreteNonParametric(xs, ps)

A *multivariate Discrete nonparametric distribution* explicitly defines an arbitrary
probability mass function in terms of a list of real support values and their
corresponding probabilities

```julia
d = MvDiscreteNonParametric(xs, ps)

params(d) # Get the parameters, i.e. (xs, ps)
support(d) # Get a sorted AbstractVector describing the support (xs) of the distribution
probs(d) # Get a Matrix of the probabilities (ps) associated with the support
```

External links

* [Probability mass function on Wikipedia](http://en.wikipedia.org/wiki/Probability_mass_function)
"""
struct MvDiscreteNonParametric{T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractMatrix{P}} <: DiscreteMultivariateDistribution
support::Ts
p::Ps

function MvDiscreteNonParametric{T,P,Ts,Ps}(vs::Ts, ps::Ps; check_args=true) where {
T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractMatrix{P}}
check_args || return new{T,P,Ts,Ps}(vs, ps)
@check_args(MvDiscreteNonParametric, length(vs) == size(ps,1))
@check_args(MvDiscreteNonParametric, all(isprobvec, eachcol(ps)))
@check_args(MvDiscreteNonParametric, allunique(vs))
sort_order = sortperm(vs)
new{T,P,Ts,Ps}(vs[sort_order], ps[sort_order,:])
end
end

MvDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractMatrix{P}} =
MvDiscreteNonParametric{T,P,Ts,Ps}(vs, ps, check_args=check_args)

Base.eltype(::Type{<:MvDiscreteNonParametric{T}}) where T = T

# Conversion
Base.convert(::Type{MvDiscreteNonParametric{T,P,Ts,Ps}}, d::MvDiscreteNonParametric) where {T,P,Ts,Ps} =
MvDiscreteNonParametric{T,P,Ts,Ps}(Ts(support(d)), Ps(probs(d)), check_args=false)

# Accessors
Distributions.params(d::MvDiscreteNonParametric) = (d.support, d.p)

"""
support(d::MvDiscreteNonParametric)

Get a sorted AbstractVector defining the support of `d`.
"""
Distributions.support(d::MvDiscreteNonParametric) = d.support

"""
probs(d::MvDiscreteNonParametric)

Get the vector of probabilities associated with the support of `d`.
"""
Distributions.probs(d::MvDiscreteNonParametric) = d.p

import Base: ==
==(c1::D, c2::D) where D<:MvDiscreteNonParametric =
(support(c1) == support(c2) || all(support(c1) .== support(c2))) &&
(probs(c1) == probs(c2) || all(probs(c1) .== probs(c2)))

Base.isapprox(c1::D, c2::D) where D<:MvDiscreteNonParametric =
(support(c1) ≈ support(c2) || all(support(c1) .≈ support(c2))) &&
(probs(c1) ≈ probs(c2) || all(probs(c1) .≈ probs(c2)))

# Sampling

function Base.rand(rng::AbstractRNG, d::MvDiscreteNonParametric{T,P}) where {T,P}
x = support(d)
p = probs(d)
n, k = size(p)
map(1:k) do j
draw = rand(rng, P)
cp = zero(P)
i = 0
while cp < draw && i < n
cp += p[i +=1, j]
end
x[max(i,1)]
end
end

Base.rand(d::MvDiscreteNonParametric) = rand(GLOBAL_RNG, d)

# Override the method in testutils.jl since it assumes
# an evenly-spaced integer support
Distributions.get_evalsamples(d::MvDiscreteNonParametric, ::Float64) = support(d)

# Evaluation

Distributions.pdf(d::MvDiscreteNonParametric) = copy(probs(d))

# Helper functions for pdf and cdf required to fix ambiguous method
# error involving [pc]df(::DisceteUnivariateDistribution, ::Int)
function _logpdf(d::MvDiscreteNonParametric{T,P}, x::AbstractVector{T}) where {T,P}
s = zero(P)
for col in 1:length(x)
idx_range = searchsorted(support(d), x[col])
if length(idx_range) > 0
s += log(probs(d)[first(idx_range),col])
end
end
return s
end
Distributions.logpdf(d::MvDiscreteNonParametric{T}, x::AbstractVector{<:Integer}) where T = _logpdf(d, convert(AbstractVector{T}, x))
Distributions.logpdf(d::MvDiscreteNonParametric{T}, x::AbstractVector{<:Real}) where T = _logpdf(d, convert(AbstractVector{T}, x))
Distributions.pdf(d::MvDiscreteNonParametric, x::AbstractVector{<:Real}) = exp(logpdf(d, x))

Base.minimum(d::MvDiscreteNonParametric) = first(support(d))
Base.maximum(d::MvDiscreteNonParametric) = last(support(d))
Distributions.insupport(d::MvDiscreteNonParametric, x::AbstractVector{<:Real}) =
all(x -> length(searchsorted(support(d), x)) > 0, x)

Distributions.mean(d::MvDiscreteNonParametric) = probs(d)' * support(d)

function Distributions.cov(d::MvDiscreteNonParametric)
m = mean(d)
x = support(d)
p = probs(d)
k = size(p,1)
n = size(p,2)
σ² = zero(m)
for j in 1:n
for i in 1:k
@inbounds σ²[j] += abs2(x[i,j] - m[j]) * p[i,j]
end
end
return Diagonal(σ²)
end

const MvCategorical{P,Ps} = MvDiscreteNonParametric{Int,P,Base.OneTo{Int},Ps}

MvCategorical(p::Ps; check_args=true) where {P<:Real, Ps<:AbstractMatrix{P}} =
MvCategorical{P,Ps}(p, check_args=check_args)

function MvCategorical{P,Ps}(p::Ps; check_args=true) where {P<:Real, Ps<:AbstractMatrix{P}}
check_args && @check_args(MvCategorical, all(isprobvec, eachcol(p)))
return MvCategorical{P,Ps}(Base.OneTo(size(p, 1)), p, check_args=check_args)
end

Distributions.ncategories(d::MvCategorical) = support(d).stop
Distributions.params(d::MvCategorical{P,Ps}) where {P<:Real, Ps<:AbstractVector{P}} = (probs(d),)
Distributions.partype(::MvCategorical{T}) where {T<:Real} = T
function Distributions.logpdf(d::MvCategorical{T}, x::AbstractVector{<:Integer}) where {T<:Real}
ps = probs(d)
if insupport(d, x)
_mv_categorical_logpdf(ps, x)
else
return zero(eltype(ps))
end
end
_mv_categorical_logpdf(ps, x) = sum(log, view(ps, x, :))
_mv_categorical_logpdf(ps::Tracker.TrackedMatrix, x) = Tracker.track(_mv_categorical_logpdf, ps, x)
Tracker.@grad function _mv_categorical_logpdf(ps, x)
ps_data = Tracker.data(ps)
probs = view(ps_data, x, :)
ps_grad = zero(ps_data)
sum(log, probs), Δ -> begin
ps_grad .= 0
ps_grad[x,:] .= Δ ./ probs
return (ps_grad, nothing)
end
end
18 changes: 17 additions & 1 deletion src/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using ..DistributionsAD: DistributionsAD, _turing_chol
const TrackedVecOrMat{V,D} = Union{TrackedVector{V,D},TrackedMatrix{V,D}}

import SpecialFunctions, NaNMath
import ..DistributionsAD: turing_chol
import ..DistributionsAD: turing_chol, _mv_categorical_logpdf
import Base.Broadcast: materialize
import StatsFuns: logsumexp
import ZygoteRules
Expand Down Expand Up @@ -250,5 +250,21 @@ function isprobvec(p::TrackedArray{<:Real})
pdata = value(p)
all(x -> x ≥ zero(x), pdata) && isapprox(sum(pdata), one(eltype(pdata)), atol = 1e-6)
end
function isprobvec(p::SubArray{<:TrackedReal, 1, <:TrackedArray{<:Real}})
pdata = value(p)
all(x -> x ≥ zero(x), pdata) && isapprox(sum(pdata), one(eltype(pdata)), atol = 1e-6)
end

_mv_categorical_logpdf(ps::TrackedMatrix, x) = track(_mv_categorical_logpdf, ps, x)
@grad function _mv_categorical_logpdf(ps, x)
ps_data = value(ps)
probs = view(ps_data, x, :)
ps_grad = zero(ps_data)
sum(log, probs), Δ -> begin
ps_grad .= 0
ps_grad[x,:] .= Δ ./ probs
return (ps_grad, nothing)
end
end

end
Loading