-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to represent structural deriviatives #462
Comments
I initially agreed that this example was incorrect, but now I'm not so sure. The key point here is that the differentiation transform is really structured, so each adjoint value, at each point in the program, corresponds to exactly one primal value at the corresponding point. For Debugging was raised as another issue, but for the same reasons you can just do The only exception I can think of is if you have multiple adjoint types for the same primal type, which you need to |
Except at the end-point, and from values returned by custom adjoints.
You only know to add those hooks, if you know something is going wrong (and where). (* Aside: This semantically identicial but actually different is an argument that we mist support multiple differential types for same primal type, in general. But thats another topic) |
I don't think these are exceptions: the final gradients correspond to the inputs given to the function, the initial gradients correspond to the output, and the gradients produced by a custom adjoint correspond to the primal inputs, all of which you have access to at the relevant point. I agree that user error is possible when calling a pullback, and this could be worthwhile to prevent if it's likely in practice (and especially if it's silent or otherwise hard to debug). The edge case I can see is where you mistake what type of primal composite was produced, pull back a named tuple that has the same keys but the wrong internal structure, and then get an error down the line, potentially from a different adjoint. I'm not sure how likely this is in practice; people mostly use things like arrays and tuples as outputs, which have obvious adjoints. |
Handling the easy cases is easy, thus handling the hard-cases right is a mark of distinction and a thing worth spending effort on. |
Consider our current case.
The structural deriviatives here is
Δ::NamedTuple{(:parent,)
This gives correct behavour if the primal type matching that derivative was
Adjoint
.It gives incorrect behavour if the primal type is some other wrapper that also has a
.parent
field.Fortunately,
Adjoint
is an aberration of a wrapper type.Almost all other wrapper types call that field .data and use the
parent(...)
accessor method.But odds are somewhere there is one that uses
parent
as the field nameThe question of this issue is how should we represent structual differentials.
I am strongly of the opinion that we should use ChainRules'
Composite{Primal}
type.And if it is unsuitable that should be fixed in ChainRules.
I had been thinking of that change over as step 2 of integrating ChainRules, which is to change over to using ChainRule's types internally.
vs Step 1 #366 (changing over to using ChainRule's rules) would be done first.
But actually they are independent.
Using the types doesn't even require loading ChainRules -- just ChainRulesCore.
Note: this discussion is not about if we should be using structural differentials for array types in the first place. See JuliaDiff/ChainRulesCore.jl#85 and #445 for that.
I am sure we can find another example where there is no clear natural differential type.
The text was updated successfully, but these errors were encountered: