-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: return full gradient of all arguments in gradient #535
Comments
What if I do not what the gradient respect to |
What if I don't want the argument with respect to I am gently pro this. |
I think my main point here is to be consistent and transparent with internals, this will make things much simple, I have a few reasons for this change: Firstly, returning the full gradient and using something like Secondly, this will simplify the logic inside the internals, and make it easier to understand things. Lastly, the semantic of this |
my point is I don't want to waste computation in computing gradients I don't need. Sometimes I need the gradient of |
@CarloLucibello No, Zygote will always compute all the gradient no matter you need it or not, since what Zygote does here is just simply compose different The pullback of matrix multiplication is defined as https://github.com/FluxML/Zygote.jl/blob/master/src/lib/array.jl#L291 @adjoint function(A::AbstractMatrix * B::AbstractMatrix)
return A * B, Δ::AbstractMatrix->(Δ * B', A' * Δ)
end when you call the backward pass of this pullback, both value will be calculated, and I don't think Julia compiler will always drop the unused value, since the way Zygote drop one of it is to simply use https://github.com/FluxML/Zygote.jl/blob/master/src/compiler/interface.jl#L31 If you benchmark it, this is actually can be a bit slower and more allocation due to the use of julia> A = rand(100, 100); B = rand(100, 100);
julia> _, back = Zygote._pullback((A, B)->sum(A * B), A, B)
(249662.01236290898, ∂(#37))
julia> @benchmark back(1.0)
BenchmarkTools.Trial:
memory estimate: 234.98 KiB
allocs estimate: 14
--------------
minimum time: 879.487 μs (0.00% GC)
median time: 988.103 μs (0.00% GC)
mean time: 1.037 ms (3.23% GC)
maximum time: 8.125 ms (87.60% GC)
--------------
samples: 4805
evals/sample: 1
julia> @benchmark Base.tail(back(1.0))
BenchmarkTools.Trial:
memory estimate: 235.02 KiB
allocs estimate: 15
--------------
minimum time: 879.982 μs (0.00% GC)
median time: 988.425 μs (0.00% GC)
mean time: 1.044 ms (3.26% GC)
maximum time: 8.084 ms (87.52% GC)
--------------
samples: 4777
evals/sample: 1 The best solution to avoid such calculation so far, I think is to make use of |
thanks for the explanation, that wasn't clear to me. So, when we'll have thunks, how would you compute gradients only for selected parameters according to your proposal? |
With thunks, anything not |
still, I think |
you specify via Back on topic: point of mentioning thunking at all what that right now gradient does compute all derivatives, even though it doesn't return them all. |
I found another case related, current API prevents us from defining the gradient for the callable objects since the first argument will be ignored, e.g struct Linear
W
b
end
@adjoint function (::Linear)(x)
function pullback(y)
grad_x
end
return Linear(x), pullback
end but since this would add a nothing for this adjoins, we are not able to define the gradient of |
@Roger-luo that syntax already does what you want. It's also unrelated to the Thunking is relevant, though, since once we have that ability, we'll have to provide some way to communicate what not to calculate. Currently that's done by closing over variables rather than passing them (even though it doesn't actually improve performance as yet, it could). With this change we'd have to expose The main advantage of the current API is that you can write things like The original motiviating example is not very convincing, since in general you actually write something like |
One thing we might think about is something like returning a There is a related issue JuliaDiff/ChainRulesCore.jl#121 |
Yeah in the context of "function"s, it kinda makes sense, I'll close this issue then. |
I feel it should return the gradient of each argument when FluxML/Flux.jl#1073 is merged since this will allow one to use optimizers directly on structures (maybe also related to FluxML/Flux.jl#637 ) it would be more convenient to just return the gradient of all arguments since we could have
(which is actually my case, the output of the model is a probability and I need to use a similar but more complicated code to do policy gradient), currently one has to workaround this by
gradient((m, x)->m(x), m, x)
which I find is less convenient... The only thing needs to change is the following line thoZygote.jl/src/compiler/interface.jl
Line 34 in b0ea130
which also simplifies the logic of
pullback
in Zygote side a bit I think and there is no need to have_pullback
andpullback
The text was updated successfully, but these errors were encountered: