diff --git a/Project.toml b/Project.toml index a21c83a..d8894ac 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PythonOT" uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef" authors = ["David Widmann"] -version = "0.1.5" +version = "0.1.6" [deps] PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" diff --git a/docs/src/api.md b/docs/src/api.md index 2730efe..00409b1 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -35,4 +35,5 @@ PythonOT.Smooth.smooth_ot_dual sinkhorn_unbalanced sinkhorn_unbalanced2 barycenter_unbalanced +mm_unbalanced ``` diff --git a/src/PythonOT.jl b/src/PythonOT.jl index 1f67513..d801f3d 100644 --- a/src/PythonOT.jl +++ b/src/PythonOT.jl @@ -12,7 +12,8 @@ export emd, barycenter_unbalanced, sinkhorn_unbalanced, sinkhorn_unbalanced2, - empirical_sinkhorn_divergence + empirical_sinkhorn_divergence, + mm_unbalanced const pot = PyCall.PyNULL() diff --git a/src/lib.jl b/src/lib.jl index c13db7a..07bb9af 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -312,11 +312,11 @@ julia> ν = [0.0, 1.0]; julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5]; -julia> sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000) +julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4) 3×2 Matrix{Float64}: - 0.0 0.499964 - 0.0 0.200188 - 0.0 0.29983 + 0.0 0.5 + 0.0 0.2002 + 0.0 0.2998 ``` It is possible to provide multiple target marginals as columns of a matrix. In this case the @@ -325,10 +325,10 @@ optimal transport costs are returned: ```jldoctest sinkhorn_unbalanced julia> ν = [0.0 0.5; 1.0 0.5]; -julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=6) +julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4) 2-element Vector{Float64}: - 0.949709 - 0.449411 + 0.9497 + 0.4494 ``` See also: [`sinkhorn_unbalanced2`](@ref) @@ -371,9 +371,8 @@ julia> ν = [0.0, 1.0]; julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5]; -julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6) -1-element Vector{Float64}: - 0.949709 +julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4) +0.9497 ``` It is possible to provide multiple target marginals as columns of a matrix: @@ -381,10 +380,10 @@ It is possible to provide multiple target marginals as columns of a matrix: ```jldoctest sinkhorn_unbalanced2 julia> ν = [0.0 0.5; 1.0 0.5]; -julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6) +julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4) 2-element Vector{Float64}: - 0.949709 - 0.449411 + 0.9497 + 0.4494 ``` See also: [`sinkhorn_unbalanced`](@ref) @@ -516,3 +515,54 @@ Python function. function entropic_gromov_wasserstein(μ, ν, Cμ, Cν, ε, loss="square_loss"; kwargs...) return pot.gromov.entropic_gromov_wasserstein(Cμ, Cν, μ, ν, loss, ε; kwargs...) end + +""" + mm_unbalanced(a, b, M, reg_m; reg=0, c=a*b', kwargs...) + +Solve the unbalanced optimal transport problem and return the OT plan. +The function solves the following optimization problem: + +```math +W = \\min_{\\gamma \\geq 0} \\langle \\gamma, M \\rangle_F + + \\mathrm{reg_{m1}} \\cdot \\operatorname{div}(\\gamma \\mathbf{1}, a) + + \\mathrm{reg_{m2}} \\cdot \\operatorname{div}(\\gamma^\\mathsf{T} \\mathbf{1}, b) + + \\mathrm{reg} \\cdot \\operatorname{div}(\\gamma, c) +``` + +where + +- `M` is the metric cost matrix, +- `a` and `b` are source and target unbalanced distributions, +- `c` is a reference distribution for the regularization, +- `reg_m` is the marginal relaxation term (if it is a scalar or an indexable object of length 1, then the same term is applied to both marginal relaxations), and +- `reg` is a regularization term. + +This function is a wrapper of the function +[`mm_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.mm_unbalanced) in the +Python Optimal Transport package. Keyword arguments are listed in the documentation of the +Python function. + +# Examples + +```jldoctest +julia> a=[.5, .5]; + +julia> b=[.5, .5]; + +julia> M=[1. 36.; 9. 4.]; + +julia> round.(mm_unbalanced(a, b, M, 5, div="kl"), digits=2) +2×2 Matrix{Float64}: + 0.45 0.0 + 0.0 0.34 + +julia> round.(mm_unbalanced(a, b, M, 5, div="l2"), digits=2) +2×2 Matrix{Float64}: + 0.4 0.0 + 0.0 0.1 +``` + +""" +function mm_unbalanced(a, b, M, reg_m; kwargs...) + return pot.unbalanced.mm_unbalanced(a, b, M, reg_m; kwargs...) +end