Skip to content

Commit c0bf336

Browse files
tianyizheng02github-actions
authored andcommitted
Consolidate the two existing kNN implementations (TheAlgorithms#8903)
* Add type hints to k_nearest_neighbours.py * Refactor k_nearest_neighbours.py into class * Add documentation to k_nearest_neighbours.py * Use heap-based priority queue for k_nearest_neighbours.py * Delete knn_sklearn.py * updating DIRECTORY.md * Use optional args in k_nearest_neighbours.py for demo purposes * Fix wrong function arg in k_nearest_neighbours.py --------- Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com>
1 parent 8c052c6 commit c0bf336

File tree

3 files changed

+79
-81
lines changed

3 files changed

+79
-81
lines changed

DIRECTORY.md

-1
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,6 @@
507507
* [Gradient Descent](machine_learning/gradient_descent.py)
508508
* [K Means Clust](machine_learning/k_means_clust.py)
509509
* [K Nearest Neighbours](machine_learning/k_nearest_neighbours.py)
510-
* [Knn Sklearn](machine_learning/knn_sklearn.py)
511510
* [Linear Discriminant Analysis](machine_learning/linear_discriminant_analysis.py)
512511
* [Linear Regression](machine_learning/linear_regression.py)
513512
* Local Weighted Learning
+79-49
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,88 @@
1+
"""
2+
k-Nearest Neighbours (kNN) is a simple non-parametric supervised learning
3+
algorithm used for classification. Given some labelled training data, a given
4+
point is classified using its k nearest neighbours according to some distance
5+
metric. The most commonly occurring label among the neighbours becomes the label
6+
of the given point. In effect, the label of the given point is decided by a
7+
majority vote.
8+
9+
This implementation uses the commonly used Euclidean distance metric, but other
10+
distance metrics can also be used.
11+
12+
Reference: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm
13+
"""
14+
115
from collections import Counter
16+
from heapq import nsmallest
217

318
import numpy as np
419
from sklearn import datasets
520
from sklearn.model_selection import train_test_split
621

7-
data = datasets.load_iris()
8-
9-
X = np.array(data["data"])
10-
y = np.array(data["target"])
11-
classes = data["target_names"]
12-
13-
X_train, X_test, y_train, y_test = train_test_split(X, y)
14-
15-
16-
def euclidean_distance(a, b):
17-
"""
18-
Gives the euclidean distance between two points
19-
>>> euclidean_distance([0, 0], [3, 4])
20-
5.0
21-
>>> euclidean_distance([1, 2, 3], [1, 8, 11])
22-
10.0
23-
"""
24-
return np.linalg.norm(np.array(a) - np.array(b))
25-
26-
27-
def classifier(train_data, train_target, classes, point, k=5):
28-
"""
29-
Classifies the point using the KNN algorithm
30-
k closest points are found (ranked in ascending order of euclidean distance)
31-
Params:
32-
:train_data: Set of points that are classified into two or more classes
33-
:train_target: List of classes in the order of train_data points
34-
:classes: Labels of the classes
35-
:point: The data point that needs to be classified
36-
37-
>>> X_train = [[0, 0], [1, 0], [0, 1], [0.5, 0.5], [3, 3], [2, 3], [3, 2]]
38-
>>> y_train = [0, 0, 0, 0, 1, 1, 1]
39-
>>> classes = ['A','B']; point = [1.2,1.2]
40-
>>> classifier(X_train, y_train, classes,point)
41-
'A'
42-
"""
43-
data = zip(train_data, train_target)
44-
# List of distances of all points from the point to be classified
45-
distances = []
46-
for data_point in data:
47-
distance = euclidean_distance(data_point[0], point)
48-
distances.append((distance, data_point[1]))
49-
# Choosing 'k' points with the least distances.
50-
votes = [i[1] for i in sorted(distances)[:k]]
51-
# Most commonly occurring class among them
52-
# is the class into which the point is classified
53-
result = Counter(votes).most_common(1)[0][0]
54-
return classes[result]
22+
23+
class KNN:
24+
def __init__(
25+
self,
26+
train_data: np.ndarray[float],
27+
train_target: np.ndarray[int],
28+
class_labels: list[str],
29+
) -> None:
30+
"""
31+
Create a kNN classifier using the given training data and class labels
32+
"""
33+
self.data = zip(train_data, train_target)
34+
self.labels = class_labels
35+
36+
@staticmethod
37+
def _euclidean_distance(a: np.ndarray[float], b: np.ndarray[float]) -> float:
38+
"""
39+
Calculate the Euclidean distance between two points
40+
>>> KNN._euclidean_distance(np.array([0, 0]), np.array([3, 4]))
41+
5.0
42+
>>> KNN._euclidean_distance(np.array([1, 2, 3]), np.array([1, 8, 11]))
43+
10.0
44+
"""
45+
return np.linalg.norm(a - b)
46+
47+
def classify(self, pred_point: np.ndarray[float], k: int = 5) -> str:
48+
"""
49+
Classify a given point using the kNN algorithm
50+
>>> train_X = np.array(
51+
... [[0, 0], [1, 0], [0, 1], [0.5, 0.5], [3, 3], [2, 3], [3, 2]]
52+
... )
53+
>>> train_y = np.array([0, 0, 0, 0, 1, 1, 1])
54+
>>> classes = ['A', 'B']
55+
>>> knn = KNN(train_X, train_y, classes)
56+
>>> point = np.array([1.2, 1.2])
57+
>>> knn.classify(point)
58+
'A'
59+
"""
60+
# Distances of all points from the point to be classified
61+
distances = (
62+
(self._euclidean_distance(data_point[0], pred_point), data_point[1])
63+
for data_point in self.data
64+
)
65+
66+
# Choosing k points with the shortest distances
67+
votes = (i[1] for i in nsmallest(k, distances))
68+
69+
# Most commonly occurring class is the one into which the point is classified
70+
result = Counter(votes).most_common(1)[0][0]
71+
return self.labels[result]
5572

5673

5774
if __name__ == "__main__":
58-
print(classifier(X_train, y_train, classes, [4.4, 3.1, 1.3, 1.4]))
75+
import doctest
76+
77+
doctest.testmod()
78+
79+
iris = datasets.load_iris()
80+
81+
X = np.array(iris["data"])
82+
y = np.array(iris["target"])
83+
iris_classes = iris["target_names"]
84+
85+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
86+
iris_point = np.array([4.4, 3.1, 1.3, 1.4])
87+
classifier = KNN(X_train, y_train, iris_classes)
88+
print(classifier.classify(iris_point, k=3))

machine_learning/knn_sklearn.py

-31
This file was deleted.

0 commit comments

Comments
 (0)