Skip to content

Commit

Permalink
Simplify static test scenarios (#581)
Browse files Browse the repository at this point in the history
* Simplify static scenarios

* No conversion

* Exclude derivative for Zygote

* Fix JLArrays

* Unskip
  • Loading branch information
gdalle authored Oct 16, 2024
1 parent a97b432 commit 94f9bc5
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,24 @@ module DifferentiationInterfaceTestJLArraysExt
import DifferentiationInterface as DI
using DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using JLArrays: JLArray, jl
using JLArrays: JLArray, JLVector, JLMatrix, jl
using Random: AbstractRNG, default_rng

myjl(f::Function) = f
function myjl(::DIT.NumToArr{A}) where {T,N,A<:AbstractArray{T,N}}
return DIT.NumToArr(JLArray{T,N})
end

function (f::DIT.NumToArr{JLVector{T}})(x::Number) where {T}
a = JLVector{T}(Vector(1:6)) # avoid mutation
return sin.(x .* a)
end

function (f::DIT.NumToArr{JLMatrix{T}})(x::Number) where {T}
a = JLMatrix{T}(Matrix(reshape(1:6, 2, 3))) # avoid mutation
return sin.(x .* a)
end

myjl(f::DIT.MultiplyByConstant) = f
myjl(f::DIT.WritableClosure) = f

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,10 @@ using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector

mySArray(f::Function) = f
myMArray(f::Function) = f

mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(SVector{6,T})
myMArray(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(MVector{6,T})

mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(SMatrix{2,3,T,6})
myMArray(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(MMatrix{2,3,T,6})

mySArray(f::DIT.MultiplyByConstant) = f
myMArray(f::DIT.MultiplyByConstant) = f

mySArray(f::DIT.WritableClosure) = f
myMArray(f::DIT.WritableClosure) = f

mySArray(x::Number) = x
myMArray(x::Number) = x
Expand All @@ -36,13 +27,8 @@ function myMArray(x::AbstractMatrix{T}) where {T}
end

mySArray(x::Tuple) = map(mySArray, x)
myMArray(x::Tuple) = map(myMArray, x)

mySArray(x::DI.Constant) = DI.Constant(mySArray(DI.unwrap(x)))
myMArray(x::DI.Constant) = DI.Constant(myMArray(DI.unwrap(x)))

mySArray(::Nothing) = nothing
myMArray(::Nothing) = nothing

function mySArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
(; f, x, y, tang, contexts, res1, res2) = scen
Expand All @@ -57,22 +43,9 @@ function mySArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
)
end

function myMArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
(; f, x, y, tang, contexts, res1, res2) = scen
return Scenario{op,pl_op,pl_fun}(
myMArray(f);
x=myMArray(x),
y=pl_fun == :in ? myMArray(y) : myMArray(y),
tang=myMArray(tang),
contexts=myMArray(contexts),
res1=myMArray(res1),
res2=myMArray(res2),
)
end

function DIT.static_scenarios(args...; kwargs...)
scens = DIT.default_scenarios(args...; kwargs...)
return vcat(mySArray.(scens), myMArray.(scens))
return mySArray.(scens)
end

end
4 changes: 2 additions & 2 deletions DifferentiationInterfaceTest/src/scenarios/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ end

## Number to array

multiplicator(::Type{A}) where {A<:AbstractVector} = convert(A, float.(1:6))
multiplicator(::Type{A}) where {A<:AbstractMatrix} = convert(A, reshape(float.(1:6), 2, 3))
multiplicator(::Type{A}) where {A<:AbstractVector} = A(1:6)
multiplicator(::Type{A}) where {A<:AbstractMatrix} = A(reshape(1:6, 2, 3))

struct NumToArr{A} end
NumToArr(::Type{A}) where {A} = NumToArr{A}()
Expand Down

0 comments on commit 94f9bc5

Please sign in to comment.