-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Front page example broken #17
Comments
Seems like something broke in NNlibCUDA for 2D batchnorm. MWE: using Lux, NNlibCUDA, Zygote
x = randn(Float32, 2, 2) |> gpu;
z = randn(Float32, 2) |> gpu
gradient(sum ∘ NNlibCUDA.batchnorm, z, z, x, z, z, 0.1f0) @ToucheSir are you aware of any recent change that might break this? |
Not that I'm aware of. However, I don't think that MWE should work at all because NNlibCUDA does not define a rrule for its batchnorm function (this is awkwardly handled in Flux at present, but should really be in NNlib). |
I see. @maxfreu this should be fixed on |
Ah nice. But to me it looks like this should go into NNlib, indeed. |
FluxML/NNlib.jl#19 is nearing on a half-decade, so perhaps we should get it done before then 😆. PRs very much welcome. |
Hi! Thanks for this interesting work! I just tried the front page example and it turned out not to work for me. Taking the gradient fails with:
Package status output & version:
The text was updated successfully, but these errors were encountered: