Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: indexing rework with new SymbolicIndexingInterface #532

Merged
merged 1 commit into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ PyCall = "1.96"
PythonCall = "0.9.15"
RCall = "0.13.18"
RecipesBase = "1.0"
RecursiveArrayTools = "2.38"
RecursiveArrayTools = "3.0"
Reexport = "1"
RuntimeGeneratedFunctions = "0.5"
SciMLOperators = "0.3.7"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.9"
SymbolicIndexingInterface = "0.2"
SymbolicIndexingInterface = "0.3"
Tables = "1.11"
TruncatedStacktraces = "1.4"
QuasiMonteCarlo = "0.2.19, 0.3"
Expand Down
4 changes: 2 additions & 2 deletions ext/SciMLBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
sym,
j::Integer)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym

Check warning on line 15 in ext/SciMLBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseChainRulesCoreExt.jl#L15

Added line #L15 was not covered by tests
if i === nothing
getter = getobserved(VA)
grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
Expand Down Expand Up @@ -65,7 +65,7 @@

function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym

Check warning on line 68 in ext/SciMLBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseChainRulesCoreExt.jl#L68

Added line #L68 was not covered by tests
if i === nothing
throw(error("AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Expand Down
9 changes: 5 additions & 4 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
using Zygote: @adjoint, pullback
import Zygote: literal_getproperty
using SciMLBase
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake,
using SciMLBase: ODESolution, sym_to_index, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution
using SymbolicIndexingInterface: symbolic_type, NotSymbolic

# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
Expand All @@ -32,7 +33,7 @@

@adjoint function getindex(VA::ODESolution, sym, j::Int)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym

Check warning on line 36 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L36

Added line #L36 was not covered by tests
du, dprob = if i === nothing
getter = getobserved(VA)
grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
Expand Down Expand Up @@ -81,7 +82,7 @@
for (x, j) in zip(VA.u, 1:length(VA))]
(Δ′, nothing)
end
VA[i], ODESolution_getindex_pullback
VA[:, i], ODESolution_getindex_pullback

Check warning on line 85 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L85

Added line #L85 was not covered by tests
end

@adjoint function Zygote.literal_getproperty(sim::EnsembleSolution,
Expand All @@ -91,7 +92,7 @@

@adjoint function getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym

Check warning on line 95 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L95

Added line #L95 was not covered by tests
if i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Expand Down
20 changes: 5 additions & 15 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
function EnsembleSolution(sim::T, elapsedTime,
converged, stats=nothing) where {T <: AbstractVector{T2}
} where {T2 <:
AbstractArray}
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1, typeof(sim)}(
sim,
Union{AbstractArray,RecursiveArrayTools.AbstractVectorOfArray}}
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1,

Check warning on line 44 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L44

Added line #L44 was not covered by tests
typeof(sim)}(sim,
elapsedTime,
converged,
stats)
Expand Down Expand Up @@ -209,18 +209,8 @@
end
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
return [xi[s] for xi in x]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am very happy you were able to get rid of this.

::Colon,
args::Colon...)
return invoke(getindex,
Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...},
x,
:,
args...)
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Union{ScalarSymbolic,ArraySymbolic}, s, ::Colon)
return [xi[s] for xi in x.u]

Check warning on line 213 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L212-L213

Added lines #L212 - L213 were not covered by tests
end

function (sol::AbstractEnsembleSolution)(args...; kwargs...)
Expand Down
184 changes: 87 additions & 97 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,13 @@
# So any error checking happens to ensure we actually _can_ set state
set_u!(integrator, integrator.u)

if !issymbollike(sym)
if symbolic_type(sym) == NotSymbolic()

Check warning on line 348 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L348

Added line #L348 was not covered by tests
error("sym must be a symbol")
end
i = sym_to_index(sym, integrator)
i = variable_index(integrator, sym)

Check warning on line 351 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L351

Added line #L351 was not covered by tests

if isnothing(i)
error("sym is not a state variable")
error("$sym is not a state variable")

Check warning on line 354 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L354

Added line #L354 was not covered by tests
end

integrator.u[i] = val
Expand Down Expand Up @@ -385,27 +385,27 @@

### Indexing
function getsyms(integrator::DEIntegrator)
if has_syms(integrator.f)
return integrator.f.syms
else
return keys(integrator.u[1])
syms = variable_symbols(integrator)
if isempty(syms)
syms = keys(integrator.u)

