Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propose accuracy functions #2181

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

skyleaworlder
Copy link
Contributor

@skyleaworlder skyleaworlder commented Feb 5, 2023

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

About #2171 | accuracy function

I simply define 3 accuracy function for multi-class and multi-label problem.

julia> typeof(test_x)
Matrix{Float32} (alias for Array{Float32, 2})

julia> typeof(test_y)
Vector{Int64} (alias for Array{Int64, 1})

julia> multiclass_accuracy(model, test_x, test_y);
julia> typeof(test_x)
Matrix{N0f8} (alias for Array{FixedPointNumbers.Normed{UInt8, 8}, 2})

julia> typeof(test_y)
OneHotMatrix{UInt32, Vector{UInt32}} (alias for OneHotArray{UInt32, 1, 2, Array{UInt32, 1}})

julia> multiclass_accuracy(model, test_x, test_y);

julia> multilabel_accuracy(model, test_x, test_y);

@codecov-commenter
Copy link

codecov-commenter commented Feb 5, 2023

Codecov Report

Attention: Patch coverage is 0% with 3 lines in your changes missing coverage. Please review.

Project coverage is 83.15%. Comparing base (c5a691a) to head (5220755).
Report is 270 commits behind head on master.

Files Patch % Lines
src/metrics.jl 0.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2181      +/-   ##
==========================================
- Coverage   86.02%   83.15%   -2.88%     
==========================================
  Files          19       20       +1     
  Lines        1460     1460              
==========================================
- Hits         1256     1214      -42     
- Misses        204      246      +42     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@CarloLucibello
Copy link
Member

I'm not sure we want to add accuracy metrics in Flux, but if we do they should take the signature metric(yhat, y) and not metric(model, x, y)

@ToucheSir
Copy link
Member

I think we should add them somewhere officially supported, since we've seen a lot of people end up rolling their own less-than-optimal versions. The question is where? Do we have the capacity to revive https://github.com/JuliaML/MLMetrics.jl?

@mcabbott
Copy link
Member

mcabbott commented Feb 12, 2023

We should think more about the right signatures.

Matching the categorical functions in Flux.Losses is one place to start, and suggests:

accuracy(yhat::AbstractMatrix, y::OneHotMatrix)         # like crossentropy
accuracy(yhat::AbstractMatrix, y::AbstractVector{Bool}) # like binarycrossentropy

If the first class are expanded to take labels instead (like #2141) then

accuracy(yhat::AbstractMatrix, y::AbstractVector, labels=1:size(yhat,1)) = accuracy(yhat, onehotbatch(y, labels))

If all loss functions accept the model as 1st argument (#2090) then perhaps you want accuracy(m, x, y) = accuracy(m(x), y). But that loss signature is mostly because it would be convenient to pass to train!, how useful is it here?

Matching the input of train! is another possibility. Having made a DataLoader with batches, this function could perhaps iterate over it for you:

accuracy(model, data) = mean(accuracy(model(x), y) for (x,y) in data)  # really allow non-equal batches

Those seem the obvious reference points if this lives within Flux. If it lives in OneHotArrays, then perhaps the binarycrossentropy-like signature is out of scope.

@skyleaworlder
Copy link
Contributor Author

I use multilabel_accuracy(model, test_x, test_y) because I think that the point of the semantics of this function signature is more like a model than test dataset. If it were data, there would be some libraries that could have already done well for it.

If https://github.com/JuliaML/MLMetrics.jl is enough, applying it in docs / tutorials is also a good choice, which can also help new users reduce work in testing model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants