Skip to content

Commit

Permalink
chore: don't return structural tangent
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed May 20, 2024
1 parent 44bfc91 commit de2d6cd
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
20 changes: 12 additions & 8 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import Zygote: literal_getproperty
using SciMLBase
using SciMLBase: ODESolution, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution
NonlinearSolution, AbstractTimeseriesSolution,
SciMLStructures
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed,
observed, parameter_values
observed, parameter_values, state_values, current_time
using RecursiveArrayTools

# This method resolves the ambiguity with the pullback defined in
Expand Down Expand Up @@ -111,10 +112,13 @@ end
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
if is_observed(VA, sym)
y, back = Zygote.pullback(VA) do sol
f = observed(sol, sym)
p = parameter_values(sol)
f.(sol.u, Ref(p), sol.t)
f = observed(VA, sym)
p = parameter_values(VA)
tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
u = state_values(VA)
t = current_time(VA)
y, back = Zygote.pullback(u, tunables) do u, tunables
f.(u, Ref(tunables), t)

Check warning on line 121 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L114-L121

Added lines #L114 - L121 were not covered by tests
end
gs = back(Δ)
(gs[1], nothing)
Expand Down Expand Up @@ -154,8 +158,7 @@ function not_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where {T}
end
end

nt = Zygote.nt_nothing(VA)
Zygote.accum(nt, (u = Δ′,))
Δ′

Check warning on line 161 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L161

Added line #L161 was not covered by tests
end

@adjoint function Base.getindex(
Expand All @@ -171,6 +174,7 @@ end
gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ)

Check warning on line 174 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L173-L174

Added lines #L173 - L174 were not covered by tests

a = Zygote.accum(gs_obs[1], gs_not_obs)

Check warning on line 176 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L176

Added line #L176 was not covered by tests

(a, nothing)

Check warning on line 178 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L178

Added line #L178 was not covered by tests
end
VA[sym], ODESolution_getindex_pullback
Expand Down
6 changes: 3 additions & 3 deletions test/downstream/observables_autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ sol = solve(prob, Tsit5())
end
du_ = [0.0, 1.0, 1.0, 1.0]
du = [du_ for _ in sol.u]
@test du == gs.u
@test du == gs

# Observable in a vector
gs, = gradient(sol) do sol
sum(sum.(sol[[sys.w, sys.x]]))
end
du_ = [0.0, 1.0, 1.0, 2.0]
du = [du_ for _ in sol.u]
@test du == gs.u
@test du == gs
end

# DAE
Expand Down Expand Up @@ -84,7 +84,7 @@ end
end
du_ = [0.2, 1.0]
du = [du_ for _ in sol.u]
@test gs.u == du
@test gs == du
end

# @testset "Adjoints with DAE" begin
Expand Down

0 comments on commit de2d6cd

Please sign in to comment.