-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add get samples function to InstructionToSignals for JAX-jit usage #149
Changes from 18 commits
b5e2fad
7e2c039
9c87d08
8adbc79
6ad5815
c309427
5713e13
9a3acaa
3c6a78a
ccb9bdc
42d4b8c
a1637e0
ef502d3
a73d315
32ad385
8839c78
bf66412
9bc4972
39a1e90
aa10fbe
454f119
2814cf0
4044ebd
f8b055e
c3338c5
5a5d2ac
c56c9ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,5 +81,4 @@ | |
|
||
InstructionToSignals | ||
""" | ||
|
||
from .pulse_to_signals import InstructionToSignals |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,11 @@ | |
Pulse schedule to Signals converter. | ||
""" | ||
|
||
from typing import Dict, List, Optional | ||
from typing import Callable, Dict, List, Optional | ||
import functools | ||
|
||
import numpy as np | ||
import sympy as sym | ||
|
||
from qiskit.pulse import ( | ||
Schedule, | ||
|
@@ -30,8 +33,11 @@ | |
ControlChannel, | ||
AcquireChannel, | ||
) | ||
from qiskit.pulse.exceptions import PulseError | ||
from qiskit.pulse.library import SymbolicPulse | ||
from qiskit import QiskitError | ||
|
||
from qiskit_dynamics.array import Array | ||
from qiskit_dynamics.signals import DiscreteSignal | ||
|
||
|
||
|
@@ -40,19 +46,17 @@ class InstructionToSignals: | |
|
||
The :class:`InstructionsToSignals` class converts a pulse schedule to a list of signals that can | ||
be given to a model. This conversion is done by calling the :meth:`get_signals` method on a | ||
schedule. The converter applies to instances of :class:`~qiskit.pulse.Schedule`. Instances of | ||
:class:`~qiskit.pulse.ScheduleBlock` must first be converted to :class:`~qiskit.pulse.Schedule` | ||
using the :func:`~qiskit.pulse.transforms.block_to_schedule` function in Qiskit Pulse. | ||
|
||
The converter can be initialized with the optional arguments ``carriers`` and ``channels``. When | ||
``channels`` is given, only the signals specified by name in ``channels`` are returned. The | ||
``carriers`` dictionary specifies the analog carrier frequency of each channel. Here, the keys | ||
are the channel name, e.g. ``d12`` for drive channel number ``12``, and the values are the | ||
corresponding frequency. If a channel is not present in ``carriers`` it is assumed that the | ||
analog carrier frequency is zero. | ||
|
||
See the :meth:`get_signals` method documentation for a detailed description of how pulse | ||
schedules are interpreted and translated into :class:`.DiscreteSignal` objects. | ||
schedule. The converter applies to instances of :class:`Schedule`. Instances of | ||
:class:`ScheduleBlock` must first be converted to :class:`Schedule` using the | ||
:meth:`block_to_schedule` in Qiskit pulse. | ||
|
||
The converter can be initialized with the optional arguments ``carriers`` and ``channels``. | ||
These arguments change the returned signals of :meth:`get_signals`. When ``channels`` is given | ||
then only the signals specified by name in ``channels`` are returned. The ``carriers`` | ||
dictionary allows the user to specify the carrier frequency of the channels. Here, the keys are | ||
the channel name, e.g. ``d12`` for drive channel number 12, and the values are the corresponding | ||
frequency. If a channel is not present in ``carriers`` it is assumed that the carrier frequency | ||
is zero. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -64,14 +68,15 @@ def __init__( | |
"""Initialize pulse schedule to signals converter. | ||
|
||
Args: | ||
dt: Length of the samples. This is required by the converter as pulse schedule are | ||
specified in units of dt and typically do not carry the value of dt with them. | ||
carriers: A dict of analog carrier frequencies. The keys are the names of the channels | ||
dt: Length of the samples. This is required by the converter as pulse | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two code blocks are similarly changing some documentation I had changed previously. Maybe some issue with a previous attempt at merging |
||
schedule are specified in units of dt and typically do not carry the value of dt | ||
with them. | ||
carriers: A dict of carrier frequencies. The keys are the names of the channels | ||
and the values are the corresponding carrier frequency. | ||
channels: A list of channels that the :meth:`get_signals` method should return. This | ||
argument will cause :meth:`get_signals` to return the signals in the same order as | ||
the channels. Channels present in the schedule but absent from channels will not be | ||
included in the returned object. If None is given (the default) then all channels | ||
channels: A list of channels that the :meth:`get_signals` method should return. | ||
This argument will cause :meth:`get_signals` to return the signals in the same order | ||
as the channels. Channels present in the schedule but absent from channels will not | ||
be included in the returned object. If None is given (the default) then all channels | ||
present in the pulse schedule are returned. | ||
""" | ||
|
||
|
@@ -129,9 +134,8 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]: | |
|
||
Args: | ||
schedule: The schedule to represent in terms of signals. Instances of | ||
:class:`~qiskit.pulse.ScheduleBlock` must first be converted to | ||
:class:`~qiskit.pulse.Schedule` using the | ||
:func:`~qiskit.pulse.transforms.block_to_schedule` function in Qiskit Pulse. | ||
:class:`ScheduleBlock` must first be converted to :class:`Schedule` using the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly here, this is undoing some changes I had made. |
||
:meth:`block_to_schedule` in Qiskit pulse. | ||
|
||
Returns: | ||
A list of :class:`.DiscreteSignal` instances. | ||
|
@@ -167,29 +171,33 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]: | |
if isinstance(inst.pulse, Waveform): | ||
inst_samples = inst.pulse.samples | ||
else: | ||
DanPuzzuoli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
inst_samples = inst.pulse.get_waveform().samples | ||
inst_samples = get_samples(inst.pulse) | ||
|
||
# build sample array to append to signal | ||
times = self._dt * (start_sample + np.arange(len(inst_samples))) | ||
samples = inst_samples * np.exp( | ||
2.0j * np.pi * freq * times | ||
+ 1.0j * phi | ||
+ 2.0j * np.pi * phase_accumulations[chan] | ||
Array( | ||
2.0j * np.pi * freq * times | ||
+ 1.0j * phi | ||
+ 2.0j * np.pi * phase_accumulations[chan] | ||
) | ||
) | ||
signals[chan].add_samples(start_sample, samples) | ||
|
||
if isinstance(inst, ShiftPhase): | ||
phases[chan] += inst.phase | ||
|
||
if isinstance(inst, ShiftFrequency): | ||
frequency_shifts[chan] += inst.frequency | ||
phase_accumulations[chan] -= inst.frequency * start_sample * self._dt | ||
frequency_shifts[chan] = frequency_shifts[chan] + Array(inst.frequency) | ||
phase_accumulations[chan] = ( | ||
phase_accumulations[chan] - inst.frequency * start_sample * self._dt | ||
) | ||
|
||
if isinstance(inst, SetPhase): | ||
phases[chan] = inst.phase | ||
|
||
if isinstance(inst, SetFrequency): | ||
phase_accumulations[chan] -= ( | ||
phase_accumulations[chan] = phase_accumulations[chan] - ( | ||
(inst.frequency - (frequency_shifts[chan] + signals[chan].carrier_freq)) | ||
* start_sample | ||
* self._dt | ||
|
@@ -302,3 +310,59 @@ def _get_channel(self, channel_name: str): | |
raise QiskitError( | ||
f"Invalid channel name {channel_name} given to {self.__class__.__name__}." | ||
) from error | ||
|
||
|
||
def get_samples(pulse: SymbolicPulse) -> np.ndarray: | ||
"""Return samples filled according to the formula that the pulse | ||
represents and the parameter values it contains. | ||
|
||
Args: | ||
pulse: SymbolicPulse class. | ||
Returns: | ||
Samples of the pulse. | ||
Raises: | ||
PulseError: When parameters are not assigned. | ||
PulseError: When expression for pulse envelope is not assigned. | ||
PulseError: When a free symbol value is not defined in the pulse instance parameters. | ||
""" | ||
envelope = pulse.envelope | ||
pulse_params = pulse.parameters | ||
if pulse.is_parameterized(): | ||
raise PulseError("Unassigned parameter exists. All parameters must be assigned.") | ||
|
||
if envelope is None: | ||
raise PulseError("Pulse envelope expression is not assigned.") | ||
|
||
args = [] | ||
for symbol in sorted(envelope.free_symbols, key=lambda s: s.name): | ||
if symbol.name == "t": | ||
times = Array(np.arange(0, pulse_params["duration"]) + 1 / 2) | ||
args.insert(0, times.data) | ||
continue | ||
try: | ||
args.append(pulse_params[symbol.name]) | ||
except KeyError as ex: | ||
raise PulseError( | ||
f"Pulse parameter '{symbol.name}' is not defined for this instance. " | ||
"Please check your waveform expression is correct." | ||
) from ex | ||
return _lru_cache_expr(envelope, Array.default_backend())(*args) | ||
|
||
|
||
@functools.lru_cache(maxsize=None) | ||
def _lru_cache_expr(expr: sym.Expr, backend) -> Callable: | ||
"""A helper function to get lambdified expression. | ||
|
||
Args: | ||
expr: Symbolic expression to evaluate. | ||
backend: Array backend. | ||
Returns: | ||
lambdified expression. | ||
""" | ||
params = [] | ||
for param in sorted(expr.free_symbols, key=lambda s: s.name): | ||
if param.name == "t": | ||
params.insert(0, param) | ||
continue | ||
params.append(param) | ||
return sym.lambdify(params, expr, modules=backend) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,7 +54,8 @@ | |
|
||
|
||
try: | ||
from jax import jit | ||
from jax import core, jit | ||
import jax.numpy as jnp | ||
except ImportError: | ||
pass | ||
|
||
|
@@ -523,6 +524,8 @@ def solve( | |
Array.default_backend() == "jax" | ||
and (method == "jax_odeint" or _is_diffrax_method(method)) | ||
and all(isinstance(x, Schedule) for x in signals_list) | ||
# check if jit transformation is already performed. | ||
and not (isinstance(jnp.array(0), core.Tracer)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is related with the issue #175 |
||
): | ||
all_results = self._solve_schedule_list_jax( | ||
t_span_list=t_span_list, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
--- | ||
features: | ||
- | | ||
The logic of :class:`.InstructionToSignals` has been updated. | ||
This change allows you to jit compile a function including | ||
the converter with input pulse schedule that contains :class:`.SymbolicPulse`. | ||
|
||
.. code-block:: python | ||
|
||
from functools import partial | ||
import jax | ||
from qiskit import pulse | ||
from qiskit_dynamics.pulse import InstructionToSignals | ||
|
||
@partial(jax.jit, static_argnums=(0, 1)) | ||
def run_simulation(amp, sigma): | ||
converter = InstructionToSignals(dt=1, carriers=None) | ||
with pulse.build() as schedule: | ||
pulse.play(pulse.Gaussian(160, amp, sigma), pulse.DriveChannel(0)) | ||
signals = converter.get_signals(schedule) | ||
|
||
# continue with simulations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think something's happened here - this is undoing some changes I had made in a previous PR (that I don't think you are intentionally trying to change).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am so sorry to confuse you. As you say, it is not my intention. I restored to that of main.