Skip to content

Commit

Permalink
Merge pull request #70 from TuringLang/mt/neurips2
Browse files Browse the repository at this point in the history
Few AD fixes and improvements
  • Loading branch information
mohamed82008 authored May 1, 2020
2 parents e605971 + 434ac03 commit 340eb9f
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 22 deletions.
1 change: 1 addition & 0 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export TuringScalMvNormal,
include("common.jl")
include("univariate.jl")
include("multivariate.jl")
include("mvcategorical.jl")
include("matrixvariate.jl")
include("flatten.jl")
include("arraydist.jl")
Expand Down
16 changes: 8 additions & 8 deletions src/arraydist.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
# Utils

function maporbroadcast(f, dists::AbstractArray, x::AbstractArray)
function summaporbroadcast(f, dists::AbstractArray, x::AbstractArray)
# Broadcasting here breaks Tracker for some reason
return sum(map(f, dists, x))
end
function maporbroadcast(f, dists::AbstractVector, x::AbstractMatrix)
return map(x -> maporbroadcast(f, dists, x), eachcol(x))
function summaporbroadcast(f, dists::AbstractVector, x::AbstractMatrix)
return map(x -> summaporbroadcast(f, dists, x), eachcol(x))
end
@init @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin
function maporbroadcast(f, dists::LazyArrays.BroadcastArray, x::AbstractArray)
function summaporbroadcast(f, dists::LazyArrays.BroadcastArray, x::AbstractArray)
return sum(copy(f.(dists, x)))
end
function maporbroadcast(f, dists::LazyArrays.BroadcastVector, x::AbstractMatrix)
function summaporbroadcast(f, dists::LazyArrays.BroadcastVector, x::AbstractMatrix)
return vec(sum(copy(f.(dists, x)), dims = 1))
end
lazyarray(f, x...) = LazyArrays.LazyArray(Base.broadcasted(f, x...))
Expand All @@ -27,11 +27,11 @@ function arraydist(dists::AbstractVector{<:UnivariateDistribution})
end

function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real})
return maporbroadcast(logpdf, dist.v, x)
return summaporbroadcast(logpdf, dist.v, x)
end
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
# eachcol breaks Zygote, so we need an adjoint
return maporbroadcast(logpdf, dist.v, x)
return summaporbroadcast(logpdf, dist.v, x)
end
ZygoteRules.@adjoint function Distributions.logpdf(
dist::VectorOfUnivariate,
Expand All @@ -54,7 +54,7 @@ function arraydist(dists::AbstractMatrix{<:UnivariateDistribution})
return MatrixOfUnivariate(dists)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
return maporbroadcast(logpdf, dist.dists, x)
return summaporbroadcast(logpdf, dist.dists, x)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return map(x -> logpdf(dist, x), x)
Expand Down
21 changes: 21 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,24 @@ function Distributions.isprobvec(p::TrackedArray{<:Real})
pdata = Tracker.data(p)
all(x -> x zero(x), pdata) && isapprox(sum(pdata), one(eltype(pdata)), atol = 1e-6)
end

# Some array functions - workaround https://github.com/FluxML/Tracker.jl/issues/4

import Base: +, -, *, /, \
import LinearAlgebra: dot

for f in (:+, :-, :*, :/, :\, :dot), (T1, T2) in [
(:TrackedArray, :AbstractArray),
(:TrackedMatrix, :AbstractMatrix),
(:TrackedMatrix, :AbstractVector),
(:TrackedVector, :AbstractMatrix),
]
@eval begin
function $f(a::$T1{T}, b::$T2{<:TrackedReal}) where {T <: Real}
return $f(convert(AbstractArray{TrackedReal{T}}, a), b)
end
function $f(a::$T2{<:TrackedReal}, b::$T1{T}) where {T <: Real}
return $f(a, convert(AbstractArray{TrackedReal{T}}, b))
end
end
end
18 changes: 15 additions & 3 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,24 @@ struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistrib
alpha0::T
lmnB::T
end
Base.length(d::TuringDirichlet) = length(d.alpha)
function check(alpha)
all(ai -> ai > 0, alpha) ||
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
end
ZygoteRules.@adjoint function check(alpha)
return check(alpha), _ -> nothing
return check(alpha), _ -> (nothing,)
end
function Distributions._rand!(rng::Random.AbstractRNG,
d::TuringDirichlet,
x::AbstractVector{<:Real})
s = 0.0
n = length(x)
α = d.alpha
for i in 1:n
@inbounds s += (x[i] = rand(rng, Gamma(α[i])))
end
Distributions.multiply!(x, inv(s)) # this returns x
end

function TuringDirichlet(alpha::AbstractVector)
Expand All @@ -31,11 +43,11 @@ function TuringDirichlet(d::Integer, alpha::Real)
TuringDirichlet{T, TV}(_alpha, alpha0, lmnB)
end
function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer}
Tf = float(T)
TuringDirichlet(convert(AbstractVector{Tf}, alpha))
TuringDirichlet(float.(alpha))
end
TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha))

