Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CTC loss to new Losses module #1287

Merged
merged 36 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9e31e53
Add CTC loss and tests
maetshju Jul 19, 2020
37efaa0
Add ctc to Losses module
maetshju Jul 20, 2020
f471337
General updates
maetshju Oct 13, 2020
b19b88c
Reverting bad merge
maetshju Oct 13, 2020
da3564b
Revert "General updates"
maetshju Oct 13, 2020
d8242c0
General ctc updates
maetshju Oct 13, 2020
5bf2635
Get test cases working
maetshju Oct 15, 2020
1707255
Merge branch 'master' into ctc
maetshju Oct 17, 2020
e4123ab
Update NEWS.md
maetshju Oct 17, 2020
3045ed2
Re-pull from main Flux repo
maetshju Dec 7, 2020
46898ed
Change ctc to ctc_loss
maetshju Dec 14, 2020
110a608
Fix typo
maetshju Dec 14, 2020
5145222
Remove camel-casing
maetshju Dec 14, 2020
e002027
Remove some whitespace from functions
maetshju Dec 14, 2020
d0bd3bd
Adding info to comply with Apache license
maetshju Dec 14, 2020
5dafa05
Use logsumexp in CPU CTC
maetshju Dec 18, 2020
3950464
Change logsum to logsumexp
maetshju Dec 18, 2020
282cb23
Re-add logaddexp function to CPU ctc
maetshju Dec 19, 2020
50fc561
Merge branch 'master' into ctc
maetshju Dec 19, 2020
d9caac4
Regnerate Manifest.toml
maetshju Dec 19, 2020
c043855
Remove time indexing from ctc gradchecks
maetshju Dec 20, 2020
3eb9e51
Revert "Remove time indexing from ctc gradchecks"
maetshju Dec 20, 2020
bccef7a
Change typedZero to typed_zero
maetshju Dec 20, 2020
00d4125
Update gradcheck comment; remove some whitespace
maetshju Dec 22, 2020
9b37c8f
Reduce allocations for better performance
maetshju Dec 24, 2020
75bb3c1
Transpose alpha and beta to match GPU kernel
maetshju Dec 24, 2020
b00487a
Update add_blanks to use fill
maetshju Dec 24, 2020
d96a53f
Split CPU loss and gradient calculation
maetshju Dec 31, 2020
6ca07e2
Rejig CPU CTC API
maetshju Jan 3, 2021
bc7ab03
Split GPU CTC kernel and update API
maetshju Jan 3, 2021
9bb245d
Merge branch 'master' into ctc
maetshju Jan 3, 2021
807aefa
Move to onecold representation for ctc input
maetshju Jan 14, 2021
6f108c6
Merge branch 'master' into ctc
maetshju Jan 14, 2021
e1e8cc8
Apply suggestions from code review
maetshju Jan 16, 2021
6e5fb17
Remove F in ctc tests; update ctc-gpu test syntax
maetshju Jan 16, 2021
bc94a16
Fix indentation in ctc.jl
maetshju Jan 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* Excise datasets in favour of other providers in the julia ecosystem.
* other new features and bug fixes (see GitHub's releases page)
* Added option to set `bias` to [false](https://github.com/FluxML/Flux.jl/pull/1379) to eliminating `bias` from being trained.
* Add [CTC loss function](https://github.com/FluxML/Flux.jl/pull/1287) to Losses module

## v0.11.2

Expand Down
7 changes: 5 additions & 2 deletions src/losses/Losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ export mse, mae, msle,
tversky_loss,
dice_coeff_loss,
poisson_loss,
hinge_loss, squared_hinge_loss
hinge_loss, squared_hinge_loss,
ctc
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

include("utils.jl")
include("functions.jl")
include("ctc.jl")
if CUDA.functional() include("ctc-gpu.jl") end
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

end #module
end #module
291 changes: 291 additions & 0 deletions src/losses/ctc-gpu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
# GPU impelmentation
maetshju marked this conversation as resolved.
Show resolved Hide resolved

# a port of the GPU kernels from Baidu's C++ warp-ctc package
# GitHub: https://github.com/baidu-research/warp-ctc/
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
# paper: https://arxiv.org/pdf/1512.02595.pdf

using Flux
using Statistics
using CUDA

const MAX_THREADS = 256

function log_plus_f(p1, p2)

CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
isinf(p1) && return p2
isinf(p2) && return p1

if p1 < p2
p1, p2 = p2, p1
end

return p1 + CUDA.log(1+CUDA.exp(p2 - p1))
end

function countRepeats(A)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
repeats = 0
for (i,elem) in enumerate(A)
if i > 1 && A[i] == A[i-1]
repeats += 1
end
end
return repeats
end

function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel)

tid = threadIdx().x
L = labelSize
T = uttLength
S = length(labelsWithBlanks)

if L + repeats > T
return nothing
end

labels = labelsWithBlanks

# Corner-case checking
start = (L + repeats <= T) ? 0 : 1
last = S > 1 ? 2 : 1

# Fill in first column (time step)
i = tid
while i <= last - start
alpha[start+i, 1] = probs[labels[start+i], 1]
i += blockDim().x
end

sync_threads()

# Fill in coefficients for each time step
for t=2:T

# Corner-case checking
if tid == 1 && !(1 < S - 2*(T-t) - 1)
if start == 0
alpha[1, t] = probs[blankLabel, t] + alpha[1, t-1]
elseif start == 1
alpha[1, t] = alpha[1, t-1]
end
end

sync_threads()

# Fill in coefficients for each label class in the target output sequence;
# each thread will process the calculations for one class
idx = tid+1
while idx <= S

prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1])

