forked from johnmyleswhite/kNN.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
classifier.jl
39 lines (35 loc) · 1.13 KB
/
classifier.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
immutable kNNClassifier
t::NaiveNeighborTree
y::Vector
end
function knn(X::Matrix,
y::Vector;
metric::Metric = Euclidean())
return kNNClassifier(NaiveNeighborTree(X, metric), y)
end
# TODO: Don't construct copy of model.y just to extract majority vote
function StatsBase.predict(model::kNNClassifier,
x::Vector,
k::Integer = 1)
inds, dists = nearest(model.t, x, k)
return majority_vote(model.y[inds])
end
function StatsBase.predict!(predictions::Vector,
model::kNNClassifier,
X::Matrix,
k::Integer = 1)
n = size(X, 2)
# @assert eltype(predictions) == eltype(model.y)
# @assert length(predictions) == n
for i in 1:n
predictions[i] = predict(model, X[:, i], k)
end
return predictions
end
function StatsBase.predict(model::kNNClassifier,
X::Matrix,
k::Integer = 1)
predictions = Array(eltype(model.y), size(X, 2))
predict!(predictions, model, X, k)
return predictions
end