Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding CTC loss #342

Closed
wants to merge 25 commits into from
Closed

Adding CTC loss #342

wants to merge 25 commits into from

Conversation

maetshju
Copy link
Contributor

@maetshju maetshju commented Aug 7, 2018

This is the CTC loss function and associated tests. it should interface with the tracker just like the other loss functions, and it uses the @require macro to allow the GPU functionality to be optional.

Copy link
Member

@MikeInnes MikeInnes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this, and apologies for the long wait in reviewing. The main thing it needs now is some additions to the docs so that it's totally clear what this is for and how to use it.

end

@require CUDAnative begin
@require CuArrays begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can just be @require CuArrays, since CUDAnative will always be a dependency.

These blocks will also need an update for 1.0.


@grad function ctc(ŷ, y)
ls, gs = ctc(Flux.Tracker.data(ŷ), Flux.Tracker.data(y))
return ls, Δ -> (Δ .* gpu(gs), Δ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the gpu call needed here? We should find a different way to do this.

@require CuArrays begin
using CuArrays
lossvalue = 3.6990738
l, gs = ctc(Flux.gpu(x), Flux.gpu(y))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we test the CPU version as well, and make sure the outputs are the same?

Worth also having a numerical gradcheck? Might not be necessary if we have known-good values.

@maetshju
Copy link
Contributor Author

I should be able to get to this after a conference I'm going to this week. Doesn't look like it should be too much work; and I think it's a good idea to add it to the docs pages.

@jtmatamalas jtmatamalas mentioned this pull request Nov 12, 2018
CTC functions updated to work with current versions of Flux, CuArrays, and CUDAnative.
GPU functions split into a new file to allow conditional loading. Test cases using
gradchecks developed.
GPU threads desynchronized on call to div when calculating
gradients. Changing div to CUDAnative.div_fast keeps the
threads synchronized.
@MikeInnes
Copy link
Member

I wonder if the core kernel should be added to CuArrays so we can make sure that's well tested, and then we can add the interface parts to Flux separately. (@maleadt does that sound reasonable to you?)

@maleadt
Copy link
Collaborator

maleadt commented Sep 5, 2019

Unless there's an NNlib interface to implement with this functionality, I'd rather not. We should think about where to put functionality like that, and I think of CuArrays as implementing existing array-like interfaces for the GPU rather than housing "arbitrary" functionality.

@MikeInnes
Copy link
Member

Seems fair.

@maetshju what's the status on this, is it ready on your end, or are you still hacking on it?

@maetshju
Copy link
Contributor Author

maetshju commented Sep 5, 2019

@MikeInnes the core functionality should be there. I am going to make a pass through all the code over the weekend to do a few more tests and make sure the documentation/comments are sufficient and correct, but I think it should be good to go after that.

@maetshju
Copy link
Contributor Author

maetshju commented Sep 9, 2019

I believe this is ready to go. It consistently passes the tests I've written, and I've checked that the methods properly cover all the 16 possible cases of tracked/untracked cu/non-cu array combinations. The docs should be up to date as well.

If it looks good on your end, would you like me to squash the commits together before merging?

@DhairyaLGandhi
Copy link
Member

I'm going to do a quick run on the GPU also

bors try

bors bot added a commit that referenced this pull request Sep 9, 2019
@bors
Copy link
Contributor

bors bot commented Sep 9, 2019

try

Build failed

@maetshju
Copy link
Contributor Author

I'm not sure I understand why that previous build failed. It doesn't look like Docker got everything installed before failing.

@MikeInnes
Copy link
Member

We've occasionally had CI flakiness, so let's just try again:

bors try

@bors
Copy link
Contributor

bors bot commented Jan 15, 2020

try

Merge conflict

@maetshju
Copy link
Contributor Author

These last few commits fix the merge conflict, and I think they get everything working with Zygote. Admittedly, I'm not super familiar with Zygote yet, but the CPU tests are passing. I haven't had a chance to check on GPU yet though.

@MikeInnes
Copy link
Member

Thanks a lot. Will give the GPU tests a quick go to see where we are:

bors try

bors bot added a commit that referenced this pull request Jan 20, 2020
@bors
Copy link
Contributor

bors bot commented Jan 20, 2020

try

Build failed

@maetshju
Copy link
Contributor Author

That most recent commit should fix the UndefVarError by using CuArrays.functional(). It's passing tests in my GPU environment, so hopefully it works with the CI.

@MikeInnes
Copy link
Member

bors try

bors bot added a commit that referenced this pull request Jan 21, 2020
@bors
Copy link
Contributor

bors bot commented Jan 21, 2020

try

Build succeeded

Manifest.toml Outdated
@@ -74,6 +80,12 @@ git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"

[[Conda]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? I haven't followed the PR, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good catch. The dependency was induced by FFTW's version 1.1.0, which in turn was induced by its MKL_jll and FFTW_jll dependencies requiring Julia 1.3, while I was running 1.1. I upgraded my version of Julia and then updated packages, which should remove the dependency on Conda.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should still install on older versions of Julia now, since it can pull in the older version of FFTW that requires Conda. Or at least, I was able to install this new version from the pull branch on Julia 1.1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. Couldn't imagine FFTW would induce the Conda dependency indirectly, will keep that in mind!

@CarloLucibello
Copy link
Member

@maetshju it is a pity if this gets lost. There are two changes impacting this PRs, #1258 and #1264 . If you will have the patience to rebase once #1264 gets in, we can have a quick review and merge

@CarloLucibello
Copy link
Member

starting a new PR all over may be easier than rebase

@maetshju
Copy link
Contributor Author

maetshju commented Jul 2, 2020

@CarloLucibello Thanks for letting me know about those PRs. I think opening a new PR once the others have been merged is probably easier, as you suggested. I would love to get this committed; it's been lingering in my mind for a while.

@maetshju maetshju mentioned this pull request Jul 21, 2020
4 tasks
bors bot added a commit that referenced this pull request Jan 20, 2021
1287: Add CTC loss to new Losses module r=CarloLucibello a=maetshju

This is a redux of adding the connectionist temporal classification loss from #342, now that the Losses module has been merged in #1264. Discussion in #342 suggested that a new PR would be easier than rebasing.

Since the last commit in #342, functions and data structures from `CUDAnative.jl` and `CuArrays.jl` have been updated to work with `CUDA.jl`. This is in addition to incorporating the loss function into the Losses module.

### PR Checklist

- [X] Tests are added
- [X] Entry in NEWS.md
- [X] Documentation, if applicable
- [ ] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: Matt Kelley <matthew.curtis.kelley@gmail.com>
Co-authored-by: Matthew C. Kelley <matthew.curtis.kelley@gmail.com>
bors bot added a commit that referenced this pull request Jan 20, 2021
1287: Add CTC loss to new Losses module r=CarloLucibello a=maetshju

This is a redux of adding the connectionist temporal classification loss from #342, now that the Losses module has been merged in #1264. Discussion in #342 suggested that a new PR would be easier than rebasing.

Since the last commit in #342, functions and data structures from `CUDAnative.jl` and `CuArrays.jl` have been updated to work with `CUDA.jl`. This is in addition to incorporating the loss function into the Losses module.

### PR Checklist

- [X] Tests are added
- [X] Entry in NEWS.md
- [X] Documentation, if applicable
- [ ] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: Matt Kelley <matthew.curtis.kelley@gmail.com>
Co-authored-by: Matthew C. Kelley <matthew.curtis.kelley@gmail.com>
@maetshju
Copy link
Contributor Author

Superseded by #1287.

@maetshju maetshju closed this Jan 20, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants