diff --git a/src/functions/indSimplex.jl b/src/functions/indSimplex.jl index 451423d..88400bb 100644 --- a/src/functions/indSimplex.jl +++ b/src/functions/indSimplex.jl @@ -1,6 +1,6 @@ # indicator of a simplex -export IndSimplex +export IndSimplex, IndUnitSimplex """ IndSimplex(a=1.0) @@ -107,3 +107,49 @@ function prox_naive(f::IndSimplex, x, gamma) end return v, R(0) end + +""" + IndUnitSimplex(a=1.0) + +Return the indicator of the unit simplex +```math +S = \\left\\{ x : x \\geq 0, \\sum_i x_i \\leq a \\right\\}. +``` + +By default `a=1`, therefore ``S`` is the probability simplex of dimension n+1. +""" +struct IndUnitSimplex{R} + a::R + function IndUnitSimplex{R}(a::R) where R + if a <= 0 + error("parameter a must be positive") + else + new(a) + end + end +end + +is_convex(f::Type{<:IndUnitSimplex}) = true +is_set(f::Type{<:IndUnitSimplex}) = true + +IndUnitSimplex(a::R=1) where R = IndUnitSimplex{R}(a) + +function (f::IndUnitSimplex)(x) + R = eltype(x) + if all(x .>= 0) && sum(x) <= f.a + eps(f.a) + return R(0) + end + return R(Inf) +end + +function prox!(y, f::IndUnitSimplex{R}, x, gamma) where {R} + fx = zero(R) + for i in eachindex(x) + y[i] = max(x[i], zero(R)) + fx += y[i] + end + if fx > f.a + simplex_proj_condat!(y, f.a, x) + end + return eltype(x)(0) +end diff --git a/test/test_calls.jl b/test/test_calls.jl index ee24121..c952fee 100644 --- a/test/test_calls.jl +++ b/test/test_calls.jl @@ -31,6 +31,8 @@ test_cases_spec = [ "right" => [ ( (IndSimplex(),), randn(Float32, 10) ), ( (IndSimplex(),), randn(Float64, 10) ), + ( (IndUnitSimplex(),), randn(Float32, 10) ), + ( (IndUnitSimplex(),), randn(Float64, 10) ), ( (IndNonnegative(), rand()), randn(Float64, 10) ), ( (IndZero(),), randn(Float64, 10) ), ( (IndBox(-1, 1),), randn(Float32, 10) ), diff --git a/test/test_equivalences.jl b/test/test_equivalences.jl index 2aa79c6..f2ca656 100644 --- a/test/test_equivalences.jl +++ b/test/test_equivalences.jl @@ -22,12 +22,22 @@ for i = 1:N r = 5*rand() f = IndSimplex(r) g = IndBallL1(r) + h = IndUnitSimplex(r) y1, fy1 = prox(f, abs.(x)) - y1 = sign.(x).*y1 + y1_l1ball = sign.(x).*y1 y2, gy2 = prox(g, x) + y3, hy3 = prox(h, abs.(x)) - @test y1 ≈ y2 + @test y1_l1ball ≈ y2 + @test y1 ≈ y3 + + x2 = abs.(x) * 0.5 * r + x2 ./= sum(x2) + + y_probsimplex = prox(f, x2) + y_unit_simplex = prox(h, x2) + @test norm(y_probsimplex) >= norm(y_unit_simplex) end # projecting onto the simplex