Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CTC loss to new Losses module #1287

Merged
merged 36 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9e31e53
Add CTC loss and tests
maetshju Jul 19, 2020
37efaa0
Add ctc to Losses module
maetshju Jul 20, 2020
f471337
General updates
maetshju Oct 13, 2020
b19b88c
Reverting bad merge
maetshju Oct 13, 2020
da3564b
Revert "General updates"
maetshju Oct 13, 2020
d8242c0
General ctc updates
maetshju Oct 13, 2020
5bf2635
Get test cases working
maetshju Oct 15, 2020
1707255
Merge branch 'master' into ctc
maetshju Oct 17, 2020
e4123ab
Update NEWS.md
maetshju Oct 17, 2020
3045ed2
Re-pull from main Flux repo
maetshju Dec 7, 2020
46898ed
Change ctc to ctc_loss
maetshju Dec 14, 2020
110a608
Fix typo
maetshju Dec 14, 2020
5145222
Remove camel-casing
maetshju Dec 14, 2020
e002027
Remove some whitespace from functions
maetshju Dec 14, 2020
d0bd3bd
Adding info to comply with Apache license
maetshju Dec 14, 2020
5dafa05
Use logsumexp in CPU CTC
maetshju Dec 18, 2020
3950464
Change logsum to logsumexp
maetshju Dec 18, 2020
282cb23
Re-add logaddexp function to CPU ctc
maetshju Dec 19, 2020
50fc561
Merge branch 'master' into ctc
maetshju Dec 19, 2020
d9caac4
Regnerate Manifest.toml
maetshju Dec 19, 2020
c043855
Remove time indexing from ctc gradchecks
maetshju Dec 20, 2020
3eb9e51
Revert "Remove time indexing from ctc gradchecks"
maetshju Dec 20, 2020
bccef7a
Change typedZero to typed_zero
maetshju Dec 20, 2020
00d4125
Update gradcheck comment; remove some whitespace
maetshju Dec 22, 2020
9b37c8f
Reduce allocations for better performance
maetshju Dec 24, 2020
75bb3c1
Transpose alpha and beta to match GPU kernel
maetshju Dec 24, 2020
b00487a
Update add_blanks to use fill
maetshju Dec 24, 2020
d96a53f
Split CPU loss and gradient calculation
maetshju Dec 31, 2020
6ca07e2
Rejig CPU CTC API
maetshju Jan 3, 2021
bc7ab03
Split GPU CTC kernel and update API
maetshju Jan 3, 2021
9bb245d
Merge branch 'master' into ctc
maetshju Jan 3, 2021
807aefa
Move to onecold representation for ctc input
maetshju Jan 14, 2021
6f108c6
Merge branch 'master' into ctc
maetshju Jan 14, 2021
e1e8cc8
Apply suggestions from code review
maetshju Jan 16, 2021
6e5fb17
Remove F in ctc tests; update ctc-gpu test syntax
maetshju Jan 16, 2021
bc94a16
Fix indentation in ctc.jl
maetshju Jan 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405)
* Excise datasets in favour of other providers in the julia ecosystem.
* Added option to set `bias` to [false](https://github.com/FluxML/Flux.jl/pull/1379) to eliminating `bias` from being trained.
* Add [CTC loss function](https://github.com/FluxML/Flux.jl/pull/1287) to Losses module
* Removed kwarg only constructors for [`convolutional layers`](https://github.com/FluxML/Flux.jl/pull/1379).
* Add [sparse initialization](https://github.com/FluxML/Flux.jl/pull/1454) as described in [Deep learning via Hessian-free optimization](https://dl.acm.org/doi/abs/10.5555/3104322.3104416).
* Moved GPU CI to use buildkite instead of GitLab
Expand Down
7 changes: 5 additions & 2 deletions src/losses/Losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ export mse, mae, msle,
tversky_loss,
dice_coeff_loss,
poisson_loss,
hinge_loss, squared_hinge_loss
hinge_loss, squared_hinge_loss,
ctc_loss

include("utils.jl")
include("functions.jl")
include("ctc.jl")
if CUDA.functional() include("ctc-gpu.jl") end
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

end #module
end #module
232 changes: 232 additions & 0 deletions src/losses/ctc-gpu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# GPU implementation

# a port of the GPU kernels from Baidu's C++ warp-ctc package,
# which itself is Copyright 2015-2016 Baidu USA LLC
# and available under the Apache 2.0 license
#
# Apache 2.0 license: https://www.apache.org/licenses/LICENSE-2.0
# GitHub: https://github.com/baidu-research/warp-ctc/
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
# paper: https://arxiv.org/pdf/1512.02595.pdf

using Flux
using Statistics
using CUDA
using NNlib

const MAX_THREADS = 256

function log_plus_f(p1, p2)
isinf(p1) && return p2
isinf(p2) && return p1
if p1 < p2
p1, p2 = p2, p1
end
return p1 + CUDA.log(1+CUDA.exp(p2 - p1))
end

function count_repeats(A)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
repeats = 0
for (i,elem) in enumerate(A)
if i > 1 && A[i] == A[i-1]
repeats += 1
end
end
return repeats
end

function compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel)

tid = threadIdx().x
L = labelSize
T = uttLength
S = length(labelsWithBlanks)

if L + repeats > T
return nothing
end
labels = labelsWithBlanks

# Corner-case checking
start = (L + repeats <= T) ? 0 : 1
last = S > 1 ? 2 : 1

# Fill in first column (time step)
i = tid
while i <= last - start
alpha[start+i, 1] = probs[labels[start+i], 1]
i += blockDim().x
end
sync_threads()

# Fill in coefficients for each time step
for t=2:T
# Corner-case checking
if tid == 1 && !(1 < S - 2*(T-t) - 1)
if start == 0
alpha[1, t] = alpha[1, t-1] + probs[blankLabel, t]
elseif start == 1
alpha[1, t] = alpha[1, t-1]
end
end
sync_threads()

# Fill in coefficients for each label class in the target output sequence;
# each thread will process the calculations for one class
idx = tid+1
while idx <= S
prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1])
if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2]
prevSum = log_plus_f(prevSum, alpha[idx-2, t-1])
end
if idx < S - 2*(T-t) - 1
alpha[idx, t] = -Inf32
else
alpha[idx, t] = prevSum + probs[labels[idx], t]
end
idx += blockDim().x
end
sync_threads()
end
return nothing
end

