Skip to content

Commit

Permalink
more extra rules for static arrays
Browse files Browse the repository at this point in the history
more overloads for StaticArrays
  • Loading branch information
oxinabox committed Jan 19, 2024
1 parent e1c7c7e commit 09589cd
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,14 @@ function ChainRules.rrule(::DiffractorRuleConfig, ::Type{SArray{S, T, N, L}}, x:
end

function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L}
SArray{S, T, N, L}(x), SArray{S, T, N, L}(∂x.backing)
#TODO: we really shouldn't actually see the isa(∂x, AbstractZero) case since the frule should be called then
Δx = isa(∂x, AbstractZero) ? ∂x : SArray{S, T, N, L}(ChainRulesCore.backing(∂x))
SArray{S, T, N, L}(x), Δx
end

Base.view(t::Tangent{T}, inds) where T<:SVector = view(T(ChainRulesCore.backing(t.data)), inds)
Base.getindex(t::Tangent{<:SVector, <:NamedTuple}, ind::Int) = ChainRulesCore.backing(t.data)[ind]

function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L}
SArray{S, T, N, L}(x), SArray{S}(∂x)
end
Expand Down

0 comments on commit 09589cd

Please sign in to comment.