Skip to content

Commit

Permalink
Finite Discrete Measure (#95)
Browse files Browse the repository at this point in the history
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
  • Loading branch information
davibarreira and devmotion authored Jun 7, 2021
1 parent 3674ceb commit 93c6077
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
52 changes: 52 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,55 @@ function checkbalanced(x::AbstractVecOrMat, y::AbstractVecOrMat)
throw(ArgumentError("source and target marginals are not balanced"))
return nothing
end

struct FiniteDiscreteMeasure{X<:AbstractVector,P<:AbstractVector}
support::X
p::P

function FiniteDiscreteMeasure{X,P}(support::X, p::P) where {X,P}
length(support) == length(p) || error("length of `support` and `p` must be equal")
isprobvec(p) || error("`p` must be a probability vector")
return new{X,P}(support, p)
end
end

"""
discretemeasure(
support::AbstractVector,
probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support))
)
Construct a finite discrete probability measure with `support` and corresponding
`probabilities`. If the probability vector argument is not passed, then
equal probability is assigned to each entry in the support.
# Examples
```julia
using KernelFunctions
# rows correspond to samples
μ = discretemeasure(RowVecs(rand(7,3)), normalize!(rand(10),1))
# columns correspond to samples, each with equal probability
ν = discretemeasure(ColVecs(rand(3,12)))
```
!!! note
If `support` is a 1D vector, the constructed measure will be sorted,
e.g. for `mu = discretemeasure([3, 1, 2],[0.5, 0.2, 0.3])`, then
`mu.support` will be `[1, 2, 3]` and `mu.p` will be `[0.2, 0.3, 0.5]`.
"""
function discretemeasure(
support::AbstractVector{<:Real},
probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support)),
)
return DiscreteNonParametric(support, probs)
end
function discretemeasure(
support::AbstractVector,
probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support)),
)
return FiniteDiscreteMeasure{typeof(support),typeof(probs)}(support, probs)
end

Distributions.support(d::FiniteDiscreteMeasure) = d.support
Distributions.probs(d::FiniteDiscreteMeasure) = d.p
49 changes: 49 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using OptimalTransport
using LinearAlgebra
using Random
using Test
using Distributions

Random.seed!(100)

Expand Down Expand Up @@ -95,4 +96,52 @@ Random.seed!(100)
x2, y2 .* hcat(rand(), ones(1, size(y2, 2) - 1))
)
end

@testset "FiniteDiscreteMeasure" begin
@testset "Univariate Finite Discrete Measure" begin
n = 100
m = 80
μsupp = rand(n)
νsupp = rand(m)
μprobs = normalize!(rand(n), 1)

μ = OptimalTransport.discretemeasure(μsupp, μprobs)
ν = OptimalTransport.discretemeasure(νsupp)
# check if it vectors are indeed probabilities
@test isprobvec.p)
@test isprobvec(probs(μ))
@test ν.p == ones(m) ./ m
@test probs(ν) == ones(m) ./ m

# check if it assigns to DiscreteNonParametric when Vector/Matrix is 1D
@test μ isa DiscreteNonParametric
@test ν isa DiscreteNonParametric

# check if support is correctly assinged
@test sort(μsupp) == μ.support
@test sort(μsupp) == support(μ)
@test sort(vec(νsupp)) == ν.support
@test sort(vec(νsupp)) == support(ν)
end
@testset "Multivariate Finite Discrete Measure" begin
n = 10
m = 3
μsupp = [rand(m) for i in 1:n]
νsupp = [rand(m) for i in 1:n]
μprobs = normalize!(rand(n), 1)
μ = OptimalTransport.discretemeasure(μsupp, μprobs)
ν = OptimalTransport.discretemeasure(νsupp)
# check if it vectors are indeed probabilities
@test isprobvec.p)
@test isprobvec(probs(μ))
@test ν.p == ones(n) ./ n
@test probs(ν) == ones(n) ./ n

# check if support is correctly assinged
@test μsupp == μ.support
@test μsupp == support(μ)
@test νsupp == ν.support
@test νsupp == support(ν)
end
end
end

0 comments on commit 93c6077

Please sign in to comment.