function compute_beta_and_grad_kernel(probs, labelSize, uttLength,
repeatsInLabel, labelsWithBlanks,
alphas, beta, output, accum,
grad, blankLabel, loss)

tid = threadIdx().x
L = labelSize
T = uttLength
S = 2*L + 1
repeats = repeatsInLabel
labels = labelsWithBlanks

if (L+repeats) > T
return nothing
end

# Corner-case checking
start = S > 1 ? S-2 : 0
last = L + repeats < T ? S : S-1
sync_threads()
i = tid

# Calculate coefficients for last column (time step)
# then determine alpha and beta product
while i <= last - start
beta[i+start, T] = 0
output[i+start, T] = beta[i+start, T] + alphas[i+start, T]
i += blockDim().x
end
sync_threads()

# Fill in `accum` for last column (time step)
if tid == 1
for i=1:S
labelIdx = labels[i]
accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T])
end
end
sync_threads()

# Fill in `grad` for last column (time step)
idx = tid
while idx <= size(grad, 1)
s = -Inf32
for i=1:S
s = log_plus_f(s, output[i, T])
end

# ∂L/∂a (where a is activation before logsoftmax)
grad[idx, T] = CUDA.exp(probs[idx, T]) - CUDA.exp(accum[idx, T] - s)
idx += blockDim().x
end
sync_threads()

# Fill in the rest of the coefficients
t = T-1
while t >= 1
if t < T
idx = tid
while idx <= S
nextSum = probs[labels[idx], t+1] + beta[idx, t+1]
if idx < S
nextSum = log_plus_f(nextSum,
probs[labels[idx+1], t+1] + beta[idx+1, t+1])
end
if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2]
nextSum = log_plus_f(nextSum,
probs[labels[idx+2], t+1] + beta[idx + 2, t+1])
end
if idx > 2*t
beta[idx, t] = -Inf32
else
beta[idx, t] = nextSum
end
idx += blockDim().x
end
sync_threads()
idx = tid
while idx <= S
output[idx, t] = alphas[idx, t] + beta[idx, t]
idx += blockDim().x
end
sync_threads()
end
sync_threads()

# Calculate accumulated alpha-beta products for each label class for
# each time step; used in calculating gradients
if tid == 1
for i=1:S
labelIdx = labels[i]
accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t])
end
end
sync_threads()
idx = tid

# Calculate gradients
while idx <= size(grad, 1)

# ∂L/∂a (where a is activation before logsoftmax)
grad[idx, t] = CUDA.exp(probs[idx, t]) - CUDA.exp(accum[idx, t] + loss)
idx += blockDim().x
end
sync_threads()
t -= 1
sync_threads()
end
return nothing
end

function ctc_alpha(ŷ::CuArray, y)
ŷ = logsoftmax(ŷ)
blank = size(ŷ, 1)
z′ = fill(blank, 2 * length(y) + 1)
z′[eachindex(y) .* 2] = y
T = size(ŷ, 2)
U′ = 2*length(y) + 1
alphas = CUDA.fill(log(zero(ŷ[1])), U′,T)
nRepeats = count_repeats(y)
nThreads = min(U′, MAX_THREADS)
@cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, CuArray(y), CuArray(z′), alphas, blank)
return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z′=z′, yhat=ŷ, nRepeats=nRepeats)
end

ctc_loss(ŷ::CuArray, y) = ctc_alpha(ŷ::CuArray, y).loss

function ∇ctc_loss(ŷ::CuArray, y, out)
loss, alphas, z′, ŷ, nRepeats = out
U′, T = size(alphas)
blank = size(ŷ, 1)
typed_zero = zero(first(ŷ))
betas = CUDA.fill(log(typed_zero), U′, T)
output = CUDA.fill(log(typed_zero), U′, T)
nThreads = min(U′, MAX_THREADS)
grads = CUDA.fill(log(typed_zero), size(ŷ))
accum = CUDA.fill(log(typed_zero), size(ŷ))
@cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(y), T, nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank, loss)
return grads
end
138 changes: 138 additions & 0 deletions src/losses/ctc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
using Flux
using Zygote: @adjoint
using Statistics
using NNlib

# CPU implementation
"""
logaddexp(a, b)
Adds log-space `a` and `b` such that the result equals `log(exp(a)+exp(b))`
"""
function logaddexp(a, b)
isinf(a) && return b
isinf(b) && return a

# always want the greater number on the left in the exponentiation;
# the magnitude difference may end up making the number very positive
# which will cause exp() to return Inf
# E.g., a = -900, b = -800, will give exp(-800 - -900), which will be
# Inf for Float32 values
if a < b
a, b = b, a
end
return a + log(1+exp(b-a))
end

"""
add_blanks(z)

Adds blanks to the start and end of `z`, and between items in `z`
"""
function add_blanks(z, blank)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
z′ = fill(blank, 2*length(z) + 1)
z′[2 .* eachindex(z)] = z
return z′
end

function ctc_alpha(ŷ::AbstractArray, y)
typed_zero = zero(ŷ[1])
ŷ = logsoftmax(ŷ)
blank = size(ŷ, 1)
z′ = add_blanks(y, blank)
T = size(ŷ, 2)
U′ = length(z′)

α = fill(log(typed_zero), U′, T)
α[1,1] = ŷ[blank, 1]
α[2,1] = ŷ[z′[2], 1]
for t=2:T
bound = max(1, U′ - 2(T - t) - 1)
for u=bound:U′
if u == 1
α[u,t] = α[u, t-1]
else
α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1])

# array bounds check and f(u) function from Eq. 7.9
if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u])
α[u,t] = logaddexp(α[u,t], α[u-2,t-1])
end
end
α[u,t] += ŷ[z′[u], t]
end
end
return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ)
end

function ∇ctc_loss(ŷ::AbstractArray, y, out)
loss, α, z′, ŷ = out
U′, T = size(α)
blank = size(ŷ, 1)
typed_zero = zero(first(α))

# Calculate beta coefficients, from the bottom-right, to the upper-left
β = fill(log(typed_zero), U′, T)

# Fill bottom-right corner so bounding errors can be avoided
# by starting `u` at `U′-1`
β[U′, T] = typed_zero
β[U′-1, T] = typed_zero

# start at T-1 so that β(T, u) = log(0) for all u < U′ - 1
for t=(T-1):-1:1
bound = min(U′, 2t)
for u=bound:-1:1
if u == U′
β[u,t] = ŷ[z′[u], t+1] + β[u, t+1]
else
β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1])

# array bounds check and g(u) function from Eq. 7.16
if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2]
β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1])
end
end
end
end

# Accumulate alpha-beta products for each category,
# then calculate gradients
accum = fill(log(typed_zero), size(ŷ))
for t=1:T
for u=1:U′
accum[z′[u], t] = logaddexp(accum[z′[u], t], α[u,t] + β[u,t])
end
end
grads = exp.(ŷ) .- exp.(accum .+ loss)
return grads
end

"""
ctc_loss(ŷ, y)

Computes the connectionist temporal classification loss between `ŷ`
and `y`.

`ŷ` must be a classes-by-time matrices, i.e., each row
represents a class and each column represents a time step.
Additionally, the `logsoftmax` function will be applied to `ŷ`, so
`ŷ` must be the raw activation values from the neural network and
not, for example, the activations after being passed through a
`softmax` activation function. `y` must be a 1D array of the labels
associated with `ŷ`. The blank label is assumed to be the last label
category in `ŷ`, so it is equivalent to `size(ŷ, 1)`.

Used for sequence-to-sequence classification problems such as
speech recognition and handwriting recognition where the exact
time-alignment of the output (e.g., letters) is not needed to
solve the problem. See [Graves et al. (2006)](https://www.cs.toronto.edu/~graves/icml_2006.pdf)
or [Graves (2012)](https://www.cs.toronto.edu/~graves/preprint.pdf#chapter.7)
for mathematical details.
"""
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss

@adjoint function ctc_loss(ŷ, y)
out = ctc_alpha(ŷ, y)
ctc_loss_pullback(Δ) = (Δ .* ∇ctc_loss(ŷ, y, out), nothing)
return out.loss, ctc_loss_pullback
end
Loading