Distributions.Dirichlet(alpha::AbstractVector) = TuringDirichlet(alpha)
Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha)
Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)

Expand Down
170 changes: 170 additions & 0 deletions src/mvcategorical.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
using Distributions: @check_args

"""
MvDiscreteNonParametric(xs, ps)
A *multivariate Discrete nonparametric distribution* explicitly defines an arbitrary
probability mass function in terms of a list of real support values and their
corresponding probabilities
```julia
d = MvDiscreteNonParametric(xs, ps)
params(d) # Get the parameters, i.e. (xs, ps)
support(d) # Get a sorted AbstractVector describing the support (xs) of the distribution
probs(d) # Get a Matrix of the probabilities (ps) associated with the support
```
External links
* [Probability mass function on Wikipedia](http://en.wikipedia.org/wiki/Probability_mass_function)
"""
struct MvDiscreteNonParametric{T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractMatrix{P}} <: DiscreteMultivariateDistribution
support::Ts
p::Ps

function MvDiscreteNonParametric{T,P,Ts,Ps}(vs::Ts, ps::Ps; check_args=true) where {
T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractMatrix{P}}
check_args || return new{T,P,Ts,Ps}(vs, ps)
@check_args(MvDiscreteNonParametric, length(vs) == size(ps,1))
@check_args(MvDiscreteNonParametric, all(isprobvec, eachcol(ps)))
@check_args(MvDiscreteNonParametric, allunique(vs))
sort_order = sortperm(vs)
new{T,P,Ts,Ps}(vs[sort_order], ps[sort_order,:])
end
end

MvDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractMatrix{P}} =
MvDiscreteNonParametric{T,P,Ts,Ps}(vs, ps, check_args=check_args)

Base.eltype(::Type{<:MvDiscreteNonParametric{T}}) where T = T

# Conversion
Base.convert(::Type{MvDiscreteNonParametric{T,P,Ts,Ps}}, d::MvDiscreteNonParametric) where {T,P,Ts,Ps} =
MvDiscreteNonParametric{T,P,Ts,Ps}(Ts(support(d)), Ps(probs(d)), check_args=false)

# Accessors
Distributions.params(d::MvDiscreteNonParametric) = (d.support, d.p)

"""
support(d::MvDiscreteNonParametric)
Get a sorted AbstractVector defining the support of `d`.
"""
Distributions.support(d::MvDiscreteNonParametric) = d.support

"""
probs(d::MvDiscreteNonParametric)
Get the vector of probabilities associated with the support of `d`.
"""
Distributions.probs(d::MvDiscreteNonParametric) = d.p

import Base: ==
==(c1::D, c2::D) where D<:MvDiscreteNonParametric =
(support(c1) == support(c2) || all(support(c1) .== support(c2))) &&
(probs(c1) == probs(c2) || all(probs(c1) .== probs(c2)))

Base.isapprox(c1::D, c2::D) where D<:MvDiscreteNonParametric =
(support(c1) support(c2) || all(support(c1) .≈ support(c2))) &&
(probs(c1) probs(c2) || all(probs(c1) .≈ probs(c2)))

# Sampling

function Base.rand(rng::AbstractRNG, d::MvDiscreteNonParametric{T,P}) where {T,P}
x = support(d)
p = probs(d)
n, k = size(p)
map(1:k) do j
draw = rand(rng, P)
cp = zero(P)
i = 0
while cp < draw && i < n
cp += p[i +=1, j]
end
x[max(i,1)]
end
end

Base.rand(d::MvDiscreteNonParametric) = rand(GLOBAL_RNG, d)

# Override the method in testutils.jl since it assumes
# an evenly-spaced integer support
Distributions.get_evalsamples(d::MvDiscreteNonParametric, ::Float64) = support(d)

# Evaluation

Distributions.pdf(d::MvDiscreteNonParametric) = copy(probs(d))

