-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathKNearestNeighborsClassifier.swift
74 lines (57 loc) · 1.95 KB
/
KNearestNeighborsClassifier.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
//
// KNearestNeighborsClassifier.swift
//
//
// Created by Tomas Adam on 26/5/17.
// Copyright © 2019 Kosice. All rights reserved.
//
import Darwin
import Foundation
public class KNearestNeighborsClassifier {
private let data: [[Double]]
private let labels: [Int]
private let nNeighbors: Int
public init(data: [[Double]], labels: [Int], nNeighbors: Int = 3) {
self.data = data
self.labels = labels
self.nNeighbors = nNeighbors
guard nNeighbors <= data.count else {
fatalError("Expected `nNeighbors` (\(nNeighbors)) <= `data.count` (\(data.count))")
}
guard data.count == labels.count else {
fatalError("Expected `data.count` (\(data.count)) == `labels.count` (\(labels.count))")
}
}
public func predict(_ xTests: [[Double]]) -> [Int] {
return xTests.map({
let knn = kNearestNeighbors($0)
return kNearestNeighborsMajority(knn)
})
}
private func distance(_ xTrain: [Double], _ xTest: [Double]) -> Double {
let distances = xTrain.enumerated().map { index, _ in
return pow(xTrain[index] - xTest[index], 2)
}
return distances.reduce(0, +)
}
private func kNearestNeighbors(_ xTest: [Double]) -> [(key: Double, value: Int)] {
var NearestNeighbors = [Double : Int]()
for (index, xTrain) in data.enumerated() {
NearestNeighbors[distance(xTrain, xTest)] = labels[index]
}
let kNearestNeighborsSorted = Array(NearestNeighbors.sorted(by: { $0.0 < $1.0 }))[0...nNeighbors-1]
return Array(kNearestNeighborsSorted)
}
private func kNearestNeighborsMajority(_ knn: [(key: Double, value: Int)]) -> Int {
var labels = [Int : Int]()
for neighbor in knn {
labels[neighbor.value] = (labels[neighbor.value] ?? 0) + 1
}
for label in labels {
if label.value == labels.values.max() {
return label.key
}
}
fatalError("Cannot find the majority.")
}
}