diff --git a/src/functions/indSimplex.jl b/src/functions/indSimplex.jl
index 451423d..88400bb 100644
--- a/src/functions/indSimplex.jl
+++ b/src/functions/indSimplex.jl
@@ -1,6 +1,6 @@
 # indicator of a simplex
 
-export IndSimplex
+export IndSimplex, IndUnitSimplex
 
 """
     IndSimplex(a=1.0)
@@ -107,3 +107,49 @@ function prox_naive(f::IndSimplex, x, gamma)
     end
     return v, R(0)
 end
+
+"""
+    IndUnitSimplex(a=1.0)
+
+Return the indicator of the unit simplex
+```math
+S = \\left\\{ x : x \\geq 0, \\sum_i x_i \\leq a \\right\\}.
+```
+
+By default `a=1`, therefore ``S`` is the probability simplex of dimension n+1.
+"""
+struct IndUnitSimplex{R}
+    a::R
+    function IndUnitSimplex{R}(a::R) where R
+        if a <= 0
+            error("parameter a must be positive")
+        else
+            new(a)
+        end
+    end
+end
+
+is_convex(f::Type{<:IndUnitSimplex}) = true
+is_set(f::Type{<:IndUnitSimplex}) = true
+
+IndUnitSimplex(a::R=1) where R = IndUnitSimplex{R}(a)
+
+function (f::IndUnitSimplex)(x)
+    R = eltype(x)
+    if all(x .>= 0) && sum(x) <= f.a + eps(f.a)
+        return R(0)
+    end
+    return R(Inf)
+end
+
+function prox!(y, f::IndUnitSimplex{R}, x, gamma) where {R}
+    fx = zero(R)
+    for i in eachindex(x)
+        y[i] = max(x[i], zero(R))
+        fx += y[i]
+    end
+    if fx > f.a
+        simplex_proj_condat!(y, f.a, x)
+    end
+    return eltype(x)(0)
+end
diff --git a/test/test_calls.jl b/test/test_calls.jl
index ee24121..c952fee 100644
--- a/test/test_calls.jl
+++ b/test/test_calls.jl
@@ -31,6 +31,8 @@ test_cases_spec = [
         "right" => [
             ( (IndSimplex(),), randn(Float32, 10) ),
             ( (IndSimplex(),), randn(Float64, 10) ),
+            ( (IndUnitSimplex(),), randn(Float32, 10) ),
+            ( (IndUnitSimplex(),), randn(Float64, 10) ),
             ( (IndNonnegative(), rand()), randn(Float64, 10) ),
             ( (IndZero(),), randn(Float64, 10) ),
             ( (IndBox(-1, 1),), randn(Float32, 10) ),
diff --git a/test/test_equivalences.jl b/test/test_equivalences.jl
index 2aa79c6..f2ca656 100644
--- a/test/test_equivalences.jl
+++ b/test/test_equivalences.jl
@@ -22,12 +22,22 @@ for i = 1:N
   r = 5*rand()
   f = IndSimplex(r)
   g = IndBallL1(r)
+  h = IndUnitSimplex(r)
 
   y1, fy1 = prox(f, abs.(x))
-  y1 = sign.(x).*y1
+  y1_l1ball = sign.(x).*y1
   y2, gy2 = prox(g, x)
+  y3, hy3 = prox(h, abs.(x))
 
-  @test y1 ≈ y2
+  @test y1_l1ball ≈ y2
+  @test y1 ≈ y3
+
+  x2 = abs.(x) * 0.5 * r
+  x2 ./= sum(x2)
+
+  y_probsimplex = prox(f, x2)
+  y_unit_simplex = prox(h, x2)
+  @test norm(y_probsimplex) >= norm(y_unit_simplex)
 end
 
 # projecting onto the simplex