diff --git a/Project.toml b/Project.toml index 1651aab2a..dfc4201f7 100644 --- a/Project.toml +++ b/Project.toml @@ -6,9 +6,11 @@ version = "3.15.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" @@ -32,6 +34,7 @@ SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" diff --git a/src/internal/jacobian.jl b/src/internal/jacobian.jl index b78eb7383..aa1a3c83b 100644 --- a/src/internal/jacobian.jl +++ b/src/internal/jacobian.jl @@ -1,3 +1,5 @@ +using Enzyme + """ JacobianCache(prob, alg, f::F, fu, u, p; autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing, linsolve = missing) where {F} @@ -61,7 +63,9 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing, if !has_analytic_jac && needs_jac autodiff = construct_concrete_adtype(f, autodiff) - di_extras = if iip + di_extras = if !iip && autodiff isa AutoEnzyme + Enzyme.onehot(u) + elseif iip DI.prepare_jacobian(f, fu, autodiff, u, Constant(prob.p)) else DI.prepare_jacobian(f, autodiff, u, Constant(prob.p)) @@ -88,7 +92,11 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing, if iip DI.jacobian(f, fu, di_extras, autodiff, u, Constant(p)) else - DI.jacobian(f, di_extras, autodiff, u, Constant(p)) + if autodiff isa AutoEnzyme + hcat(Enzyme.autodiff(Forward, f, BatchDuplicated(u, di_extras), Const(p))[1]...) + else + DI.jacobian(f, di_extras, autodiff, u, Constant(p)) + end end end else @@ -153,6 +161,8 @@ function (cache::JacobianCache{iip})( else if SciMLBase.has_jac(cache.f) return cache.f.jac(u, p) + elseif cache.autodiff isa AutoEnzyme + hcat(Enzyme.autodiff(Forward, cache.f, BatchDuplicated(u, cache.di_extras), Const(p))[1]...) else return DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p)) end