From 94f9bc5d82df013f040d4cad6eaa6da724ba122b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 16 Oct 2024 09:29:40 +0200 Subject: [PATCH] Simplify static test scenarios (#581) * Simplify static scenarios * No conversion * Exclude derivative for Zygote * Fix JLArrays * Unskip --- ...DifferentiationInterfaceTestJLArraysExt.jl | 12 +++++++- ...erentiationInterfaceTestStaticArraysExt.jl | 29 +------------------ .../src/scenarios/default.jl | 4 +-- 3 files changed, 14 insertions(+), 31 deletions(-) diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl index 4db2a7bb6..2f0afbc58 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl @@ -3,7 +3,7 @@ module DifferentiationInterfaceTestJLArraysExt import DifferentiationInterface as DI using DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT -using JLArrays: JLArray, jl +using JLArrays: JLArray, JLVector, JLMatrix, jl using Random: AbstractRNG, default_rng myjl(f::Function) = f @@ -11,6 +11,16 @@ function myjl(::DIT.NumToArr{A}) where {T,N,A<:AbstractArray{T,N}} return DIT.NumToArr(JLArray{T,N}) end +function (f::DIT.NumToArr{JLVector{T}})(x::Number) where {T} + a = JLVector{T}(Vector(1:6)) # avoid mutation + return sin.(x .* a) +end + +function (f::DIT.NumToArr{JLMatrix{T}})(x::Number) where {T} + a = JLMatrix{T}(Matrix(reshape(1:6, 2, 3))) # avoid mutation + return sin.(x .* a) +end + myjl(f::DIT.MultiplyByConstant) = f myjl(f::DIT.WritableClosure) = f diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl index 0e7f02d4b..0b8d0d9d2 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl @@ -8,19 +8,10 @@ using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector mySArray(f::Function) = f -myMArray(f::Function) = f - mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(SVector{6,T}) -myMArray(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(MVector{6,T}) - mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(SMatrix{2,3,T,6}) -myMArray(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(MMatrix{2,3,T,6}) - mySArray(f::DIT.MultiplyByConstant) = f -myMArray(f::DIT.MultiplyByConstant) = f - mySArray(f::DIT.WritableClosure) = f -myMArray(f::DIT.WritableClosure) = f mySArray(x::Number) = x myMArray(x::Number) = x @@ -36,13 +27,8 @@ function myMArray(x::AbstractMatrix{T}) where {T} end mySArray(x::Tuple) = map(mySArray, x) -myMArray(x::Tuple) = map(myMArray, x) - mySArray(x::DI.Constant) = DI.Constant(mySArray(DI.unwrap(x))) -myMArray(x::DI.Constant) = DI.Constant(myMArray(DI.unwrap(x))) - mySArray(::Nothing) = nothing -myMArray(::Nothing) = nothing function mySArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} (; f, x, y, tang, contexts, res1, res2) = scen @@ -57,22 +43,9 @@ function mySArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} ) end -function myMArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} - (; f, x, y, tang, contexts, res1, res2) = scen - return Scenario{op,pl_op,pl_fun}( - myMArray(f); - x=myMArray(x), - y=pl_fun == :in ? myMArray(y) : myMArray(y), - tang=myMArray(tang), - contexts=myMArray(contexts), - res1=myMArray(res1), - res2=myMArray(res2), - ) -end - function DIT.static_scenarios(args...; kwargs...) scens = DIT.default_scenarios(args...; kwargs...) - return vcat(mySArray.(scens), myMArray.(scens)) + return mySArray.(scens) end end diff --git a/DifferentiationInterfaceTest/src/scenarios/default.jl b/DifferentiationInterfaceTest/src/scenarios/default.jl index 74506155e..a6c9f8c3b 100644 --- a/DifferentiationInterfaceTest/src/scenarios/default.jl +++ b/DifferentiationInterfaceTest/src/scenarios/default.jl @@ -71,8 +71,8 @@ end ## Number to array -multiplicator(::Type{A}) where {A<:AbstractVector} = convert(A, float.(1:6)) -multiplicator(::Type{A}) where {A<:AbstractMatrix} = convert(A, reshape(float.(1:6), 2, 3)) +multiplicator(::Type{A}) where {A<:AbstractVector} = A(1:6) +multiplicator(::Type{A}) where {A<:AbstractMatrix} = A(reshape(1:6, 2, 3)) struct NumToArr{A} end NumToArr(::Type{A}) where {A} = NumToArr{A}()