-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This allows for more seamless conversion between the result objects of different methods. To convert between two result types, the first result is converted to a Dict of field names to values, and then that dict is converted to the target result type. This assumes that all result types have a common set of field names, and for any field in a result that is not in that common set, a custom convert method must be defined that sets default values for those fields in the target result type.
- Loading branch information
Showing
7 changed files
with
219 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
""" | ||
Abstract type for the result object returned by [`optimize`](@ref). Any | ||
optimization method implemented on top of `QuantumControl` should subtype | ||
from `AbstractOptimizationResult`. This enables conversion between the results | ||
of different methods, allowing one method to continue an optimization from | ||
another method. | ||
In order for this to work seamlessly, result objects should use a common set of | ||
field names as much as a possible. When a result object requires fields that | ||
cannot be provided by all other result objects, it should have default values | ||
for these field, which can be defined in a custom `Base.convert` method, as, | ||
e.g., | ||
```julia | ||
function Base.convert(::Type{MyResult}, result::AbstractOptimizationResult) | ||
defaults = Dict{Symbol,Any}( | ||
:f_calls => 0, | ||
:fg_calls => 0, | ||
) | ||
return convert(MyResult, result, defaults) | ||
end | ||
``` | ||
Where `f_calls` and `fg_calls` are fields of `MyResult` that are not present in | ||
a given `result` of a different type. The three-argument `convert` is defined | ||
internally for any `AbstractOptimizationResult`. | ||
""" | ||
abstract type AbstractOptimizationResult end | ||
|
||
function Base.convert( | ||
::Type{Dict{Symbol,Any}}, | ||
result::R | ||
) where {R<:AbstractOptimizationResult} | ||
return Dict{Symbol,Any}(field => getfield(result, field) for field in fieldnames(R)) | ||
end | ||
|
||
|
||
struct MissingResultDataException{R} <: Exception | ||
missing_fields::Vector{Symbol} | ||
end | ||
|
||
|
||
function Base.showerror(io::IO, err::MissingResultDataException{R}) where {R} | ||
msg = "Missing data for fields $(err.missing_fields) to instantiate $R." | ||
print(io, msg) | ||
end | ||
|
||
|
||
struct IncompatibleResultsException{R1,R2} <: Exception | ||
missing_fields::Vector{Symbol} | ||
end | ||
|
||
|
||
function Base.showerror(io::IO, err::IncompatibleResultsException{R1,R2}) where {R1,R2} | ||
msg = "$R2 cannot be converted to $R1: $R2 does not provide required fields $(err.missing_fields). $R1 may need a custom implementation of `Base.convert` that sets values for any field names not provided by all results." | ||
print(io, msg) | ||
end | ||
|
||
|
||
function Base.convert( | ||
::Type{R}, | ||
data::Dict{Symbol,<:Any}, | ||
defaults::Dict{Symbol,<:Any}=Dict{Symbol,Any}(), | ||
) where {R<:AbstractOptimizationResult} | ||
|
||
function _get(data, field, defaults) | ||
# Can't use `get`, because that would try to evaluate the non-existing | ||
# `defaults[field]` for `fields` that actually exist in `data`. | ||
if haskey(data, field) | ||
return data[field] | ||
else | ||
return defaults[field] | ||
end | ||
end | ||
|
||
args = try | ||
[_get(data, field, defaults) for field in fieldnames(R)] | ||
catch exc | ||
if exc isa KeyError | ||
missing_fields = [ | ||
field for field in fieldnames(R) if | ||
!(haskey(data, field) || haskey(defaults, field)) | ||
] | ||
throw(MissingResultDataException{R}(missing_fields)) | ||
else | ||
rethrow() | ||
end | ||
end | ||
return R(args...) | ||
end | ||
|
||
|
||
function Base.convert( | ||
::Type{R1}, | ||
result::R2, | ||
defaults::Dict{Symbol,<:Any}=Dict{Symbol,Any}(), | ||
) where {R1<:AbstractOptimizationResult,R2<:AbstractOptimizationResult} | ||
data = convert(Dict{Symbol,Any}, result) | ||
try | ||
return convert(R1, data, defaults) | ||
catch exc | ||
if exc isa MissingResultDataException{R1} | ||
throw(IncompatibleResultsException{R1,R2}(exc.missing_fields)) | ||
else | ||
rethrow() | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
using Test | ||
using IOCapture | ||
|
||
using QuantumControl: | ||
AbstractOptimizationResult, MissingResultDataException, IncompatibleResultsException | ||
|
||
struct _TestOptimizationResult1 <: AbstractOptimizationResult | ||
iter_start::Int64 | ||
iter_stop::Int64 | ||
end | ||
|
||
struct _TestOptimizationResult2 <: AbstractOptimizationResult | ||
iter_start::Int64 | ||
J_T::Float64 | ||
J_T_prev::Float64 | ||
end | ||
|
||
struct _TestOptimizationResult3 <: AbstractOptimizationResult | ||
iter_start::Int64 | ||
iter_stop::Int64 | ||
end | ||
|
||
@testset "Dict conversion" begin | ||
|
||
R = _TestOptimizationResult1(0, 100) | ||
|
||
data = convert(Dict{Symbol,Any}, R) | ||
@test data isa Dict{Symbol,Any} | ||
@test Set(keys(data)) == Set((:iter_stop, :iter_start)) | ||
@test data[:iter_start] == 0 | ||
@test data[:iter_stop] == 100 | ||
|
||
@test _TestOptimizationResult1(0, 100) ≠ _TestOptimizationResult1(0, 50) | ||
|
||
_R = convert(_TestOptimizationResult1, data) | ||
@test _R == R | ||
|
||
captured = IOCapture.capture(; passthrough=false, rethrow=Union{}) do | ||
convert(_TestOptimizationResult2, data) | ||
end | ||
@test captured.value isa MissingResultDataException | ||
msg = begin | ||
io = IOBuffer() | ||
showerror(io, captured.value) | ||
String(take!(io)) | ||
end | ||
@test startswith(msg, "Missing data for fields [:J_T, :J_T_prev]") | ||
@test contains(msg, "_TestOptimizationResult2") | ||
|
||
end | ||
|
||
|
||
@testset "Result conversion" begin | ||
|
||
R = _TestOptimizationResult1(0, 100) | ||
|
||
_R = convert(_TestOptimizationResult1, R) | ||
@test _R == R | ||
|
||
_R = convert(_TestOptimizationResult3, R) | ||
@test _R isa _TestOptimizationResult3 | ||
@test convert(Dict{Symbol,Any}, _R) == convert(Dict{Symbol,Any}, R) | ||
|
||
captured = IOCapture.capture(; passthrough=false, rethrow=Union{}) do | ||
convert(_TestOptimizationResult2, R) | ||
end | ||
@test captured.value isa IncompatibleResultsException | ||
msg = begin | ||
io = IOBuffer() | ||
showerror(io, captured.value) | ||
String(take!(io)) | ||
end | ||
@test contains(msg, "does not provide required fields [:J_T, :J_T_prev]") | ||
|
||
R2 = _TestOptimizationResult2(0, 0.1, 0.4) | ||
captured = IOCapture.capture(; passthrough=false, rethrow=Union{}) do | ||
convert(_TestOptimizationResult1, R2) | ||
end | ||
@test captured.value isa IncompatibleResultsException | ||
msg = begin | ||
io = IOBuffer() | ||
showerror(io, captured.value) | ||
String(take!(io)) | ||
end | ||
@test contains(msg, "does not provide required fields [:iter_stop]") | ||
|
||
end |