Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FR] Classifier function #142

Open
KronosTheLate opened this issue Mar 16, 2022 · 0 comments
Open

[FR] Classifier function #142

KronosTheLate opened this issue Mar 16, 2022 · 0 comments

Comments

@KronosTheLate
Copy link

I am pretty new to KNN, but currently, I only know to use it for classifying. So it would be most natural for this package to include a simply classifier function. It should take indices of nearest neighbours given by knn, the train classes, evaluating each set of nearest neighbors to a class. A threshold value of minimum number of neighbours of the same class should be optional, and default to 1.

Possible sources of inspiration

  • kNN.jl's classifier.jl file.
  • The following function I have ended up having to define:
"""
    classify(neighbor_inds::Vector{Int}, train_classes::Vector{Int}; tiebreaker=rand, possible_classes=unique(train_classes))
    classify(neighbor_inds::Vector{Vector{Int}}, args...; kwargs...)

kwargs:
`tiebreaker` is 
1) a function that takes a tuple of candidates and returns a value, or 
2) a value that is returned upon a tie.
"""
function classify(neighbor_inds::AbstractVector{Int}, train_classes::AbstractVector{Int}; tiebreaker=rand, l::Int=1)
    possible_classes = unique(train_classes)
    neighbor_classes = train_classes[neighbor_inds]
    my_counts = [count(==(psbl_cls), neighbor_classes) for psbl_cls in possible_classes]
    A = [possible_classes my_counts]
    sorted_counts = sortslices(A, dims=1, by=x->x[2], rev=true)
    if sorted_counts[begin, end] < l
        return missing
    elseif sorted_counts[1, 2] == sorted_counts[2, 2]
        inds = [sorted_counts[i, 2] == sorted_counts[1, 2] for i in 1:size(sorted_counts, 1)]
        candidates_of_equal_count = sorted_counts[inds, :][:, 1]
        if tiebreaker isa Function
            return candidates_of_equal_count |> tiebreaker
        else
            f = (args...)->tiebreaker
            return candidates_of_equal_count |> f
        end
    else
        return sorted_counts[1, 1]
    end
end
function classify(neighbor_inds::Vector{Vector{Int}}, train_classes::AbstractVector{Int}; kwargs...)
    [classify(neighbor_inds[i], train_classes; kwargs...) for i in eachindex(neighbor_inds)]
end
nearest_neighbour_inds = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
train_classes = [1, 1, 2, 4, 5, 6, 8, 8, 9]
classify(nearest_neighbour_inds, train_classes)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant