Skip to content

Commit

Permalink
Remove more redundancy
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Nov 27, 2024
1 parent f5b5894 commit 2b2ac55
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
6 changes: 4 additions & 2 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
pf = get_pf(autojacvec; _f = unwrappedf, isinplace = isinplace, isRODE = isRODE)
paramjac_config = (paramjac_config..., Enzyme.make_zero(pf))
elseif autojacvec isa MooncakeVJP
pf = get_pf(autojacvec; _f = unwrappedf, isinplace, isRODE)
pf = get_pf(autojacvec, prob, unwrappedf)
paramjac_config = get_paramjac_config(autojacvec, pf, p, f, y, _t)
elseif SciMLBase.has_paramjac(f) || quad || !(autojacvec isa Bool) ||
autojacvec isa EnzymeVJP
Expand Down Expand Up @@ -506,7 +506,9 @@ function get_pf(autojacvec::EnzymeVJP; _f, isinplace, isRODE)
end
end

function get_pf(autojacvec::MooncakeVJP; _f, isinplace, isRODE)
function get_pf(::MooncakeVJP, prob, _f)
isinplace = DiffEqBase.isinplace(prob)
isRODE = isa(prob, RODEProblem)
pf = let f = _f
if isinplace && isRODE
function (out, u, _p, t, W)
Expand Down
4 changes: 1 addition & 3 deletions src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,7 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing)
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
pJ = nothing
elseif sensealg.autojacvec isa MooncakeVJP
isinplace = DiffEqBase.isinplace(prob)
isRODE = isa(prob, RODEProblem)
pf = get_pf(sensealg.autojacvec; _f = f, isinplace, isRODE)
pf = get_pf(sensealg.autojacvec, prob, f)
paramjac_config = get_paramjac_config(sensealg.autojacvec, pf, p, f, y, tspan[2])
pJ = nothing
elseif isautojacvec # Zygote
Expand Down
4 changes: 1 addition & 3 deletions src/quadrature_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing)
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
pJ = nothing
elseif sensealg.autojacvec isa MooncakeVJP
isinplace = DiffEqBase.isinplace(prob)
isRODE = isa(prob, RODEProblem)
pf = get_pf(sensealg.autojacvec; _f = f, isinplace, isRODE)
pf = get_pf(sensealg.autojacvec, prob, f)
paramjac_config = get_paramjac_config(sensealg.autojacvec, pf, p, f, y, tspan[2])
pJ = nothing
elseif isautojacvec # Zygote
Expand Down

0 comments on commit 2b2ac55

Please sign in to comment.