Skip to content

Commit

Permalink
minor stableclustering cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Nov 1, 2024
1 parent 0c635fa commit 7007dc0
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions imodels/clustering/stableclustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, k_values, n_repetitions=10, algorithm="k-means", metric="adju
self.random_state = random_state
self.scores_ = {}

def fit(self, X, y=None):
def fit(self, X):
best_k = None
best_score = -1
best_model = None
Expand All @@ -28,7 +28,7 @@ def fit(self, X, y=None):
for i_rep in tqdm(range(self.n_repetitions), desc='Repetitions', leave=False):
if self.algorithm == "k-means":
model = KMeans(
n_clusters=k, random_state=self.random_state + i_rep)
n_clusters=k, random_state=self.random_state + i_rep, init='random')
labels = model.fit_predict(X)
# elif self.algorithm == "nmf":
# model = NMF(n_components=k, init='random',
Expand Down Expand Up @@ -66,7 +66,7 @@ def fit(self, X, y=None):
return self

def predict(self, X):
check_is_fitted(self, ["best_model"])
check_is_fitted(self, ["best_model_", "best_k_"])
if self.algorithm == "k-means":
return self.best_model_.predict(X)
# elif self.algorithm == "nmf":
Expand Down

0 comments on commit 7007dc0

Please sign in to comment.