using Statistics using Flux # GPU impelmentation # a port of the GPU kernels from Baidu's C++ warp-ctc package # GitHub: https://github.com/baidu-research/warp-ctc/ # paper: https://arxiv.org/pdf/1512.02595.pdf function logadd(a, b) isinf(a) && return b isinf(b) && return a # always want the greater number on the left in the exponentiation; # the magnitude difference may end up making the number very positive # which will cause exp() to return Inf # E.g., a = -900, b = -800, will give exp(-800 - -900), which will be # Inf for Float32 values if a < b a, b = b, a end return a + log(1+exp(b-a)) end function logsum(a) local s s = a[1] for item in a[2:end] s = logadd(s, item) end return s end function F(A, blank) prev = A[1] z = [prev] for curr in A[2:end] if curr != prev && curr != blank push!(z, curr) `` end prev = curr end return z end """ addBlanks(z) Adds blanks to the start and end of `z`, and between item in `z` """ function addBlanks(z, blank) z′ = [blank] for label in z push!(z′, label) push!(z′, blank) end return z′ end function log_plus_f(p1, p2) isinf(p1) && return p2 isinf(p2) && return p1 if p1 < p2 p1, p2 = p2, p1 end return p1 + CUDAnative.log(1+CUDAnative.exp(p2 - p1)) end function countRepeats(A) repeats = 0 for (i,elem) in enumerate(A) if i > 1 && A[i] == A[i-1] repeats += 1 end end return repeats end using CUDAnative, CuArrays function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel) tid = threadIdx().x L = labelSize T = uttLength S = length(labelsWithBlanks) if L + repeats > T return nothing end labels = labelsWithBlanks # Corner-case checking # hard-code 'start' to be equal to 0 here to make behavior same as CPU code above start = 0 #(L + repeats <= T) ? 0 : 1 last = S > 1 ? 2 : 1 # Fill in first row i = tid while i <= last - start alpha[start + i] = probs[labels[start + i]] i += blockDim().x end sync_threads() # Fill in coefficients for each time step for t=2:T startCurRow = (t-1) * S startPrevRow = (t-2) * S startProbCol = (t-1) * div(length(probs), T) # Corner-case checking if tid == 1 && !(1 < S - 2*(T-t) - 1) if start == 0 alpha[startCurRow + 1] = probs[startProbCol + blankLabel] + alpha[startPrevRow + 1] elseif start == 1 alpha[startCurRow + 1] = alpha[startPrevRow + 1] end end sync_threads() # Fill in coefficients for each label class in the target output sequence; # each thread will process the calculations for one class idx = tid+1 while idx <= S prevSum = log_plus_f(alpha[startPrevRow + idx], alpha[startPrevRow + idx-1]) if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2] prevSum = log_plus_f(prevSum, alpha[startPrevRow + idx-2]) end if idx < S - 2*(T-t) - 1 alpha[idx + startCurRow] = -Inf32 else alpha[startCurRow + idx] = prevSum + probs[startProbCol + labels[idx]] end idx += blockDim().x end sync_threads() end return nothing end function computeBetasAndGradKernel(probs, labelSize, uttLength, repeatsInLabel, labelsWithBlanks, alphas, beta, output, accum, grad, blankLabel) tid = threadIdx().x L = labelSize T = uttLength S = 2*L + 1 repeats = repeatsInLabel labels = labelsWithBlanks if (L+repeats) > T return nothing end # Corner-case checking start = S > 1 ? S-2 : 0 # hard-code 'last' to be equal to S here to make behavior same as CPU code above last = S #L + repeats < T ? S : S-1 sync_threads() startCurRow = (T-1)*S startProbCol = (T-1) * div(length(probs), T) i = tid # Calculate coefficients for last row, then determine alpha and beta product while i <= last - start beta[startCurRow + i + start] = 0 output[startCurRow + i + start] = beta[startCurRow + i + start] + alphas[startCurRow + i + start] i += blockDim().x end sync_threads() # Fill in `accum` for last row if tid == 1 startAccRow = startProbCol startOutputRow = startCurRow for i=1:S labelIdx = labels[i] accum[startAccRow + labelIdx] = log_plus_f(accum[startAccRow + labelIdx], output[startOutputRow + i]) @cuprintf("thread %d accum[%ld] = %f\n", tid, startAccRow+labelIdx, accum[startAccRow+labelIdx]) end end sync_threads() #@cuprintf("sync 1 thread: %ld\n", threadIdx().x) # Fill in `grad` for last row idx = tid while idx <= div(length(grad), T) # startProbRow = (T - 1) * div(length(probs), T) startOutputRow = (T - 1) * S s = -Inf32 for i=1:S s = log_plus_f(s, output[startOutputRow + i]) end # ∂L/∂a (where a is activation before logsoftmax) grad[startProbRow + idx] = CUDAnative.exp(probs[startProbRow + idx]) - CUDAnative.exp(accum[startProbRow + idx] - s) @cuprintf("thread %d grad[%ld]: %f accum[%ld]: %f\n", tid, startProbRow+idx, grad[startProbRow+idx], startProbRow+idx, accum[startProbRow+idx]) idx += blockDim().x end sync_threads() # Fill in the rest of the coefficients t = T-1 while t >= 1 startCurRow = (t-1)*S startNextRow = t*S startProbCol = t * div(length(probs), T) if t < T idx = tid while idx <= S-1 nextSum = log_plus_f(beta[startNextRow + idx] + probs[startProbCol + labels[idx]], beta[startNextRow + idx+1] + probs[startProbCol + labels[idx+1]]) if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2] nextSum = log_plus_f(nextSum, beta[startNextRow + idx + 2] + probs[startProbCol + labels[idx+2]]) end if idx > 2*t beta[idx + startCurRow] = -Inf32 else beta[idx + startCurRow] = nextSum end idx += blockDim().x end sync_threads() if tid == 1 && last == S beta[startCurRow + S] = beta[startNextRow + S] + probs[startProbCol + blankLabel] end sync_threads() idx = tid while idx <= S output[startCurRow + idx] = alphas[idx+startCurRow] + beta[startCurRow + idx] idx += blockDim().x end sync_threads() end sync_threads() # Calculate accumulated alpha-beta products for each label class for # each time step; used in calculating gradients if tid == 1 startAccRow = (t-1) * div(length(accum), T) startOutputRow = (t-1) * S for i=1:S labelIdx = labels[i] accum[startAccRow + labelIdx] = log_plus_f(accum[startAccRow + labelIdx], output[startOutputRow + i]) @cuprintf("t=%ld accum[%ld] = %f\n", t, startAccRow+labelIdx, accum[startAccRow+labelIdx]) end end sync_threads() #@cuprintf("sync 2 thread: %ld\n", threadIdx().x) idx = tid # Calculate gradients while idx <= div(length(grad), T) # startProbRow = (t - 1) * div(length(probs), T) startOutputRow = (t - 1) * S s = -Inf32 for i=1:S s = log_plus_f(s, output[startOutputRow + i]) end # ∂L/∂a (where a is activation before logsoftmax) grad[startProbRow + idx] = CUDAnative.exp(probs[startProbRow + idx]) - CUDAnative.exp(accum[startProbRow + idx] - s) @cuprintf("t=%ld thread %d grad[%ld] = %f\n", t, threadIdx().x, startProbRow+idx, grad[startProbRow+idx]) idx += blockDim().x end sync_threads() t -= 1 sync_threads() # because of course, it wouldn't work without this earlier return statement # otherwise, some of the gradient values become 0 t == 0 && return end return nothing end function ctc(ŷ::CuArrays.CuArray, y, nthreads) ŷ = logsoftmax(ŷ) blank = Int32(size(ŷ, 1)) labels = argmax.([y[i,:] for i=1:size(y,1)]) z = F(labels, blank) z′ = [blank] for label in z push!(z′, label) push!(z′, blank) end T = size(ŷ, 2) U′ = 2*length(z) + 1 alphas = gpu([-Inf32 for x in 1:(size(ŷ,2) * U′)]) betas = gpu([-Inf32 for x in 1:(size(ŷ,2) * U′)]) nRepeats = countRepeats(labels) @cuda blocks=1 threads=nthreads computeAlphaKernel(ŷ, length(z), size(ŷ,2), nRepeats, gpu(z), CuArrays.CuArray(z′), alphas, blank) grads = gpu([-Inf32 for x in 1:length(ŷ)]) output = gpu([-Inf32 for x in 1:(size(ŷ,2) * U′)]) accum = gpu([-Inf32 for x in 1:length(ŷ)]) @cuda blocks=1 threads=nthreads computeBetasAndGradKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArrays.CuArray(z′), alphas, betas, output, accum, grads, blank) ls = Array(reshape(Array(output), U′, T)') ls = -1 .* mapslices(logsum, ls; dims=2) gs = reshape(Array(grads), size(ŷ,1), size(ŷ,2)) ŷ = alphas = betas = output = accum = grads = nothing return mean(ls), gs end using Printf x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] y = [1 0 0; 1 0 0; 0 1 0] expected_loss = 3.6990738275138035 expected_grad = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] println("single-thread intermediate printouts:") l_single, gs_single = ctc(Flux.gpu(x), Flux.gpu(y), 1) println() println("multi-thread intermediate printouts:") l_multi, gs_multi = ctc(Flux.gpu(x), Flux.gpu(y), 3) l_diff = abs(l_single - l_multi) g_max_idx = argmax(abs.(gs_single .- gs_multi)) single_max_g = gs_single[g_max_idx] multi_max_g = gs_multi[g_max_idx] g_max_diff = abs(single_max_g - multi_max_g) #@printf("\nloss diff: %.2e\n", l_diff) #@printf("\tsingle-thread loss: %.5f\n", l_single) #@printf("\tmulti-thread loss: %.5f\n", l_multi) @printf("\nmax grad diff: %.2e\n", g_max_diff) println("single-thread grads: $gs_single") println(" multi-thread grads: $gs_multi")