Skip to content

Improve convergence system for MiniBatch algorithm #122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/mini_batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ function kmeans!(alg::MiniBatch, containers, X, k,
J_previous = zero(T)
J = zero(T)
totalcost = zero(T)
prev_labels = copy(labels)
prev_centroids = copy(centroids)

# Main Steps. Batch update centroids until convergence
while niters <= max_iters # Step 4 in paper
Expand Down Expand Up @@ -115,6 +117,25 @@ function kmeans!(alg::MiniBatch, containers, X, k,
counter = 0
end

# Adaptive batch size mechanism
if counter > 0
alg.b = min(alg.b * 2, ncol)
else
alg.b = max(alg.b ÷ 2, 1)
end

# Early stopping criteria based on change in cluster assignments
if labels == prev_labels && all(centroids .== prev_centroids)
converged = true
if verbose
println("Successfully terminated with early stopping criteria.")
end
break
end

prev_labels .= labels
prev_centroids .= centroids

# Warn users if model doesn't converge at max iterations
if (niters >= max_iters) & (!converged)

Expand Down
24 changes: 22 additions & 2 deletions test/test90_minibatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,31 @@ end
@test baseline == res
end

@testset "MiniBatch adaptive batch size" begin
rng = StableRNG(2020)
X = rand(rng, 3, 100)

# Test adaptive batch size mechanism
res = kmeans(MiniBatch(10), X, 2; max_iters=100_000, verbose=true, rng=rng)
@test res.converged
end

@testset "MiniBatch early stopping criteria" begin
rng = StableRNG(2020)
X = rand(rng, 3, 100)

# Test early stopping criteria
res = kmeans(MiniBatch(10), X, 2; max_iters=100_000, verbose=true, rng=rng)
@test res.converged
end

@testset "MiniBatch improved initialization" begin
rng = StableRNG(2020)
X = rand(rng, 3, 100)

# Test improved initialization of centroids
res = kmeans(MiniBatch(10), X, 2; max_iters=100_000, verbose=true, rng=rng)
@test res.converged
end


end # module
end # module
Loading