Check warning on line 390 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L388-L390

Added lines #L388 - L390 were not covered by tests
end
return syms

Check warning on line 392 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L392

Added line #L392 was not covered by tests
end

function getindepsym(integrator::DEIntegrator)
if has_indepsym(integrator.f)
return integrator.f.indepsym
else
syms = independent_variable_symbols(integrator)
if isempty(syms)

Check warning on line 397 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L396-L397

Added lines #L396 - L397 were not covered by tests
return nothing
end
return syms

Check warning on line 400 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L400

Added line #L400 was not covered by tests
end

function getparamsyms(integrator::DEIntegrator)
if has_paramsyms(integrator.f)
return integrator.f.paramsyms
else
psyms = parameter_symbols(integrator)
if isempty(psyms)

Check warning on line 405 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L404-L405

Added lines #L404 - L405 were not covered by tests
return nothing
end
return psyms

Check warning on line 408 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L408

Added line #L408 was not covered by tests
end

function getobserved(integrator::DEIntegrator)
Expand All @@ -417,58 +417,76 @@
end

function sym_to_index(sym, integrator::DEIntegrator)
if has_sys(integrator.f) && is_state_sym(integrator.f.sys, sym)
return state_sym_to_index(integrator.f.sys, sym)
idx = variable_index(integrator, sym)
if idx === nothing
idx = findfirst(isequal(sym), keys(integrator.u))

Check warning on line 422 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L420-L422

Added lines #L420 - L422 were not covered by tests
end
return idx

Check warning on line 424 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L424

Added line #L424 was not covered by tests
end

# SymbolicIndexingInterface
SymbolicIndexingInterface.symbolic_container(A::DEIntegrator) = A.f
SymbolicIndexingInterface.parameter_values(A::DEIntegrator) = A.p

Check warning on line 429 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L428-L429

Added lines #L428 - L429 were not covered by tests

function SymbolicIndexingInterface.is_observed(A::DEIntegrator, sym)
return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic()

Check warning on line 432 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L431-L432

Added lines #L431 - L432 were not covered by tests
end

function SymbolicIndexingInterface.observed(A::DEIntegrator, sym)
(u, p, t) -> getobserved(A)(sym, u, p, t)

Check warning on line 436 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L435-L436

Added lines #L435 - L436 were not covered by tests
end

SymbolicIndexingInterface.is_time_dependent(::DEIntegrator) = true

Check warning on line 439 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L439

Added line #L439 was not covered by tests

# TODO make this nontrivial once dynamic state selection works
SymbolicIndexingInterface.constant_structure(::DEIntegrator) = true

Check warning on line 442 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L442

Added line #L442 was not covered by tests

function Base.getproperty(A::DEIntegrator, sym::Symbol)
if sym === :destats && hasfield(typeof(A), :stats)
@warn "destats has been deprecated for stats"
getfield(A, :stats)

Check warning on line 447 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L444-L447

Added lines #L444 - L447 were not covered by tests
else
return sym_to_index(sym, getsyms(integrator))
return getfield(A, sym)

Check warning on line 449 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L449

Added line #L449 was not covered by tests
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator,
I::Union{Int, AbstractArray{Int},
Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::NotSymbolic, I::Union{Int, AbstractArray{Int},

Check warning on line 453 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L453

Added line #L453 was not covered by tests
CartesianIndex, Colon, BitArray,
AbstractArray{Bool}}...)
RecursiveArrayTools.VectorOfArray(A.u)[I...]
A.u[I...]

Check warning on line 456 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L456

Added line #L456 was not covered by tests
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym)
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
elseif is_independent_variable(A, sym)
return A.t
elseif is_observed(A, sym)
return SymbolicIndexingInterface.observed(A, sym)(A.u, A.p, A.t)

Check warning on line 468 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L459-L468

Added lines #L459 - L468 were not covered by tests
else
error("Tried to index integrator with a Symbol that was not found in the system.")

Check warning on line 470 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L470

Added line #L470 was not covered by tests
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ArraySymbolic, sym)
return A[collect(sym)]

Check warning on line 475 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L474-L475

Added lines #L474 - L475 were not covered by tests
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
return getindex.((A,), sym)

Check warning on line 479 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L478-L479

Added lines #L478 - L479 were not covered by tests
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
if issymbollike(sym)
if sym isa AbstractArray
return A[collect(sym)]
end
i = sym_to_index(sym, A)
elseif all(issymbollike, sym)
return getindex.((A,), sym)
else
i = sym
end
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))

Check warning on line 484 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L483-L484

Added lines #L483 - L484 were not covered by tests

if i === nothing
if issymbollike(sym)
if has_sys(A.f) && is_indep_sym(A.f.sys, sym) ||
Symbol(sym) == getindepsym(A)
return A.t
elseif has_sys(A.f) && is_param_sym(A.f.sys, sym)
return A.p[param_sym_to_index(A.f.sys, sym)]
elseif has_paramsyms(A.f) && Symbol(sym) in getparamsyms(A)
return A.p[findfirst(x -> isequal(x, Symbol(sym)), getparamsyms(A))]
elseif (sym isa Symbol) && has_sys(A.f) && hasproperty(A.f.sys, sym) # Handles input like :X (where X is a state).
return observed(A, getproperty(A.f.sys, sym))
elseif has_sys(A.f) && (count('₊', String(Symbol(sym))) == 1) &&
(count(isequal(Symbol(sym)),
Symbol.(A.f.sys.name, :₊, getparamsyms(A))) == 1) # Handles input like sys.X (where X is a parameter).
return A.p[findfirst(isequal(Symbol(sym)),
Symbol.(A.f.sys.name, :₊, getparamsyms(A)))]
else
return observed(A, sym)
end
else
observed(A, sym)
end
elseif i isa Base.Integer || i isa AbstractRange || i isa AbstractVector{<:Base.Integer}
A[i]
if symtype != NotSymbolic()
return getindex(A, symtype, sym)

Check warning on line 487 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L486-L487

Added lines #L486 - L487 were not covered by tests
else
error("Invalid indexing of integrator")
return getindex(A, elsymtype, sym)

Check warning on line 489 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L489

Added line #L489 was not covered by tests
end
end

Expand All @@ -477,52 +495,24 @@
end

function Base.setindex!(A::DEIntegrator, val, sym)
if has_sys(A.f)
if issymbollike(sym)
params = getparamsyms(A)
s = Symbol.(states(A.f.sys))
params = Symbol.(params)

i = findfirst(isequal(Symbol(sym)), s)
if !isnothing(i)
A.u[i] = val
return A
elseif sym isa Symbol # Handles input like :X.
s_f = Symbol.(getproperty.(states(A.f.sys), :f))
if count(isequal(Symbol(sym)), s_f) == 1
i = findfirst(isequal(sym), s_f)
A.u[i] = val
return A
elseif count(isequal(Symbol(sym)), s_f) > 1
error("The input symbol $(sym) occurs several times among integrator states. Please avoid use Symbol form (:$(sym)).")
end
elseif count('₊', String(Symbol(sym))) == 1 # Handles input like sys.X.
s_names = Symbol.(A.f.sys.name, :₊, s)
if count(isequal(Symbol(sym)), s_names) == 1
i = findfirst(isequal(Symbol(sym)), s_names)
A.u[i] = val
return A
end
end

i = findfirst(isequal(Symbol(sym)), params)
if !isnothing(i)
A.p[i] = val
return A
elseif count('₊', String(Symbol(sym))) == 1 # Handles input like sys.X.
p_names = Symbol.(A.f.sys.name, :₊, params)
if count(isequal(Symbol(sym)), p_names) == 1
i = findfirst(isequal(Symbol(sym)), p_names)
A.p[i] = val
return A
end
end
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")
has_sys(A.f) || error("Invalid indexing of integrator: Integrator does not support indexing without a system")
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(A, sym)
A.u[variable_index(A, sym)] = val
elseif is_parameter(A, sym)
Base.depwarn("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.", :parameter_setindex)
setp(A, sym)(A, val)

Check warning on line 504 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L498-L504

Added lines #L498 - L504 were not covered by tests
else
error("Invalid indexing of integrator: $sym is not a symbol")
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")

Check warning on line 506 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L506

Added line #L506 was not covered by tests
end
return A
elseif symbolic_type(sym) == ArraySymbolic()
setindex!.((A,), val, collect(sym))
return A

Check warning on line 511 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L508-L511

Added lines #L508 - L511 were not covered by tests
else
error("Invalid indexing of integrator: Integrator does not support indexing without a system")
sym isa AbstractArray || error("Invalid indexing of integrator")
setindex!.((A,), val, sym)
return A

Check warning on line 515 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L513-L515

Added lines #L513 - L515 were not covered by tests
end
end

Expand Down
Loading
Loading