Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim authored and seanmor5 committed Dec 21, 2022
1 parent c77fb48 commit 1da7d56
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions lib/axon/losses.ex
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,19 @@ defmodule Axon.Losses do
# both and perform this whole thing. If neither is set, we set this to
# nil and then avoid the weighted avg later on.
weights =
case {opts[:positive_weight], opts[:negative_weight]} do
{nil, nil} -> nil
{pos, nil} -> Nx.take(Nx.tensor([1.0, pos]), y_true)
{nil, neg} -> Nx.take(Nx.tensor([neg, 1.0]), y_true)
{pos, neg} -> Nx.take(Nx.tensor([neg, pos]), y_true)
end
transform({y_true, opts[:positive_weight], opts[:negative_weight]}, fn
{_, nil, nil} ->
nil

{y_true, pos, nil} ->
Nx.take(Nx.tensor([1.0, pos], backend: Nx.Defn.Expr), y_true)

{y_true, nil, neg} ->
Nx.take(Nx.tensor([neg, 1.0], backend: Nx.Defn.Expr), y_true)

{y_true, pos, neg} ->
Nx.take(Nx.tensor([neg, pos], backend: Nx.Defn.Expr), y_true)
end)

# Merge types before computing loss to prevent under/overflow. This
# can especially happen when targets are encoded as u8 tensors. We
Expand Down

0 comments on commit 1da7d56

Please sign in to comment.