Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Map experimental C (actually C++) API for gradient tape #283

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

saudet
Copy link
Contributor

@saudet saudet commented Apr 9, 2021

I've only tested the build on Linux for now, but it should also work on Mac and Windows.

Note that the API has already changed with 2.5.0, so we should probably upgrade to that version before looking at this too closely.

@karllessard
Copy link
Collaborator

@saudet , is this PR still just a draft or you think it is ready to be reviewed and merged?

@saudet
Copy link
Contributor Author

saudet commented Apr 25, 2021

Since it doesn't look like we're going to do anything for this with TF 2.4.x, I think I'll upgrade this PR to 2.5.0-rc1 and then we can merge after its release? I don't think it makes sense to start doing something with the API for 2.4.x.

@saudet
Copy link
Contributor Author

saudet commented Jun 25, 2021

I've finally rebased this on master and upgraded for TF 2.5.0! I've also undone the unreadable reformatting of presets/tensorflow.java, but feel free to redo if necessary. I'd still consider this a WIP, but if it doesn't break any builds, it should be fine to merge and start getting people playing with it, as long as we're ready to maintain an unstable experimental API....

@dosier
Copy link

dosier commented Nov 16, 2021

Hey, is this still being worked on @saudet ? This is the missing ingredient for me to implement a RL model in KotlinDL.
If it is not being worked on anymore, what is left to do? Maybe I could help with it.

@rnett
Copy link
Contributor

rnett commented Nov 16, 2021

As far as I know we're waiting on full support in Tensorflow. The RFC is here and there seems to be work happening gradually, but it's not to the point where we can use it as a full solution for gradients, I don't think. The actual gradients are here, and as you can see there's quite a few missing, and there's currently no registration method or anything like that.

@saudet
Copy link
Contributor Author

saudet commented Nov 16, 2021

@dosier FWIW, you may be better off with PyTorch.
Its C++ API is apparently rich enough to get anything RL going:
https://github.com/Omegastick/pytorch-cpp-rl
https://github.com/navneet-nmk/Pytorch-RL-CPP
https://github.com/mhubii/ppo_libtorch

And the JavaCPP Presets for PyTorch provides full access to that C++ API:
https://github.com/bytedeco/javacpp-presets/tree/master/pytorch
Please do let me know if there's anything missing though!

@dosier
Copy link

dosier commented Nov 16, 2021

@dosier FWIW, you may be better off with PyTorch. Its C++ API is apparently rich enough to get anything RL going: https://github.com/Omegastick/pytorch-cpp-rl https://github.com/navneet-nmk/Pytorch-RL-CPP https://github.com/mhubii/ppo_libtorch

And the JavaCPP Presets for PyTorch provides full access to that C++ API: https://github.com/bytedeco/javacpp-presets/tree/master/pytorch Please do let me know if there's anything missing though!

Cheers! Been hoping to see the GradientTape integrated into the TF Java API for a while, mainly so that I can contribute RL stuff in KotlinDL :D. But I also need to implement a RL model for my studies this block so the PyTorch java wrapper is a pleasant surprise (I love my static types too much).

@saudet
Copy link
Contributor Author

saudet commented Nov 16, 2021

Cheers! Been hoping to see the GradientTape integrated into the TF Java API for a while, mainly so that I can contribute RL stuff in KotlinDL :D. But I also need to implement a RL model for my studies this block so the PyTorch java wrapper is a pleasant surprise (I love my static types too much).

BTW, it looks like the author of KotlinDL would be open to integrating PyTorch as well, see pytorch/pytorch#58973 (comment).
However, I'm guessing he would like to get Facebook and/or Microsoft to cooperate a bit before doing anything with it.

/cc @zaleslaw

@zaleslaw
Copy link
Contributor

Yeah, there are a few ways to integrate Torch in Kotlin

  1. JNI
  2. JavaCPP
  3. Pure PyTorch Java API with IValue for inference only

Hope that in 2022 KotlinDL will be able to support the training of Torch models via JNI (or via JavaCPP)

Good luck, @dosier with your experiments with RL and hope to see you in the future with the running RL models

@saudet
Copy link
Contributor Author

saudet commented Nov 16, 2021

@zaleslaw May I ask why you're considering writing JNI code manually? What's missing from JavaCPP?

@karllessard
Copy link
Collaborator

As far as I know we're waiting on full support in Tensorflow. The RFC is here and there seems to be work happening gradually, but it's not to the point where we can use it as a full solution for gradients, I don't think. The actual gradients are here, and as you can see there's quite a few missing, and there's currently no registration method or anything like that.

@rnett , you are saying that with the custom gradient supported you've added not long ago plus this new internal API, we are still not able to register our own gradients in eager mode?

@rnett
Copy link
Contributor

rnett commented Nov 17, 2021

As far as I know we're waiting on full support in Tensorflow. The RFC is here and there seems to be work happening gradually, but it's not to the point where we can use it as a full solution for gradients, I don't think. The actual gradients are here, and as you can see there's quite a few missing, and there's currently no registration method or anything like that.

@rnett , you are saying that with the custom gradient supported you've added not long ago plus this new internal API, we are still not able to register our own gradients in eager mode?

I'm saying that for this new API, there's no built-in registries (i.e. global, or Graph/EagerSession based), so we would have to create and manage our own. Once we do that, it would be easy enough to add custom Java-side gradients.

I haven't seen confirmation anywhere that some sort of registry and auto-registration is planed, but I would expect it. I'm not sure how python does it, or if it's using this setup at all.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants