Skip to content

Commit

Permalink
Add AbstractOptimizationResult
Browse files Browse the repository at this point in the history
This allows for more seamless conversion between the result objects of
different methods. To convert between two result types, the first result
is converted to a Dict of field names to values, and then that dict is
converted to the target result type. This assumes that all result types
have a common set of field names, and for any field in a result that is
not in that common set, a custom convert method must be defined that
sets default values for those fields in the target result type.
  • Loading branch information
goerz committed Sep 16, 2024
1 parent 71b6197 commit b18e3e2
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/generate_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ open(outfile, "w") do out
```@docs
QuantumControl.set_default_ad_framework
QuantumControl.AbstractOptimizationResult
```
""")
write(out, raw"""
Expand Down
1 change: 1 addition & 0 deletions src/QuantumControl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ include("functionals.jl") # submodule Functionals

include("print_versions.jl")
include("set_default_ad_framework.jl")
include("result.jl")

include("deprecate.jl")

Expand Down
3 changes: 2 additions & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ where `:Krotov` is the name of the module implementing the method. The above is
also the method signature that a `Module` wishing to implement a control method
must define.
The returned `result` object is specific to the optimization method.
The returned `result` object is specific to the optimization method, but should
be a subtype of [`QuantumControl.AbstractOptimizationResult`](@ref).
"""
function optimize(
problem::ControlProblem;
Expand Down
25 changes: 15 additions & 10 deletions src/propagate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,15 @@ function init_prop_trajectory(
_prefixes=["prop_"],
_msg="Initializing propagator for trajectory",
_filter_kwargs=false,
_kwargs_dict::Dict{Symbol,Any}=Dict{Symbol,Any}(),
initial_state=traj.initial_state,
verbose=false,
kwargs...
)
#
# The private keyword arguments, `_prefixes`, `_msg`, `_filter_kwargs` are
# for internal use when setting up optimal control workspace objects (see,
# e.g., Krotov.jl and GRAPE.jl)
# The private keyword arguments, `_prefixes`, `_msg`, `_filter_kwargs`,
# `_kwargs_dict` are for internal use when setting up optimal control
# workspace objects (see, e.g., Krotov.jl and GRAPE.jl)
#
# * `_prefixes`: which prefixes to translate into `init_prop` kwargs. For
# example, in Krotov/GRAPE, we have propagators both for the forward and
Expand All @@ -117,12 +118,16 @@ function init_prop_trajectory(
# allows to pass the keyword arguments from `optimize` directly to
# `init_prop_trajectory`. By convention, these use the same
# `prop`/`fw_prop`/`bw_prop` prefixes as the properties of `traj`.
# * `_kwargs_dict`: A dictionary Symbol => Any that collects the arguments
# for `init_prop`. This allows to keep a copy of those arguments,
# especially for arguments that cannot be obtained from the resulting
# propagator, like the propagation callback.
#
kwargs_dict = Dict{Symbol,Any}()
empty!(_kwargs_dict)
for prefix in _prefixes
for key in propertynames(traj)
if startswith(string(key), prefix)
kwargs_dict[Symbol(string(key)[length(prefix)+1:end])] =
_kwargs_dict[Symbol(string(key)[length(prefix)+1:end])] =
getproperty(traj, key)
end
end
Expand All @@ -131,20 +136,20 @@ function init_prop_trajectory(
for prefix in _prefixes
for (key, val) in kwargs
if startswith(string(key), prefix)
kwargs_dict[Symbol(string(key)[length(prefix)+1:end])] = val
_kwargs_dict[Symbol(string(key)[length(prefix)+1:end])] = val
end
end
end
else
merge!(kwargs_dict, kwargs)
merge!(_kwargs_dict, kwargs)
end
level = verbose ? Logging.Info : Logging.Debug
@logmsg level _msg kwargs = kwargs_dict
@logmsg level _msg kwargs = _kwargs_dict
try
return init_prop(initial_state, traj.generator, tlist; verbose, kwargs_dict...)
return init_prop(initial_state, traj.generator, tlist; verbose, _kwargs_dict...)
catch exception
msg = "Cannot initialize propagation for trajectory"
@error msg exception kwargs = kwargs_dict
@error msg exception kwargs = _kwargs_dict
rethrow()
end
end
Expand Down
108 changes: 108 additions & 0 deletions src/result.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
Abstract type for the result object returned by [`optimize`](@ref). Any
optimization method implemented on top of `QuantumControl` should subtype
from `AbstractOptimizationResult`. This enables conversion between the results
of different methods, allowing one method to continue an optimization from
another method.
In order for this to work seamlessly, result objects should use a common set of
field names as much as a possible. When a result object requires fields that
cannot be provided by all other result objects, it should have default values
for these field, which can be defined in a custom `Base.convert` method, as,
e.g.,
```julia
function Base.convert(::Type{MyResult}, result::AbstractOptimizationResult)
defaults = Dict{Symbol,Any}(
:f_calls => 0,
:fg_calls => 0,
)
return convert(MyResult, result, defaults)
end
```
Where `f_calls` and `fg_calls` are fields of `MyResult` that are not present in
a given `result` of a different type. The three-argument `convert` is defined
internally for any `AbstractOptimizationResult`.
"""
abstract type AbstractOptimizationResult end

function Base.convert(
::Type{Dict{Symbol,Any}},
result::R
) where {R<:AbstractOptimizationResult}
return Dict{Symbol,Any}(field => getfield(result, field) for field in fieldnames(R))
end


struct MissingResultDataException{R} <: Exception
missing_fields::Vector{Symbol}
end


function Base.showerror(io::IO, err::MissingResultDataException{R}) where {R}
msg = "Missing data for fields $(err.missing_fields) to instantiate $R."
print(io, msg)
end


struct IncompatibleResultsException{R1,R2} <: Exception
missing_fields::Vector{Symbol}
end


function Base.showerror(io::IO, err::IncompatibleResultsException{R1,R2}) where {R1,R2}
msg = "$R2 cannot be converted to $R1: $R2 does not provide required fields $(err.missing_fields). $R1 may need a custom implementation of `Base.convert` that sets values for any field names not provided by all results."
print(io, msg)
end


function Base.convert(
::Type{R},
data::Dict{Symbol,<:Any},
defaults::Dict{Symbol,<:Any}=Dict{Symbol,Any}(),
) where {R<:AbstractOptimizationResult}

function _get(data, field, defaults)
# Can't use `get`, because that would try to evaluate the non-existing
# `defaults[field]` for `fields` that actually exist in `data`.
if haskey(data, field)
return data[field]
else
return defaults[field]
end
end

args = try
[_get(data, field, defaults) for field in fieldnames(R)]
catch exc
if exc isa KeyError
missing_fields = [
field for field in fieldnames(R) if
!(haskey(data, field) || haskey(defaults, field))
]
throw(MissingResultDataException{R}(missing_fields))
else
rethrow()

Check warning on line 86 in src/result.jl

View check run for this annotation

Codecov / codecov/patch

src/result.jl#L86

Added line #L86 was not covered by tests
end
end
return R(args...)
end


function Base.convert(
::Type{R1},
result::R2,
defaults::Dict{Symbol,<:Any}=Dict{Symbol,Any}(),
) where {R1<:AbstractOptimizationResult,R2<:AbstractOptimizationResult}
data = convert(Dict{Symbol,Any}, result)
try
return convert(R1, data, defaults)
catch exc
if exc isa MissingResultDataException{R1}
throw(IncompatibleResultsException{R1,R2}(exc.missing_fields))
else
rethrow()

Check warning on line 105 in src/result.jl

View check run for this annotation

Codecov / codecov/patch

src/result.jl#L105

Added line #L105 was not covered by tests
end
end
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ end
include("test_pulse_parameterizations.jl")
end

println("\n* Result Conversion (test_result_conversion.jl):")
@time @safetestset "Result Conversion" begin
include("test_result_conversion.jl")
end

println("* Invalid interfaces (test_invalid_interfaces.jl):")
@time @safetestset "Invalid interfaces" begin
include("test_invalid_interfaces.jl")
Expand Down
87 changes: 87 additions & 0 deletions test/test_result_conversion.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
using Test
using IOCapture

using QuantumControl:
AbstractOptimizationResult, MissingResultDataException, IncompatibleResultsException

struct _TestOptimizationResult1 <: AbstractOptimizationResult
iter_start::Int64
iter_stop::Int64
end

struct _TestOptimizationResult2 <: AbstractOptimizationResult
iter_start::Int64
J_T::Float64
J_T_prev::Float64
end

struct _TestOptimizationResult3 <: AbstractOptimizationResult
iter_start::Int64
iter_stop::Int64
end

@testset "Dict conversion" begin

R = _TestOptimizationResult1(0, 100)

data = convert(Dict{Symbol,Any}, R)
@test data isa Dict{Symbol,Any}
@test Set(keys(data)) == Set((:iter_stop, :iter_start))
@test data[:iter_start] == 0
@test data[:iter_stop] == 100

@test _TestOptimizationResult1(0, 100) _TestOptimizationResult1(0, 50)

_R = convert(_TestOptimizationResult1, data)
@test _R == R

captured = IOCapture.capture(; passthrough=false, rethrow=Union{}) do
convert(_TestOptimizationResult2, data)
end
@test captured.value isa MissingResultDataException
msg = begin
io = IOBuffer()
showerror(io, captured.value)
String(take!(io))
end
@test startswith(msg, "Missing data for fields [:J_T, :J_T_prev]")
@test contains(msg, "_TestOptimizationResult2")

end


@testset "Result conversion" begin

R = _TestOptimizationResult1(0, 100)

_R = convert(_TestOptimizationResult1, R)
@test _R == R

_R = convert(_TestOptimizationResult3, R)
@test _R isa _TestOptimizationResult3
@test convert(Dict{Symbol,Any}, _R) == convert(Dict{Symbol,Any}, R)

captured = IOCapture.capture(; passthrough=false, rethrow=Union{}) do
convert(_TestOptimizationResult2, R)
end
@test captured.value isa IncompatibleResultsException
msg = begin
io = IOBuffer()
showerror(io, captured.value)
String(take!(io))
end
@test contains(msg, "does not provide required fields [:J_T, :J_T_prev]")

R2 = _TestOptimizationResult2(0, 0.1, 0.4)
captured = IOCapture.capture(; passthrough=false, rethrow=Union{}) do
convert(_TestOptimizationResult1, R2)
end
@test captured.value isa IncompatibleResultsException
msg = begin
io = IOBuffer()
showerror(io, captured.value)
String(take!(io))
end
@test contains(msg, "does not provide required fields [:iter_stop]")

end

0 comments on commit b18e3e2

Please sign in to comment.