Skip to content

Commit

Permalink
Add parametrization module
Browse files Browse the repository at this point in the history
  • Loading branch information
goerz committed Jun 5, 2021
1 parent cfe6b10 commit 809d94e
Show file tree
Hide file tree
Showing 6 changed files with 1,201 additions and 10 deletions.
1,049 changes: 1,049 additions & 0 deletions docs/notebooks/01_example_simple_state_to_state_parametrization.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/krotov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
mu,
objectives,
parallelization,
parametrization,
propagators,
result,
second_order,
Expand Down
19 changes: 14 additions & 5 deletions src/krotov/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import numpy as np

from .parametrization import ParametrizedArray


__all__ = [
'control_onto_interval',
Expand Down Expand Up @@ -116,25 +118,28 @@ def discretize(control, tlist, args=(None,), kwargs=None, via_midpoints=False):
kwargs=kwargs,
via_midpoints=False,
)
return pulse_onto_tlist(pulse_on_midpoints)
result = pulse_onto_tlist(pulse_on_midpoints)
else:
# relies on np.ComplexWarning being thrown as an error
return np.array(
result = np.array(
[float(control(t, *args, **kwargs)) for t in tlist],
dtype=np.float64,
)
elif isinstance(control, (np.ndarray, list)):
# relies on np.ComplexWarning being thrown as an error
control = np.array([float(v) for v in control], dtype=np.float64)
if len(control) != len(tlist):
result = np.array([float(v) for v in control], dtype=np.float64)
if len(result) != len(tlist):
raise ValueError(
"If control is an array, it must of the same length as tlist"
)
return control
else:
raise TypeError(
"control must be either a callable func(t, args) or a numpy array"
)
if hasattr(control, 'parametrization'):
return ParametrizedArray(result, control.parametrization)
else:
return result


def extract_controls(objectives):
Expand Down Expand Up @@ -354,6 +359,8 @@ def control_onto_interval(control):
if isinstance(control, np.ndarray):
assert len(control.shape) == 1 # must be 1D array
pulse = np.zeros(len(control) - 1, dtype=control.dtype.type)
if hasattr(control, 'parametrization'):
pulse = ParametrizedArray(pulse, control.parametrization)
pulse[0] = control[0]
for i in range(1, len(control) - 1):
pulse[i] = 2.0 * control[i] - pulse[i - 1]
Expand Down Expand Up @@ -383,6 +390,8 @@ def pulse_onto_tlist(pulse):
of the input values before and after the point.
"""
control = np.zeros(len(pulse) + 1, dtype=pulse.dtype.type)
if hasattr(pulse, 'parametrization'):
control = ParametrizedArray(control, pulse.parametrization)
control[0] = pulse[0]
for i in range(1, len(control) - 1):
control[i] = 0.5 * (pulse[i - 1] + pulse[i])
Expand Down
3 changes: 3 additions & 0 deletions src/krotov/mu.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,7 @@ def derivative_wrt_pulse(
raise NotImplementedError(
"Time-dependent collapse operators not implemented"
)
if hasattr(pulses[i_pulse], 'parametrization'):
ϵ = pulses[i_pulse][time_index]
mu *= pulses[i_pulse].parametrization.derivative(ϵ)
return mu
21 changes: 16 additions & 5 deletions src/krotov/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .info_hooks import chain
from .mu import derivative_wrt_pulse
from .parallelization import USE_THREADPOOL_LIMITS
from .parametrization import ParametrizedArray
from .propagators import Propagator, expm
from .result import Result
from .second_order import _overlap
Expand Down Expand Up @@ -441,7 +442,7 @@ def optimize_pulses(
if second_order:
for i_obj in range(len(objectives)):
forward_states[i_obj][0] = objectives[i_obj].initial_state
delta_eps = [
delta_u = [
np.zeros(len(tlist) - 1, dtype=np.complex128) for _ in guess_pulses
]
optimized_pulses = copy.deepcopy(guess_pulses)
Expand All @@ -467,12 +468,12 @@ def optimize_pulses(
update *= chi_norms[i_obj]
if second_order:
update += 0.5 * σ * overlap(delta_phis[i_obj], μ(Ψ))
delta_eps[i_pulse][time_index] += update
delta_u[i_pulse][time_index] += update
λₐ = lambda_vals[i_pulse]
S_t = shape_arrays[i_pulse][time_index]
Δϵ = (S_t / λₐ) * delta_eps[i_pulse][time_index].imag # ∈ ℝ
g_a_integrals[i_pulse] += abs(Δϵ) ** 2 * dt # dt may vary!
optimized_pulses[i_pulse][time_index] += Δϵ
Δu = (S_t / λₐ) * delta_u[i_pulse][time_index].imag # ∈ ℝ
g_a_integrals[i_pulse] += abs(Δu) ** 2 * dt # dt may vary!
_add_update(optimized_pulses[i_pulse], time_index, Δu)
# forward propagation
fw_states = parallel_map[2](
_forward_propagation_step,
Expand Down Expand Up @@ -884,6 +885,16 @@ def _backward_propagation(
return storage_array


def _add_update(pulse, time_index, Δu):
if isinstance(pulse, ParametrizedArray):
ϵ = pulse[time_index]
u = pulse.parametrization.parametrize(ϵ)
pulse[time_index] = pulse.parametrization.unparametrize(u + Δu)
else:
# ϵ = u ⇒ Δϵ = Δu
pulse[time_index] += Δu


def _forward_propagation_step(
i_state,
states,
Expand Down
118 changes: 118 additions & 0 deletions src/krotov/parametrization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
r"""Classes to realized parametrized optimization pulses."""
import sys
import warnings
from abc import ABCMeta, abstractmethod

import numpy as np


class ParametrizedFunction(metaclass=ABCMeta):
"""Wrapped function, adding a `parametrization` attribute."""

def __init__(self, func, parametrization):
self._func = func
self.parametrization = parametrization

def __call__(self, t, args):
return self._func(t, args)


class ParametrizedArray(np.ndarray):
"""Wrapped numpy array, adding a `parametrization` attribute."""

# See https://numpy.org/doc/stable/user/basics.subclassing.html
def __new__(cls, input_array, parametrization):
obj = np.asarray(input_array).view(cls)
obj.parametrization = parametrization
if not isinstance(obj.parametrization, Parametrization):
raise ValueError(
"parametrization must be a Parametrization instance, not %r"
% type(parametrization)
)
return obj

def __array_finalize__(self, obj):
if obj is None:
return
self.parametrization = getattr(obj, 'parametrization', None)


class Parametrization(metaclass=ABCMeta):
"""Abstract base class for a parametrizations."""

@abstractmethod
def parametrize(self, eps_val):
return NotImplementedError

@abstractmethod
def unparametrize(self, u_val):
return NotImplementedError

@abstractmethod
def derivative(self):
return NotImplementedError


class TanhParametrization(Parametrization):
def __init__(self, *, eps_max, eps_min):
self.eps_max = eps_max
self.eps_min = eps_min

def parametrize(self, eps_val):
ϵ_max = self.eps_max
ϵ_min = self.eps_min
ϵ = eps_val
if ϵ >= ϵ_max or ϵ <= ϵ_min:
warnings.warn(
"Pulse value %r out of range (%r, %r) for %s. "
"Value will be clipped."
% (ϵ, ϵ_min, ϵ_max, self.__class__.__name__)
)
Δ = ϵ_max - ϵ_min
a = np.clip(
2 * ϵ / Δ - (ϵ_max + ϵ_min) / Δ,
-1 + sys.float_info.epsilon,
1 - sys.float_info.epsilon,
)
u = np.arctanh(a) # -18.4 < u < 18.4
return u

def unparametrize(self, u_val):
ϵ_max = self.eps_max
ϵ_min = self.eps_min
u = u_val
cp = 0.5 * (ϵ_max + ϵ_min)
cm = 0.5 * (ϵ_max - ϵ_min)
ϵ = cm * np.tanh(u) + cp
return ϵ

def derivative(self, eps_val):
ϵ_max = self.eps_max
ϵ_min = self.eps_min
ϵ = eps_val
Δ = ϵ_max - ϵ_min
a = np.clip(
2 * ϵ / Δ - (ϵ_max + ϵ_min) / Δ,
-1 + sys.float_info.epsilon,
1 - sys.float_info.epsilon,
)
u = np.arctanh(a)
return 0.5 * Δ / np.cosh(u) ** 2


class SquareParametrization(Parametrization):
def parametrize(self, eps_val):
if eps_val < 0:
warnings.warn(
"Pulse value %r < 0 out of range for %s. Clip to 0."
% (eps_val, self.__class__.__name__)
)
eps_val = 0
return np.sqrt(eps_val)

def unparametrize(self, u_val):
return u_val ** 2

def derivative(self, eps_val):
u = self.parametrize(eps_val)
return 2 * u

0 comments on commit 809d94e

Please sign in to comment.