From 90a433500d607880d1d506608db2f601af9ce505 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 26 Sep 2022 19:40:41 -0400 Subject: [PATCH 1/4] make loss(f,x,y) work --- docs/src/models/losses.md | 14 +++++++--- src/losses/Losses.jl | 55 ++++++++++++++++++++++++++++++++++++--- test/losses.jl | 9 +++++++ 3 files changed, 71 insertions(+), 7 deletions(-) diff --git a/docs/src/models/losses.md b/docs/src/models/losses.md index ae0efac186..8b0af5dd8e 100644 --- a/docs/src/models/losses.md +++ b/docs/src/models/losses.md @@ -4,21 +4,27 @@ Flux provides a large number of common loss functions used for training machine They are grouped together in the `Flux.Losses` module. Loss functions for supervised learning typically expect as inputs a target `y`, and a prediction `ŷ` from your model. -In Flux's convention, the order of the arguments is the following +In Flux's convention, the target is the last argumemt: ```julia loss(ŷ, y) ``` +All loss functions have a method which takes the model as the first argument, and calculates the prediction `ŷ = model(x)`. +This is convenient for [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`: + +```julia +loss(model, x, y) = loss(model(x), y) +``` + Most loss functions in Flux have an optional argument `agg`, denoting the type of aggregation performed over the batch: ```julia loss(ŷ, y) # defaults to `mean` -loss(ŷ, y, agg=sum) # use `sum` for reduction -loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction +loss(ŷ, y, agg=sum) # use `sum` instead loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean -loss(ŷ, y, agg=identity) # no aggregation. +loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction, returns an array ``` ### Function listing diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 3d8f6f8149..e5d29c7e24 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -1,3 +1,28 @@ +""" + Flux.Losses + +This sub-module contains many loss functions, all of which accept two arguments, +with the model output as the fist argument: `loss(model(x), y)`. +It also contains a few related utilities, such as `label_smoothing`. +The complete list of exports is: + + label_smoothing, + mse, mae, msle, + crossentropy, + logitcrossentropy, + binarycrossentropy, + logitbinarycrossentropy, + kldivergence, + huber_loss, + tversky_loss, + dice_coeff_loss, + poisson_loss, + hinge_loss, + squared_hinge_loss, + binary_focal_loss, + focal_loss, + siamese_contrastive_loss +""" module Losses using Statistics @@ -9,8 +34,8 @@ using CUDA using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss import Base.Broadcast: broadcasted -export mse, mae, msle, - label_smoothing, +export label_smoothing, + mse, mae, msle, crossentropy, logitcrossentropy, binarycrossentropy, logitbinarycrossentropy, kldivergence, @@ -19,9 +44,33 @@ export mse, mae, msle, dice_coeff_loss, poisson_loss, hinge_loss, squared_hinge_loss, - binary_focal_loss, focal_loss, siamese_contrastive_loss + binary_focal_loss, focal_loss, + siamese_contrastive_loss include("utils.jl") include("functions.jl") +for loss in Symbol.([ + mse, mae, msle, + crossentropy, logitcrossentropy, + binarycrossentropy, logitbinarycrossentropy, + kldivergence, + huber_loss, + tversky_loss, + dice_coeff_loss, + poisson_loss, + hinge_loss, squared_hinge_loss, + binary_focal_loss, focal_loss, + siamese_contrastive_loss, + ]) + @eval begin + """ + $($loss)(model, x, y) + + This method calculates `ŷ = model(x)`. Accepts the same keyword arguments. + """ + $loss(f, x::AbstractArray, y::AbstractArray; kw...) = $loss(f(x), y; kw...) + end +end + end #module diff --git a/test/losses.jl b/test/losses.jl index 2ca697a657..09a00ae3b5 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -248,3 +248,12 @@ end @test_throws DomainError(-0.5, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ1, y1, margin = -0.5) @test_throws DomainError(-1, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ, y, margin = -1) end + +@testset "3-arg methods" begin + @testset for loss in ALL_LOSSES + fun(x) = x[1:2] + x = rand(3) + y = rand(2) + @test loss(fun, x, y) == loss(fun(x), y) + end +end \ No newline at end of file From 891709920a50148e038658d71a9b8b0129959517 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 18 Oct 2022 13:03:33 -0400 Subject: [PATCH 2/4] rm module docstring --- src/losses/Losses.jl | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index e5d29c7e24..8b82ed7012 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -1,28 +1,3 @@ -""" - Flux.Losses - -This sub-module contains many loss functions, all of which accept two arguments, -with the model output as the fist argument: `loss(model(x), y)`. -It also contains a few related utilities, such as `label_smoothing`. -The complete list of exports is: - - label_smoothing, - mse, mae, msle, - crossentropy, - logitcrossentropy, - binarycrossentropy, - logitbinarycrossentropy, - kldivergence, - huber_loss, - tversky_loss, - dice_coeff_loss, - poisson_loss, - hinge_loss, - squared_hinge_loss, - binary_focal_loss, - focal_loss, - siamese_contrastive_loss -""" module Losses using Statistics From 39055aebf47c7d74e53378ebee75f074fb9703be Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 18 Oct 2022 13:08:30 -0400 Subject: [PATCH 3/4] tweak --- docs/src/models/losses.md | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/docs/src/models/losses.md b/docs/src/models/losses.md index 8b0af5dd8e..3fdc8465f4 100644 --- a/docs/src/models/losses.md +++ b/docs/src/models/losses.md @@ -10,21 +10,20 @@ In Flux's convention, the target is the last argumemt: loss(ŷ, y) ``` -All loss functions have a method which takes the model as the first argument, and calculates the prediction `ŷ = model(x)`. +All loss functions in Flux have a method which takes the model as the first argument, and calculates the prediction `ŷ = model(x)`. This is convenient for [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`: ```julia loss(model, x, y) = loss(model(x), y) ``` -Most loss functions in Flux have an optional argument `agg`, denoting the type of aggregation performed over the -batch: +Most loss functions in Flux have an optional keyword argument `agg`, which is the aggregation function used over the batch: ```julia -loss(ŷ, y) # defaults to `mean` -loss(ŷ, y, agg=sum) # use `sum` instead -loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean -loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction, returns an array +loss(ŷ, y) # defaults to `Statistics.mean` +loss(ŷ, y; agg = sum) # use `sum` instead +loss(ŷ, y; agg = x->mean(w .* x)) # weighted mean +loss(ŷ, y; agg = x->sum(x, dims=2)) # partial reduction, returns an array ``` ### Function listing From 3980a1c8b916299fcb92dee76e3eb0242a503043 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 20 Oct 2022 10:01:56 -0400 Subject: [PATCH 4/4] news, doc example --- NEWS.md | 1 + docs/src/models/losses.md | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/NEWS.md b/NEWS.md index d83e76d62a..f8ce90d3a6 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,7 @@ ## v0.13.7 * Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078) +* All loss functions now [accept 3 arguments](https://github.com/FluxML/Flux.jl/pull/2090), `loss(model, x, y) == loss(model(x), y)`. ## v0.13.4 * Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983) diff --git a/docs/src/models/losses.md b/docs/src/models/losses.md index 3fdc8465f4..b98e313c87 100644 --- a/docs/src/models/losses.md +++ b/docs/src/models/losses.md @@ -4,26 +4,28 @@ Flux provides a large number of common loss functions used for training machine They are grouped together in the `Flux.Losses` module. Loss functions for supervised learning typically expect as inputs a target `y`, and a prediction `ŷ` from your model. -In Flux's convention, the target is the last argumemt: +In Flux's convention, the target is the last argumemt, so a new loss function could be defined: ```julia -loss(ŷ, y) +newloss(ŷ, y) = sum(abs2, ŷ .- y) # total squared error ``` All loss functions in Flux have a method which takes the model as the first argument, and calculates the prediction `ŷ = model(x)`. -This is convenient for [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`: +This is convenient for [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`. +For our example it could be defined: ```julia -loss(model, x, y) = loss(model(x), y) +newloss(model, x, y) = newloss(model(x), y) ``` -Most loss functions in Flux have an optional keyword argument `agg`, which is the aggregation function used over the batch: +Most loss functions in Flux have an optional keyword argument `agg`, which is the aggregation function used over the batch. +Thus you may call, for example: ```julia -loss(ŷ, y) # defaults to `Statistics.mean` -loss(ŷ, y; agg = sum) # use `sum` instead -loss(ŷ, y; agg = x->mean(w .* x)) # weighted mean -loss(ŷ, y; agg = x->sum(x, dims=2)) # partial reduction, returns an array +crossentropy(ŷ, y) # defaults to `Statistics.mean` +crossentropy(ŷ, y; agg = sum) # use `sum` instead +crossentropy(ŷ, y; agg = x->mean(w .* x)) # weighted mean +crossentropy(ŷ, y; agg = x->sum(x, dims=2)) # partial reduction, returns an array ``` ### Function listing