diff --git a/imodels/clustering/stableclustering.py b/imodels/clustering/stableclustering.py index 395af800..f41fb14b 100644 --- a/imodels/clustering/stableclustering.py +++ b/imodels/clustering/stableclustering.py @@ -18,7 +18,7 @@ def __init__(self, k_values, n_repetitions=10, algorithm="k-means", metric="adju self.random_state = random_state self.scores_ = {} - def fit(self, X, y=None): + def fit(self, X): best_k = None best_score = -1 best_model = None @@ -28,7 +28,7 @@ def fit(self, X, y=None): for i_rep in tqdm(range(self.n_repetitions), desc='Repetitions', leave=False): if self.algorithm == "k-means": model = KMeans( - n_clusters=k, random_state=self.random_state + i_rep) + n_clusters=k, random_state=self.random_state + i_rep, init='random') labels = model.fit_predict(X) # elif self.algorithm == "nmf": # model = NMF(n_components=k, init='random', @@ -66,7 +66,7 @@ def fit(self, X, y=None): return self def predict(self, X): - check_is_fitted(self, ["best_model"]) + check_is_fitted(self, ["best_model_", "best_k_"]) if self.algorithm == "k-means": return self.best_model_.predict(X) # elif self.algorithm == "nmf":