-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathk_means.py
66 lines (50 loc) · 1.98 KB
/
k_means.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.datasets import load_iris
K = 3
MAX_ITERS = 10000
# Self-implemented K-means algorithm
def kmeans(data, k, max_iters=100):
n, dim = data.shape
# Randomly initialize cluster centers
centroids = data[np.random.choice(n, k, replace=False)]
for _ in range(max_iters):
# Assign data points to the nearest cluster center
distances = np.linalg.norm(data[:, np.newaxis, :] - centroids, axis=2)
clusters = np.argmin(distances, axis=1)
# Update cluster centers to the mean of each cluster
new_centroids = np.array([np.mean(data[clusters == i], axis=0) \
for i in range(k)])
# If cluster centers no longer change, end the iteration
if np.all(centroids == new_centroids):
break
centroids = new_centroids
return centroids, clusters
# Test the algorithms and output results
def run_algorithm(data, k, max_iters):
centroids, kmeans_clusters = kmeans(data, k, max_iters)
return centroids, kmeans_clusters
# Load the Iris dataset
iris = load_iris()
data = iris.data
# Visualization
pca = PCA(n_components=2)
data = pca.fit_transform(data)
# Run the algorithms
centroids, kmeans_clusters = run_algorithm(data, K, MAX_ITERS)
# Print the results
print("K-means Centroids:", centroids)
print("K-means Clusters:", kmeans_clusters)
# Visualize the results
def visualize_results(data, k, centroids, kmeans_clusters):
# Plot data points with colors representing clusters
plt.figure(figsize=(6, 6))
plt.scatter(data[:, 0], data[:, 1], c=kmeans_clusters, cmap='viridis')
plt.scatter(centroids[:, 0], centroids[:, 1], c='red', marker='x', s=100, label='Centroids')
plt.title('K-means Clustering')
plt.legend()
plt.show()
# Visualize the results
visualize_results(data, K, centroids, kmeans_clusters)