-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding CTC loss #342
Adding CTC loss #342
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this, and apologies for the long wait in reviewing. The main thing it needs now is some additions to the docs so that it's totally clear what this is for and how to use it.
src/layers/ctc.jl
Outdated
end | ||
|
||
@require CUDAnative begin | ||
@require CuArrays begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can just be @require CuArrays
, since CUDAnative will always be a dependency.
These blocks will also need an update for 1.0.
src/layers/ctc.jl
Outdated
|
||
@grad function ctc(ŷ, y) | ||
ls, gs = ctc(Flux.Tracker.data(ŷ), Flux.Tracker.data(y)) | ||
return ls, Δ -> (Δ .* gpu(gs), Δ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the gpu
call needed here? We should find a different way to do this.
test/layers/ctc.jl
Outdated
@require CuArrays begin | ||
using CuArrays | ||
lossvalue = 3.6990738 | ||
l, gs = ctc(Flux.gpu(x), Flux.gpu(y)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we test the CPU version as well, and make sure the outputs are the same?
Worth also having a numerical gradcheck? Might not be necessary if we have known-good values.
I should be able to get to this after a conference I'm going to this week. Doesn't look like it should be too much work; and I think it's a good idea to add it to the docs pages. |
CTC functions updated to work with current versions of Flux, CuArrays, and CUDAnative. GPU functions split into a new file to allow conditional loading. Test cases using gradchecks developed.
GPU threads desynchronized on call to div when calculating gradients. Changing div to CUDAnative.div_fast keeps the threads synchronized.
I wonder if the core kernel should be added to CuArrays so we can make sure that's well tested, and then we can add the interface parts to Flux separately. (@maleadt does that sound reasonable to you?) |
Unless there's an NNlib interface to implement with this functionality, I'd rather not. We should think about where to put functionality like that, and I think of CuArrays as implementing existing array-like interfaces for the GPU rather than housing "arbitrary" functionality. |
Seems fair. @maetshju what's the status on this, is it ready on your end, or are you still hacking on it? |
@MikeInnes the core functionality should be there. I am going to make a pass through all the code over the weekend to do a few more tests and make sure the documentation/comments are sufficient and correct, but I think it should be good to go after that. |
I believe this is ready to go. It consistently passes the tests I've written, and I've checked that the methods properly cover all the 16 possible cases of tracked/untracked cu/non-cu array combinations. The docs should be up to date as well. If it looks good on your end, would you like me to squash the commits together before merging? |
I'm going to do a quick run on the GPU also bors try |
tryBuild failed |
I'm not sure I understand why that previous build failed. It doesn't look like Docker got everything installed before failing. |
We've occasionally had CI flakiness, so let's just try again: bors try |
tryMerge conflict |
These last few commits fix the merge conflict, and I think they get everything working with Zygote. Admittedly, I'm not super familiar with Zygote yet, but the CPU tests are passing. I haven't had a chance to check on GPU yet though. |
Thanks a lot. Will give the GPU tests a quick go to see where we are: bors try |
tryBuild failed |
That most recent commit should fix the |
bors try |
tryBuild succeeded |
Manifest.toml
Outdated
@@ -74,6 +80,12 @@ git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" | |||
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" | |||
version = "0.2.0" | |||
|
|||
[[Conda]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary? I haven't followed the PR, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good catch. The dependency was induced by FFTW's version 1.1.0, which in turn was induced by its MKL_jll and FFTW_jll dependencies requiring Julia 1.3, while I was running 1.1. I upgraded my version of Julia and then updated packages, which should remove the dependency on Conda.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should still install on older versions of Julia now, since it can pull in the older version of FFTW that requires Conda. Or at least, I was able to install this new version from the pull branch on Julia 1.1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. Couldn't imagine FFTW would induce the Conda dependency indirectly, will keep that in mind!
starting a new PR all over may be easier than rebase |
@CarloLucibello Thanks for letting me know about those PRs. I think opening a new PR once the others have been merged is probably easier, as you suggested. I would love to get this committed; it's been lingering in my mind for a while. |
1287: Add CTC loss to new Losses module r=CarloLucibello a=maetshju This is a redux of adding the connectionist temporal classification loss from #342, now that the Losses module has been merged in #1264. Discussion in #342 suggested that a new PR would be easier than rebasing. Since the last commit in #342, functions and data structures from `CUDAnative.jl` and `CuArrays.jl` have been updated to work with `CUDA.jl`. This is in addition to incorporating the loss function into the Losses module. ### PR Checklist - [X] Tests are added - [X] Entry in NEWS.md - [X] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Matt Kelley <matthew.curtis.kelley@gmail.com> Co-authored-by: Matthew C. Kelley <matthew.curtis.kelley@gmail.com>
1287: Add CTC loss to new Losses module r=CarloLucibello a=maetshju This is a redux of adding the connectionist temporal classification loss from #342, now that the Losses module has been merged in #1264. Discussion in #342 suggested that a new PR would be easier than rebasing. Since the last commit in #342, functions and data structures from `CUDAnative.jl` and `CuArrays.jl` have been updated to work with `CUDA.jl`. This is in addition to incorporating the loss function into the Losses module. ### PR Checklist - [X] Tests are added - [X] Entry in NEWS.md - [X] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Matt Kelley <matthew.curtis.kelley@gmail.com> Co-authored-by: Matthew C. Kelley <matthew.curtis.kelley@gmail.com>
Superseded by #1287. |
This is the CTC loss function and associated tests. it should interface with the tracker just like the other loss functions, and it uses the
@require
macro to allow the GPU functionality to be optional.