Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux Integration #26

Open
MikeInnes opened this issue Feb 26, 2019 · 2 comments
Open

Flux Integration #26

MikeInnes opened this issue Feb 26, 2019 · 2 comments

Comments

@MikeInnes
Copy link

I'm curious if you'd be interested in making Yota compatible with Flux layers and optimisers; then Yota could be used in place of Tracker for models without control flow.

Zygote does this by inserting calls to unwrap which strip away Flux tracking (this of course won't be necessary when we get rid of Tracker).

@dfdx
Copy link
Owner

dfdx commented Feb 28, 2019

Yes, I thought about it, but I don't see a straightforward way to do it just yet - unwrapping tracked data helps in finding gradients, but how do you use them after that?

For example, let's take _update_params!(opt, xs) - it expects each x to have .data and .grad properties, e.g. be a Tracked* type. Zygote by itself doesn't use tracked data, do you just record found gradients back to original tracked variables? If so, do you have a pointer to the relevant piece of code?

@MikeInnes
Copy link
Author

I spent some time refactoring this stuff for exactly this reason. We now have this version of update which only requires a param -> grad mapping. It should be pretty easy to get that out of Yota.

The Param API is somewhat transitional, it represents a compromise between what Flux and Yota/Zygote can expose, but the idea is to eventually get rid of it in favour of more functional optimisers. Once that happens it should be possible to use Yota's native API with Flux.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants