Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to represent structural deriviatives #462

Open
oxinabox opened this issue Jan 15, 2020 · 4 comments
Open

How to represent structural deriviatives #462

oxinabox opened this issue Jan 15, 2020 · 4 comments

Comments

@oxinabox
Copy link
Member

oxinabox commented Jan 15, 2020

Consider our current case.

@adjoint function Base.adjoint(x)
  back(Δ) = (Δ',)
  back(Δ::NamedTuple{(:parent,)}) = (Δ.parent,)
  return x', back
end

The structural deriviatives here is Δ::NamedTuple{(:parent,)
This gives correct behavour if the primal type matching that derivative was Adjoint.
It gives incorrect behavour if the primal type is some other wrapper that also has a .parent field.

Fortunately, Adjoint is an aberration of a wrapper type.
Almost all other wrapper types call that field .data and use the parent(...) accessor method.
But odds are somewhere there is one that uses parent as the field name

The question of this issue is how should we represent structual differentials.
I am strongly of the opinion that we should use ChainRules' Composite{Primal} type.
And if it is unsuitable that should be fixed in ChainRules.

I had been thinking of that change over as step 2 of integrating ChainRules, which is to change over to using ChainRule's types internally.
vs Step 1 #366 (changing over to using ChainRule's rules) would be done first.
But actually they are independent.
Using the types doesn't even require loading ChainRules -- just ChainRulesCore.

Note: this discussion is not about if we should be using structural differentials for array types in the first place. See JuliaDiff/ChainRulesCore.jl#85 and #445 for that.
I am sure we can find another example where there is no clear natural differential type.

@MikeInnes
Copy link
Member

I initially agreed that this example was incorrect, but now I'm not so sure. The key point here is that the differentiation transform is really structured, so each adjoint value, at each point in the program, corresponds to exactly one primal value at the corresponding point. For Adjoint that's either an array or a namedtuple, and the namedtuple is unambiguous about corresponding to Adjoint since if it weren't, it wouldn't turn up here in the first place. If an adjoint for a Transpose is ever passed to this pullback, something else has gone seriously wrong.

Debugging was raised as another issue, but for the same reasons you can just do hook(x, x̄ -> @show(x, x̄)) and see not only the primal type but the specific primal value corresponding to that adjoint. In other words, primal type information is always redundant; you have it already.

The only exception I can think of is if you have multiple adjoint types for the same primal type, which you need to accum, and you've used the same types for other primal types, and the rule for accum is different depending on the primal type. But this seems like it'd be such a convoluted abuse of the system that I don't think I'd want to support it anyway; there's always going to be a better way.

@oxinabox
Copy link
Member Author

oxinabox commented Jan 15, 2020

The key point here is that the differentiation transform is really structured, so each adjoint value, at each point in the program, corresponds to exactly one primal value at the corresponding point.

Except at the end-point, and from values returned by custom adjoints.
The user could give a seed or custom adjoint value that is either semantically the wrong type because they are bad,
or semantically right, but in some way augmented, e.g. something that is supposed to act the same as the original differential type*, but has say Named Dimension or runs on GPU etc.
I don't have a great example for structs there, maybe something to do with remote refererence, or forward-mode AD.
And the user may or may not have done it correctly.
Which brings us to:

Debugging was raised as another issue, but for the same reasons you can just do hook(x, x̄ -> @show(x, x̄))

You only know to add those hooks, if you know something is going wrong (and where).
The big program with a lot of AD related bugs is you don't notice when something is going wrong.
(e.g. dropped gradients on say a layer bypass or attention component in a NN, but main connection is still working so network still trains but worse).
And even when you do notice, you often have trouble working out where.
And I don't want to see every intermidate gradient in my program.

(* Aside: This semantically identicial but actually different is an argument that we mist support multiple differential types for same primal type, in general. But thats another topic)

@MikeInnes
Copy link
Member

Except at the end-point, and from values returned by custom adjoints.

I don't think these are exceptions: the final gradients correspond to the inputs given to the function, the initial gradients correspond to the output, and the gradients produced by a custom adjoint correspond to the primal inputs, all of which you have access to at the relevant point.

I agree that user error is possible when calling a pullback, and this could be worthwhile to prevent if it's likely in practice (and especially if it's silent or otherwise hard to debug). The edge case I can see is where you mistake what type of primal composite was produced, pull back a named tuple that has the same keys but the wrong internal structure, and then get an error down the line, potentially from a different adjoint. I'm not sure how likely this is in practice; people mostly use things like arrays and tuples as outputs, which have obvious adjoints.

@oxinabox
Copy link
Member Author

Handling the easy cases is easy, thus handling the hard-cases right is a mark of distinction and a thing worth spending effort on.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants