-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Representing (co)tangents of structured matrices #191
Comments
i.e. this would be the entirety of the function frule((_, ΔA, _), T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
Y = T(A, uplo)
return Y, Composite{typeof(Y)}(data = ΔA)
end
function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
HermOrSym_pullback(ΔY) = Composite{typeof(A)}(ΔY.data)
return T(A, uplo), HermOrSym_pullback
end What these don't cover are the cases where one calls |
Here's an example of a function whose first operation is julia> using FiniteDifferences
julia> x = collect(reshape(1.0:9.0, 3, 3))
3×3 Array{Float64,2}:
1.0 4.0 7.0
2.0 5.0 8.0
3.0 6.0 9.0
julia> function g(x)
s = Symmetric(x)
d = s.data
h = s * d
z = h[1, 2]
return z
end
g (generic function with 1 method)
julia> only(j′vp(central_fdm(5, 1), g, 1.0, x))
3×3 Array{Float64,2}:
4.0 6.0 6.0
0.0 4.0 0.0
0.0 7.0 0.0 |
Yeah, it's just by virtue of the accessing of the |
Yeah, so we have two cases:
Because the # when a matrix sensitivity is added to a `Composite{<:Symmetric}`, first pull it back
# to the correct triangle of `.data`, then combine.
function Base.:+(a::P, b::Composite{P}) where {P<:Symmetric}
return Composite{P}(data=_symmetric_back(a)) + b
end
function Base.:+(a::AbstractMatrix, b::Composite{P}) where {P<:Symmetric}
return Composite{P}(data=_symmetric_back(a)) + b
end
function rrule(::Type{<:Symmetric}, A::AbstractMatrix)
function Symmetric_pullback(ȳ)
return (NO_FIELDS, @thunk(_symmetric_back(ȳ)))
end
return Symmetric(A), Symmetric_pullback
end
# If no composites were accumulated, pull back to used triangle and return
_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)
# If any composites, pullback has already been called, so just return
_symmetric_back(ΔΩ::Composite{<:Symmetric}) = ΔΩ.data One problem with this approach is that the information about which triangle is storing the data is not contained in the type of Is there anything like a |
This follows up a discussion on Slack.
@sethaxen:
@ChrisRackauckas:
@sethaxen:
@willtebbutt:
There are several points for discussion here. Under this perspective (@willtebbutt's newer thinking, which I tend to agree with), if downstream rules and AD have done everything right, then the pullback for
Y = Symmetric(A)
should always receive an objectΔY
with adata
field (eitherComposite{<:Symmetric}
orSymmetric
), and its pullback should just beΔY.data
. If the pullback is passed anUpperTriangular
,LowerTriangular
orDiagonal
matrix, as in the currentrrule
implementation and as in #178, then something is wrong somewhere else. Moreover, we don't need to do anything todata
, such as zeroing a triangle, because if that triangle should be zeroed, it is already zeroed inΔY.data
(e.g. ifMatrix(Y)
was called, then the unused triangle was overwritten by the used triangle in the forward pass. Consequently, a correctly implementedMatrix_pullback
will zero out the unused triangle in the cotangent vector before wrapping withSymmetric
orComposite{Symmetric}
). Thus a custom rule is probably not even necessary for theSymmetric
constructor. Have I got this right?One thing that worries me is e.g. what if a user defined an override like
(::Diagonal * ::MyDiagonal)::MyDiagonal
. This would trigger our generalrrule
. The pullback would expect anAbstractMatrix
, but it will be passed aComposite{MyDiagonal}
(orMyDiagonal
). To do the right thing, it would need to produce aComposite{Diagonal}
(orDiagonal
) and aComposite{MyDiagonal}
(orMyDiagonal
). So how should we define generic rules that handle such cases?Also, should we adopt a convention regarding whether the (co)tangent of structured matrices should be matrices or
Composite
? A point for the former is that we can automatically multiply them by other matrices and add them, and things should just work. A point for the latter is that in many cases, the (co)tangent doesn't share the same structure as the primal (e.g. the (co)tangent of a unitary matrix is a unitary transformation of a skew-Hermitian matrix). A compromise is a utility method that in most cases is a no-op but is meant to convert from a composite type to a primal when possible.Relates #52
cc @mcabbott @oxinabox
The text was updated successfully, but these errors were encountered: