Skip to content

Commit

Permalink
Simplify tests (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jun 5, 2021
1 parent a9ad75d commit 24ecf5d
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ julia = "1"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PythonOT = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tulip = "6dd1b50a-3aae-11e9-10b5-ef983d2400fa"
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"

[targets]
test = ["ForwardDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip", "HCubature"]
16 changes: 6 additions & 10 deletions examples/basic/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ sinkhorn2(μ, ν, C, ε)
# ```
# One property of the quadratically regularised optimal transport problem is that the
# resulting transport plan $\gamma$ is *sparse*. We take advantage of this and represent it as
# a sparse matrix.
# a sparse matrix.

quadreg(μ, ν, C, ε; maxiter=500);

Expand Down Expand Up @@ -120,7 +120,7 @@ norm(γ - γ_pot, Inf)
γpot = POT.sinkhorn(μ, ν, C, ε; method="sinkhorn_epsilon_scaling", numItermax=5000)
norm- γpot, Inf)

# ## Unbalanced optimal transport
# ## Unbalanced optimal transport
#
# [Unbalanced optimal transport](https://doi.org/10.1090/mcom/3303) deals with general
# positive measures which do not necessarily have the same total mass. For unbalanced
Expand Down Expand Up @@ -166,10 +166,8 @@ norm(γ - γpot, Inf)

μsupport = νsupport = range(-2, 2; length=100)
C = pairwise(SqEuclidean(), μsupport', νsupport'; dims=2)
μ = exp.(-μsupport .^ 2 ./ 0.5^2)
μ ./= sum(μ)
ν = νsupport .^ 2 .* exp.(-νsupport .^ 2 ./ 0.5^2)
ν ./= sum(ν)
μ = normalize!(exp.(-μsupport .^ 2 ./ 0.5^2), 1)
ν = normalize!(νsupport .^ 2 .* exp.(-νsupport .^ 2 ./ 0.5^2), 1)

plot(μsupport, μ; label=raw"$\mu$", size=(600, 400))
plot!(νsupport, ν; label=raw"$\nu$")
Expand Down Expand Up @@ -216,10 +214,8 @@ heatmap(
# $\lambda_1 \in \{0.25, 0.5, 0.75\}$.

support = range(-1, 1; length=250)
mu1 = exp.(-(support .+ 0.5) .^ 2 ./ 0.1^2)
mu1 ./= sum(mu1)
mu2 = exp.(-(support .- 0.5) .^ 2 ./ 0.1^2)
mu2 ./= sum(mu2)
mu1 = normalize!(exp.(-(support .+ 0.5) .^ 2 ./ 0.1^2), 1)
mu2 = normalize!(exp.(-(support .- 0.5) .^ 2 ./ 0.1^2), 1)

plt = plot(; size=(800, 400), legend=:outertopright)
plot!(plt, support, mu1; label=raw"$\mu_1$")
Expand Down
7 changes: 3 additions & 4 deletions test/entropic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using ForwardDiff
using LogExpFunctions
using PythonOT: PythonOT

using LinearAlgebra
using Random
using Test

Expand Down Expand Up @@ -219,10 +220,8 @@ Random.seed!(100)
@testset "example" begin
# set up support
support = range(-1; stop=1, length=250)
μ1 = exp.(-(support .+ 0.5) .^ 2 ./ 0.1^2)
μ1 ./= sum(μ1)
μ2 = exp.(-(support .- 0.5) .^ 2 ./ 0.1^2)
μ2 ./= sum(μ2)
μ1 = normalize!(exp.(-(support .+ 0.5) .^ 2 ./ 0.1^2), 1)
μ2 = normalize!(exp.(-(support .- 0.5) .^ 2 ./ 0.1^2), 1)
μ_all = hcat(μ1, μ2)
# create cost matrix
C = pairwise(SqEuclidean(), support'; dims=2)
Expand Down
23 changes: 8 additions & 15 deletions test/exact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ Random.seed!(100)
@testset "Earth-Movers Distance" begin
M = 200
N = 250
μ = rand(M)
ν = rand(N)
μ ./= sum(μ)
ν ./= sum(ν)
μ = normalize!(rand(M), 1)
ν = normalize!(rand(N), 1)

@testset "example" begin
# create random cost matrix
Expand Down Expand Up @@ -87,8 +85,7 @@ Random.seed!(100)

@testset "semidiscrete case" begin
μ = Normal(randn(), rand())
νprobs = rand(30)
νprobs ./= sum(νprobs)
νprobs = normalize!(rand(30), 1)
ν = Categorical(νprobs)

# compute OT plan
Expand All @@ -113,14 +110,12 @@ Random.seed!(100)
@testset "discrete case" begin
# random source and target marginal
m = 30
μprobs = rand(m)
μprobs ./= sum(μprobs)
μprobs = normalize!(rand(m), 1)
μsupport = randn(m)
μ = DiscreteNonParametric(μsupport, μprobs)

n = 50
νprobs = rand(n)
νprobs ./= sum(νprobs)
νprobs = normalize!(rand(n), 1)
νsupport = randn(n)
ν = DiscreteNonParametric(νsupport, νprobs)

Expand Down Expand Up @@ -200,11 +195,9 @@ Random.seed!(100)
end

# Create discretized distribution
μprobs = pdf(μ, μsupp')
μprobs = μprobs ./ sum(μprobs)
νprobs = pdf(ν, νsupp')
νprobs = νprobs ./ sum(νprobs)
C = pairwise(SqEuclidean(), μsupp', νsupp')
μprobs = normalize!(pdf(μ, μsupp'), 1)
νprobs = normalize!(pdf(ν, νsupp'), 1)
C = pairwise(SqEuclidean(), μsupp', νsupp'; dims=2)
@test emd2(μprobs, νprobs, C, Tulip.Optimizer()) ot_cost(SqEuclidean(), μ, ν) rtol =
1e-3

Expand Down
1 change: 1 addition & 0 deletions test/gpu/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OptimalTransport = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down
14 changes: 6 additions & 8 deletions test/gpu/simple_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using OptimalTransport
using CUDA
using Distances

using LinearAlgebra
using Random
using Test

Expand All @@ -18,14 +19,12 @@ Random.seed!(100)
@testset "sinkhorn" begin
# source histogram
m = 200
μ = rand(Float32, m)
μ ./= sum(μ)
μ = normalize!(rand(Float32, m), 1)
cu_μ = cu(μ)

# target histogram
n = 250
ν = rand(Float32, n)
ν ./= sum(ν)
ν = normalize!(rand(Float32, n), 1)
cu_ν = cu(ν)

# random cost matrix
Expand Down Expand Up @@ -71,13 +70,12 @@ Random.seed!(100)
@testset "sinkhorn_unbalanced" begin
# source histogram
m = 200
μ = rand(Float32, m)
μ ./= 1.5f0 * sum(μ)
μ = normalize!(rand(Float32, m), 1)
μ .*= 1.5f0

# target histogram
n = 250
ν = rand(Float32, n)
ν ./= sum(ν)
ν = normalize!(rand(Float32, n), 1)

# random cost matrix
C = pairwise(SqEuclidean(), randn(Float32, 1, m), randn(Float32, 1, n); dims=2)
Expand Down
8 changes: 4 additions & 4 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ Random.seed!(100)
@testset "checkbalanced" begin
mass = rand()

x1 = rand(20)
x1 .*= mass / sum(x1)
y1 = rand(30)
y1 .*= mass / sum(y1)
x1 = normalize!(rand(20), 1)
x1 .*= mass
y1 = normalize!(rand(30), 1)
y1 .*= mass
@test OptimalTransport.checkbalanced(x1, y1) === nothing
@test OptimalTransport.checkbalanced(y1, x1) === nothing
@test_throws ArgumentError OptimalTransport.checkbalanced(rand() .* x1, y1)
Expand Down

0 comments on commit 24ecf5d

Please sign in to comment.