-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathchainrules.jl
21 lines (18 loc) · 983 Bytes
/
chainrules.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
_NamedDimsArray_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ, NoTangent())
_NamedDimsArray_pullback(ȳ::Tangent) = (NoTangent(), ȳ.data, NoTangent())
_NamedDimsArray_pullback(ȳ::AbstractThunk) = _NamedDimsArray_pullback(unthunk(ȳ))
function ChainRulesCore.rrule(::Type{<:NamedDimsArray}, values, names)
return NamedDimsArray(values, names), _NamedDimsArray_pullback
end
function ChainRulesCore.rrule(T::Type{<:NamedDimsArray}, values)
NamedDimsArray_values_pullback(ȳ) = _NamedDimsArray_pullback(ȳ)[1:2]
return T(values), NamedDimsArray_values_pullback
end
function ChainRulesCore.ProjectTo(x::NamedDimsArray)
return ProjectTo{NamedDimsArray{dimnames(x)}}(; data=ProjectTo(parent(x)))
end
(project::ProjectTo{NDA})(dx::AbstractZero) where {NDA<:NamedDimsArray} = dx
function (project::ProjectTo{NDA})(dx) where {NDA<:NamedDimsArray}
names = unify_names(dimnames(NDA), dimnames(dx))
return NamedDimsArray{names}(project.data(parent(dx)))
end