Skip to content

Commit

Permalink
save all clustering models
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Nov 3, 2024
1 parent e3d7960 commit 3686937
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions imodels/clustering/stableclustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def __init__(self, k_values, n_repetitions=10, algorithm="k-means", metric="adju
def fit(self, X):
best_k = None
best_score = -1
best_model = None
self.models_ = []

for k in tqdm(self.k_values, desc="k"):
clusters = []
for i_rep in tqdm(range(self.n_repetitions), desc='Repetitions', leave=False):
if self.algorithm == "k-means":
model = KMeans(
n_clusters=k, random_state=self.random_state + i_rep, init='random')
n_clusters=k, random_state=self.random_state + i_rep) # , init='random')
labels = model.fit_predict(X)
# elif self.algorithm == "nmf":
# model = NMF(n_components=k, init='random',
Expand All @@ -53,6 +53,8 @@ def fit(self, X):
"Invalid metric: choose 'rand' or 'adjusted_rand'")
scores.append(score)

self.models_.append(deepcopy(model))

avg_score = np.mean(scores)
# Store the average score for this k
self.scores_[k] = float(avg_score)
Expand All @@ -63,7 +65,7 @@ def fit(self, X):

# Fit the final model on the whole data
self.best_k_ = best_k
self.best_model_ = best_model.fit(X)
self.best_model_ = best_model
return self

def predict(self, X):
Expand Down

0 comments on commit 3686937

Please sign in to comment.