diff --git a/src/extra_rules.jl b/src/extra_rules.jl index b9bcff7e..303419cd 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -179,8 +179,14 @@ end Base.view(t::Tangent{T}, inds) where T<:SVector = view(T(ChainRulesCore.backing(t.data)), inds) Base.getindex(t::Tangent{<:SVector, <:NamedTuple}, ind::Int) = ChainRulesCore.backing(t.data)[ind] -function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L} - SArray{S, T, N, L}(x), SArray{S}(∂x) +function ChainRules.frule( + (_, ∂x)::Tuple{Any, Tangent{TUP}}, + ::Type{SArray{S, T, N, L}}, + x::TUP, +) where {L, TUP<:NTuple{L, Number}, S, T<:Number, N} + y = SArray{S, T, N, L}(x) + ∂y = SArray{S, T, N, L}(ChainRulesCore.backing(∂x)) + return y, ∂y end @ChainRulesCore.non_differentiable StaticArrays.promote_tuple_eltype(T) diff --git a/test/extra_rules.jl b/test/extra_rules.jl new file mode 100644 index 00000000..2b860048 --- /dev/null +++ b/test/extra_rules.jl @@ -0,0 +1,35 @@ +using Diffractor +using StaticArrays +using ChainRulesCore +using Test + +@testset "StaticArrays constructor" begin + #frule(::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.Tangent{Tuple{Int64, Vararg{Float64, 9}}, Tuple{Int64, Vararg{Float64, 9}}}}, ::Type{StaticArraysCore.SVector{10, Float64}}, x::Tuple{Int64, Vararg{Float64, 9}}) + # @ Diffractor ~/.julia/packages/Diffractor/yCsbI/src/extra_rules.jl:183 + + @testset "homogenious type" begin + x = (10.0, 20.0, 30.0) + ẋ = zero_tangent(x) + y, ẏ = frule((NoTangent(), ẋ), StaticArraysCore.SVector{3, Float64}, x) + @test y == @SVector [10.0, 20.0, 30.0] + @test ẏ == @SVector [0.0, 0.0, 0.0] + end + + @testset "convertable type" begin + x::Tuple{Int, Float64, Float64} = (10, 20.0, 30.0) + ẋ = zero_tangent(x) + y, ẏ = frule((NoTangent(), ẋ), StaticArraysCore.SVector{3, Float64}, x) + # all are float + @test y == @SVector [10.0, 20.0, 30.0] + @test ẏ == @SVector [0.0, 0.0, 0.0] + end + + @testset "convertable type with ZeroTangent()" begin + x = (10, 20.0, 30.0) + ẋ = Tangent{typeof(x)}(ZeroTangent(), 1.0, 2.0) + y, ẏ = frule((NoTangent(), ẋ), StaticArraysCore.SVector{3, Float64}, x) + # all are float + @test y == @SVector [10.0, 20.0, 30.0] + @test ẏ == @SVector [0.0, 1.0, 2.0] + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 0acd3416..01cbc825 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,7 @@ const bwd = Diffractor.PrimeDerivativeBack @testset verbose=true "Diffractor.jl" begin # overall testset, ensures all tests run @testset "$file" for file in ( + "extra_rules.jl" "stage2_fwd.jl", "tangent.jl", "forward_diff_no_inf.jl",