diff --git a/NEWS.md b/NEWS.md index d83e76d62a..f8ce90d3a6 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,7 @@ ## v0.13.7 * Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078) +* All loss functions now [accept 3 arguments](https://github.com/FluxML/Flux.jl/pull/2090), `loss(model, x, y) == loss(model(x), y)`. ## v0.13.4 * Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983) diff --git a/docs/src/models/losses.md b/docs/src/models/losses.md index ae0efac186..b98e313c87 100644 --- a/docs/src/models/losses.md +++ b/docs/src/models/losses.md @@ -4,21 +4,28 @@ Flux provides a large number of common loss functions used for training machine They are grouped together in the `Flux.Losses` module. Loss functions for supervised learning typically expect as inputs a target `y`, and a prediction `ŷ` from your model. -In Flux's convention, the order of the arguments is the following +In Flux's convention, the target is the last argumemt, so a new loss function could be defined: ```julia -loss(ŷ, y) +newloss(ŷ, y) = sum(abs2, ŷ .- y) # total squared error ``` -Most loss functions in Flux have an optional argument `agg`, denoting the type of aggregation performed over the -batch: +All loss functions in Flux have a method which takes the model as the first argument, and calculates the prediction `ŷ = model(x)`. +This is convenient for [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`. +For our example it could be defined: ```julia -loss(ŷ, y) # defaults to `mean` -loss(ŷ, y, agg=sum) # use `sum` for reduction -loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction -loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean -loss(ŷ, y, agg=identity) # no aggregation. +newloss(model, x, y) = newloss(model(x), y) +``` + +Most loss functions in Flux have an optional keyword argument `agg`, which is the aggregation function used over the batch. +Thus you may call, for example: + +```julia +crossentropy(ŷ, y) # defaults to `Statistics.mean` +crossentropy(ŷ, y; agg = sum) # use `sum` instead +crossentropy(ŷ, y; agg = x->mean(w .* x)) # weighted mean +crossentropy(ŷ, y; agg = x->sum(x, dims=2)) # partial reduction, returns an array ``` ### Function listing diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 3d8f6f8149..8b82ed7012 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -9,8 +9,8 @@ using CUDA using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss import Base.Broadcast: broadcasted -export mse, mae, msle, - label_smoothing, +export label_smoothing, + mse, mae, msle, crossentropy, logitcrossentropy, binarycrossentropy, logitbinarycrossentropy, kldivergence, @@ -19,9 +19,33 @@ export mse, mae, msle, dice_coeff_loss, poisson_loss, hinge_loss, squared_hinge_loss, - binary_focal_loss, focal_loss, siamese_contrastive_loss + binary_focal_loss, focal_loss, + siamese_contrastive_loss include("utils.jl") include("functions.jl") +for loss in Symbol.([ + mse, mae, msle, + crossentropy, logitcrossentropy, + binarycrossentropy, logitbinarycrossentropy, + kldivergence, + huber_loss, + tversky_loss, + dice_coeff_loss, + poisson_loss, + hinge_loss, squared_hinge_loss, + binary_focal_loss, focal_loss, + siamese_contrastive_loss, + ]) + @eval begin + """ + $($loss)(model, x, y) + + This method calculates `ŷ = model(x)`. Accepts the same keyword arguments. + """ + $loss(f, x::AbstractArray, y::AbstractArray; kw...) = $loss(f(x), y; kw...) + end +end + end #module diff --git a/test/losses.jl b/test/losses.jl index 2ca697a657..09a00ae3b5 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -248,3 +248,12 @@ end @test_throws DomainError(-0.5, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ1, y1, margin = -0.5) @test_throws DomainError(-1, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ, y, margin = -1) end + +@testset "3-arg methods" begin + @testset for loss in ALL_LOSSES + fun(x) = x[1:2] + x = rand(3) + y = rand(2) + @test loss(fun, x, y) == loss(fun(x), y) + end +end \ No newline at end of file