if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2]
prevSum = log_plus_f(prevSum, alpha[idx-2, t-1])
end

if idx < S - 2*(T-t) - 1
alpha[idx, t] = -Inf32
else
alpha[idx, t] = prevSum + probs[labels[idx], t]
end

idx += blockDim().x
end

sync_threads()
end
return nothing
end

function computeBetasAndGradKernel(probs, labelSize, uttLength,
repeatsInLabel, labelsWithBlanks,
alphas, beta, output, accum,
grad, blankLabel)

tid = threadIdx().x
L = labelSize
T = uttLength
S = 2*L + 1
repeats = repeatsInLabel

labels = labelsWithBlanks

if (L+repeats) > T
return nothing
end

# Corner-case checking
start = S > 1 ? S-2 : 0
last = L + repeats < T ? S : S-1

sync_threads()

i = tid

# Calculate coefficients for last column (time step)
# then determine alpha and beta product
while i <= last - start + 1
beta[i+start, T] = 0
output[i+start, T] = beta[i+start, T] + alphas[i+start, T]
i += blockDim().x
end

sync_threads()

# Fill in `accum` for last column (time step)
if tid == 1
for i=1:S
labelIdx = labels[i]
accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T])
end
end

sync_threads()

# Fill in `grad` for last column (time step)
idx = tid
while idx <= size(grad, 1)

s = -Inf32

for i=1:S
s = log_plus_f(s, output[i, T])
end

# ∂L/∂a (where a is activation before logsoftmax)
grad[idx, T] = CUDA.exp(probs[idx, T]) - CUDA.exp(accum[idx, T] - s)
idx += blockDim().x
end

sync_threads()

# Fill in the rest of the coefficients
t = T-1
while t >= 1
if t < T

idx = tid
# while idx <= S-1
while idx <= S

nextSum = beta[idx, t+1] + probs[labels[idx], t+1]

if idx < S

nextSum = log_plus_f(nextSum,
beta[idx+1, t+1] + probs[labels[idx+1], t+1])
end

if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2]
nextSum = log_plus_f(nextSum,
beta[idx + 2, t+1] + probs[labels[idx+2], t+1])
end

if idx > 2*t
beta[idx, t] = -Inf32
else
beta[idx, t] = nextSum

end

idx += blockDim().x
end

sync_threads()

if tid == 1 && last == S
beta[S, t] = beta[S, t] + probs[blankLabel, t+1]
end

sync_threads()

idx = tid
while idx <= S
output[idx, t] = alphas[idx, t] + beta[idx, t]
idx += blockDim().x
end

sync_threads()
end


sync_threads()

# Calculate accumulated alpha-beta products for each label class for
# each time step; used in calculating gradients
if tid == 1
for i=1:S
labelIdx = labels[i]
accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t])
end
end

sync_threads()

idx = tid

# Calculate gradients
while idx <= size(grad, 1)

s = -Inf32

for i=1:S
s = log_plus_f(s, output[i, t])
end

# ∂L/∂a (where a is activation before logsoftmax)
grad[idx, t] = CUDA.exp(probs[idx, t]) - CUDA.exp(accum[idx, t] - s)
idx += blockDim().x
end

sync_threads()

t -= 1
sync_threads()
end

return nothing
end

# methods for `ctc_` helper function
ctc(ŷ::CuArray, y::Array) = ctc_(ŷ, y)[1] |> mean
ctc(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y))[1] |> mean
ctc(ŷ::CuArray, y::CuArray) = ctc_(ŷ, collect(y))[1] |> mean
ctc_(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y))

function ctc_(ŷ::CuArray, y)

ŷ = logsoftmax(ŷ)

blank = size(ŷ, 1)
labels = [Base.argmax(y[:,i]) for i in 1:size(y, 2)]
z = F(labels, blank)
z′ = [blank]
for label in z
push!(z′, label)
push!(z′, blank)
end

T = size(ŷ, 2)
U′ = 2*length(z) + 1

alphas = CUDA.fill(log(zero(ŷ[1])), U′, T)
betas = CUDA.fill(log(zero(ŷ[1])), U′, T)
output = CUDA.fill(log(zero(ŷ[1])), U′, T)

nRepeats = countRepeats(labels)
nThreads = min(U′, MAX_THREADS)

@cuda blocks=1 threads=nThreads computeAlphaKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank)

grads = CUDA.fill(log(zero(ŷ[1])), size(ŷ))
accum = CUDA.fill(log(zero(ŷ[1])), size(ŷ))

@cuda blocks=1 threads=nThreads computeBetasAndGradKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank)

ls = collect(output)
ls = vec(-1 .* [logsum(ls[:,i]) for i in 1:size(ls, 2)])

ŷ = alphas = betas = output = accum = nothing
return ls, grads
end
Loading