# Helper functions for pdf and cdf required to fix ambiguous method
# error involving [pc]df(::DisceteUnivariateDistribution, ::Int)
function _logpdf(d::MvDiscreteNonParametric{T,P}, x::AbstractVector{T}) where {T,P}
s = zero(P)
for col in 1:length(x)
idx_range = searchsorted(support(d), x[col])
if length(idx_range) > 0
s += log(probs(d)[first(idx_range),col])
end
end
return s
end
Distributions.logpdf(d::MvDiscreteNonParametric{T}, x::AbstractVector{<:Integer}) where T = _logpdf(d, convert(AbstractVector{T}, x))
Distributions.logpdf(d::MvDiscreteNonParametric{T}, x::AbstractVector{<:Real}) where T = _logpdf(d, convert(AbstractVector{T}, x))
Distributions.pdf(d::MvDiscreteNonParametric, x::AbstractVector{<:Real}) = exp(logpdf(d, x))

Base.minimum(d::MvDiscreteNonParametric) = first(support(d))
Base.maximum(d::MvDiscreteNonParametric) = last(support(d))
Distributions.insupport(d::MvDiscreteNonParametric, x::AbstractVector{<:Real}) =
all(x -> length(searchsorted(support(d), x)) > 0, x)

Distributions.mean(d::MvDiscreteNonParametric) = probs(d)' * support(d)

function Distributions.cov(d::MvDiscreteNonParametric)
m = mean(d)
x = support(d)
p = probs(d)
k = size(p,1)
n = size(p,2)
σ² = zero(m)
for j in 1:n
for i in 1:k
@inbounds σ²[j] += abs2(x[i,j] - m[j]) * p[i,j]
end
end
return Diagonal(σ²)
end

const MvCategorical{P,Ps} = MvDiscreteNonParametric{Int,P,Base.OneTo{Int},Ps}

MvCategorical(p::Ps; check_args=true) where {P<:Real, Ps<:AbstractMatrix{P}} =
MvCategorical{P,Ps}(p, check_args=check_args)

function MvCategorical{P,Ps}(p::Ps; check_args=true) where {P<:Real, Ps<:AbstractMatrix{P}}
check_args && @check_args(MvCategorical, all(isprobvec, eachcol(p)))
return MvCategorical{P,Ps}(Base.OneTo(size(p, 1)), p, check_args=check_args)
end

Distributions.ncategories(d::MvCategorical) = support(d).stop
Distributions.params(d::MvCategorical{P,Ps}) where {P<:Real, Ps<:AbstractVector{P}} = (probs(d),)
Distributions.partype(::MvCategorical{T}) where {T<:Real} = T
function Distributions.logpdf(d::MvCategorical{T}, x::AbstractVector{<:Integer}) where {T<:Real}
ps = probs(d)
if insupport(d, x)
_mv_categorical_logpdf(ps, x)
else
return zero(eltype(ps))
end
end
_mv_categorical_logpdf(ps, x) = sum(log, view(ps, x, :))
_mv_categorical_logpdf(ps::Tracker.TrackedMatrix, x) = Tracker.track(_mv_categorical_logpdf, ps, x)
Tracker.@grad function _mv_categorical_logpdf(ps, x)
ps_data = Tracker.data(ps)
probs = view(ps_data, x, :)
ps_grad = zero(ps_data)
sum(log, probs), Δ -> begin
ps_grad .= 0
ps_grad[x,:] .= Δ ./ probs
return (ps_grad, nothing)
end
end
18 changes: 17 additions & 1 deletion src/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using ..DistributionsAD: DistributionsAD, _turing_chol
const TrackedVecOrMat{V,D} = Union{TrackedVector{V,D},TrackedMatrix{V,D}}

import SpecialFunctions, NaNMath
import ..DistributionsAD: turing_chol
import ..DistributionsAD: turing_chol, _mv_categorical_logpdf
import Base.Broadcast: materialize
import StatsFuns: logsumexp
import ZygoteRules
Expand Down Expand Up @@ -250,5 +250,21 @@ function isprobvec(p::TrackedArray{<:Real})
pdata = value(p)
all(x -> x zero(x), pdata) && isapprox(sum(pdata), one(eltype(pdata)), atol = 1e-6)
end
function isprobvec(p::SubArray{<:TrackedReal, 1, <:TrackedArray{<:Real}})
pdata = value(p)
all(x -> x zero(x), pdata) && isapprox(sum(pdata), one(eltype(pdata)), atol = 1e-6)
end

_mv_categorical_logpdf(ps::TrackedMatrix, x) = track(_mv_categorical_logpdf, ps, x)
@grad function _mv_categorical_logpdf(ps, x)
ps_data = value(ps)
probs = view(ps_data, x, :)
ps_grad = zero(ps_data)
sum(log, probs), Δ -> begin
ps_grad .= 0
ps_grad[x,:] .= Δ ./ probs
return (ps_grad, nothing)
end
end

end
Loading

0 comments on commit 340eb9f

Please sign in to comment.