From 3cc37191ff58da1db335a7f69936e9273d09dadd Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 12 Jun 2022 13:57:32 +0200 Subject: [PATCH 1/2] update for SciMLOperators --- src/DiffEqBase.jl | 5 +++-- src/nlsolve/newton.jl | 2 +- src/nlsolve/utils.jl | 4 ++-- test/affine_operators_tests.jl | 18 +++++++++--------- test/basic_operators_interface.jl | 4 ++-- 5 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index cc20d6106..cbc1bf567 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -34,8 +34,7 @@ using ForwardDiff @reexport using SciMLBase -using SciMLBase: @def, DEIntegrator, DEProblem, AbstractDiffEqOperator, - AbstractDiffEqLinearOperator, AbstractDiffEqInterpolation, +using SciMLBase: @def, DEIntegrator, DEProblem, AbstractDiffEqInterpolation, DECallback, AbstractDEOptions, DECache, AbstractContinuousCallback, AbstractDiscreteCallback, AbstractLinearProblem, AbstractNonlinearProblem, AbstractOptimizationProblem, AbstractSteadyStateProblem, AbstractJumpProblem, @@ -81,6 +80,8 @@ import SciMLBase: solve, init, solve!, __init, __solve, update_coefficients!, update_coefficients, isadaptive, wrapfun_oop, wrapfun_iip, unwrap_fw, promote_tspan, set_u!, set_t!, set_ut! +import SciMLOperators: MatrixOperator, AbstractSciMLOperator, AbstractSciMLLinearOperator + SciMLBase.isfunctionwrapper(x::FunctionWrapper) = true """ diff --git a/src/nlsolve/newton.jl b/src/nlsolve/newton.jl index 5bb0b21d9..6febe6a9e 100644 --- a/src/nlsolve/newton.jl +++ b/src/nlsolve/newton.jl @@ -157,7 +157,7 @@ end @.. broadcast=false ztmp = (dt*k - ztmp) * invγdt end - if W isa AbstractDiffEqLinearOperator + if W isa AbstractSciMLLinearOperator update_coefficients!(W,dz,p,tstep) end nlsolver.linsolve(vecdz,W,vecztmp,iter == 1 && new_W; Pl=ScaleVector(weight, true), Pr=ScaleVector(weight, false), tol=lintol) diff --git a/src/nlsolve/utils.jl b/src/nlsolve/utils.jl index 20fa126b0..91d37c0ac 100644 --- a/src/nlsolve/utils.jl +++ b/src/nlsolve/utils.jl @@ -190,8 +190,8 @@ DiffEqBase.@def oopnlsolve begin if islinear(f) || DiffEqBase.has_jac(f) # get the operator J = islinear(f) ? nf.f : f.jac(uprev, p, t) - if !isa(J, DiffEqBase.AbstractDiffEqLinearOperator) - J = DiffEqArrayOperator(J) + if !isa(J, DiffEqBase.AbstractSciMLLinearOperator) + J = MatrixOperator(J) end W = WOperator(f.mass_matrix, dt, J, false) else diff --git a/test/affine_operators_tests.jl b/test/affine_operators_tests.jl index 7f2772403..f21f289b9 100644 --- a/test/affine_operators_tests.jl +++ b/test/affine_operators_tests.jl @@ -2,25 +2,25 @@ using DiffEqBase using Test using Random -mutable struct TestDiffEqOperator{T} <: DiffEqBase.AbstractDiffEqLinearOperator{T} +mutable struct TestSciMLOperator{T} <: DiffEqBase.AbstractSciMLLinearOperator{T} m::Int n::Int end -TestDiffEqOperator(A::AbstractMatrix{T}) where {T} = - TestDiffEqOperator{T}(size(A)...) +TestSciMLOperator(A::AbstractMatrix{T}) where {T} = + TestSciMLOperator{T}(size(A)...) -Base.size(A::TestDiffEqOperator) = (A.m, A.n) +Base.size(A::TestSciMLOperator) = (A.m, A.n) -A = TestDiffEqOperator([0 0; 0 1]) -B = TestDiffEqOperator([0 0 0; 0 1 0; 0 0 2]) +A = TestSciMLOperator([0 0; 0 1]) +B = TestSciMLOperator([0 0 0; 0 1 0; 0 0 2]) -@test_throws ErrorException AffineDiffEqOperator{Int64}((A,B),()) +@test_throws ErrorException AffineSciMLOperator{Int64}((A,B),()) @testset "DiffEq linear operators" begin Random.seed!(0) - M = rand(2,2); A = DiffEqArrayOperator(M) + M = rand(2,2); A = MatrixOperator(M) b = rand(2) u = rand(2) p = rand(1) @@ -28,7 +28,7 @@ B = TestDiffEqOperator([0 0 0; 0 1 0; 0 0 2]) As_list = [(A,), (A, A)]#, (A, α)] bs_list = [(), (b,), (2b,), (b, 2b)] @testset "combinations of A's and b's" for As in As_list, bs in bs_list - L = AffineDiffEqOperator{Float64}(As, bs, zeros(2)) + L = AffineSciMLOperator{Float64}(As, bs, zeros(2)) mysum = sum(A*u for A in As) for b in bs; mysum .+= b; end @test L(u,p,t) == mysum diff --git a/test/basic_operators_interface.jl b/test/basic_operators_interface.jl index cc4da9ae7..8f27799c9 100644 --- a/test/basic_operators_interface.jl +++ b/test/basic_operators_interface.jl @@ -21,7 +21,7 @@ end @testset "Array Operators" begin Random.seed!(0); A = rand(2,2); u = rand(2); du = zeros(2) - L = DiffEqArrayOperator(A) + L = MatrixOperator(A) @test Matrix(L) == A @test size(L) == size(A) @test L * u == A * u @@ -44,7 +44,7 @@ end @testset "Mutable Array Operators" begin Random.seed!(0); A = rand(2,2); u = rand(2); du = zeros(2) update_func = (_A,u,p,t) -> _A .= t * A - Lt = DiffEqArrayOperator(zeros(2,2); update_func=update_func) + Lt = MatrixOperator(zeros(2,2); update_func=update_func) t = 5.0 @test isconstant(Lt) == false @test Lt(u,nothing,t) ≈ (t*A) * u From 5760cf95373f9e878e2db2432066b9aacafa6d8f Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 12 Jun 2022 14:13:36 +0200 Subject: [PATCH 2/2] fix import --- src/nlsolve/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nlsolve/utils.jl b/src/nlsolve/utils.jl index 91d37c0ac..3f71f0d4f 100644 --- a/src/nlsolve/utils.jl +++ b/src/nlsolve/utils.jl @@ -190,7 +190,7 @@ DiffEqBase.@def oopnlsolve begin if islinear(f) || DiffEqBase.has_jac(f) # get the operator J = islinear(f) ? nf.f : f.jac(uprev, p, t) - if !isa(J, DiffEqBase.AbstractSciMLLinearOperator) + if !isa(J, SciMLOperators.AbstractSciMLLinearOperator) J = MatrixOperator(J) end W = WOperator(f.mass_matrix, dt, J, false)