Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get samples function to InstructionToSignals for JAX-jit usage #149

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b5e2fad
add sampling pulse for jax jit
to24toro Nov 9, 2022
7e2c039
add a release note
to24toro Nov 11, 2022
9c87d08
comment change of tolerance
to24toro Nov 11, 2022
8adbc79
change test_jit_get_samples
to24toro Nov 25, 2022
6ad5815
add jax_grad test
to24toro Dec 2, 2022
c309427
Fix pulse to signal converter (#164)
to24toro Dec 15, 2022
5713e13
wrap signal samples as Array
to24toro Jan 13, 2023
9a3acaa
format
to24toro Jan 13, 2023
3c6a78a
modify test_pulse_to_signals
to24toro Jan 13, 2023
ccb9bdc
add test_jit_solve_with_internal_jit
to24toro Jan 13, 2023
42d4b8c
impl jax-jit for pulse simulation
to24toro Jan 13, 2023
a1637e0
format by black
to24toro Feb 1, 2023
ef502d3
add type hint
to24toro Feb 2, 2023
a73d315
add test_SymbolicPulse
to24toro Feb 2, 2023
32ad385
modify test_SymbolicPulse
to24toro Feb 2, 2023
8839c78
modify test_SymbolicPulse
to24toro Feb 2, 2023
bf66412
add test_pulse_types_combination_with_jax
to24toro Feb 2, 2023
9bc4972
remove redefined fucntion
to24toro Feb 2, 2023
39a1e90
restore previous comments
to24toro Feb 8, 2023
aa10fbe
add removed sentence
to24toro Feb 8, 2023
454f119
rename class name from TestJaxGetSamples to TestPulseToSignalsJAXTran…
to24toro Feb 14, 2023
2814cf0
fix codes about TestPulseToSignalsJAXTransformations
to24toro Feb 14, 2023
4044ebd
Merge branch 'main' into jax_implementation_to_pulse
to24toro Feb 14, 2023
f8b055e
modify importing and naming of pulse function
to24toro Feb 15, 2023
c3338c5
insert spaces in InstructionToSignals docs
to24toro Feb 15, 2023
5a5d2ac
Merge branch 'main' into jax_implementation_to_pulse
DanPuzzuoli Feb 15, 2023
c56c9ca
Merge branch 'main' into jax_implementation_to_pulse
DanPuzzuoli Feb 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion qiskit_dynamics/pulse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,4 @@

InstructionToSignals
"""

from .pulse_to_signals import InstructionToSignals
82 changes: 74 additions & 8 deletions qiskit_dynamics/pulse/pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
Pulse schedule to Signals converter.
"""

from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional
import functools

import numpy as np
import sympy as sym

from qiskit.pulse import (
Schedule,
Expand All @@ -30,8 +33,11 @@
ControlChannel,
AcquireChannel,
)
from qiskit.pulse.exceptions import PulseError
from qiskit.pulse.library import SymbolicPulse
from qiskit import QiskitError

from qiskit_dynamics.array import Array
from qiskit_dynamics.signals import DiscreteSignal


Expand Down Expand Up @@ -166,29 +172,33 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]:
if isinstance(inst.pulse, Waveform):
inst_samples = inst.pulse.samples
else:
DanPuzzuoli marked this conversation as resolved.
Show resolved Hide resolved
inst_samples = inst.pulse.get_waveform().samples
inst_samples = get_samples(inst.pulse)

# build sample array to append to signal
times = self._dt * (start_sample + np.arange(len(inst_samples)))
samples = inst_samples * np.exp(
2.0j * np.pi * freq * times
+ 1.0j * phi
+ 2.0j * np.pi * phase_accumulations[chan]
Array(
2.0j * np.pi * freq * times
+ 1.0j * phi
+ 2.0j * np.pi * phase_accumulations[chan]
)
)
signals[chan].add_samples(start_sample, samples)

if isinstance(inst, ShiftPhase):
phases[chan] += inst.phase

if isinstance(inst, ShiftFrequency):
frequency_shifts[chan] += inst.frequency
phase_accumulations[chan] -= inst.frequency * start_sample * self._dt
frequency_shifts[chan] = frequency_shifts[chan] + Array(inst.frequency)
phase_accumulations[chan] = (
phase_accumulations[chan] - inst.frequency * start_sample * self._dt
)

if isinstance(inst, SetPhase):
phases[chan] = inst.phase

if isinstance(inst, SetFrequency):
phase_accumulations[chan] -= (
phase_accumulations[chan] = phase_accumulations[chan] - (
(inst.frequency - (frequency_shifts[chan] + signals[chan].carrier_freq))
* start_sample
* self._dt
Expand Down Expand Up @@ -301,3 +311,59 @@ def _get_channel(self, channel_name: str):
raise QiskitError(
f"Invalid channel name {channel_name} given to {self.__class__.__name__}."
) from error


def get_samples(pulse: SymbolicPulse) -> np.ndarray:
"""Return samples filled according to the formula that the pulse
represents and the parameter values it contains.

Args:
pulse: SymbolicPulse class.
Returns:
Samples of the pulse.
Raises:
PulseError: When parameters are not assigned.
PulseError: When expression for pulse envelope is not assigned.
PulseError: When a free symbol value is not defined in the pulse instance parameters.
"""
envelope = pulse.envelope
pulse_params = pulse.parameters
if pulse.is_parameterized():
raise PulseError("Unassigned parameter exists. All parameters must be assigned.")

if envelope is None:
raise PulseError("Pulse envelope expression is not assigned.")

args = []
for symbol in sorted(envelope.free_symbols, key=lambda s: s.name):
if symbol.name == "t":
times = Array(np.arange(0, pulse_params["duration"]) + 1 / 2)
args.insert(0, times.data)
continue
try:
args.append(pulse_params[symbol.name])
except KeyError as ex:
raise PulseError(
f"Pulse parameter '{symbol.name}' is not defined for this instance. "
"Please check your waveform expression is correct."
) from ex
return _lru_cache_expr(envelope, Array.default_backend())(*args)


@functools.lru_cache(maxsize=None)
def _lru_cache_expr(expr: sym.Expr, backend) -> Callable:
"""A helper function to get lambdified expression.

Args:
expr: Symbolic expression to evaluate.
backend: Array backend.
Returns:
lambdified expression.
"""
params = []
for param in sorted(expr.free_symbols, key=lambda s: s.name):
if param.name == "t":
params.insert(0, param)
continue
params.append(param)
return sym.lambdify(params, expr, modules=backend)
5 changes: 4 additions & 1 deletion qiskit_dynamics/solvers/solver_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@


try:
from jax import jit
from jax import core, jit
import jax.numpy as jnp
except ImportError:
pass

Expand Down Expand Up @@ -522,6 +523,8 @@ def solve(
Array.default_backend() == "jax"
and (method == "jax_odeint" or _is_diffrax_method(method))
and all(isinstance(x, Schedule) for x in signals_list)
# check if jit transformation is already performed.
and not (isinstance(jnp.array(0), core.Tracer))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related with the issue #175

):
all_results = self._solve_schedule_list_jax(
t_span_list=t_span_list,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
---
features:
- |
The logic of :class:`.InstructionToSignals` has been updated.
This change allows you to jit compile a function including
the converter with input pulse schedule that contains :class:`.SymbolicPulse`.

.. code-block:: python

from functools import partial
import jax
from qiskit import pulse
from qiskit_dynamics.pulse import InstructionToSignals

@partial(jax.jit, static_argnums=(0, 1))
def run_simulation(amp, sigma):
converter = InstructionToSignals(dt=1, carriers=None)
with pulse.build() as schedule:
pulse.play(pulse.Gaussian(160, amp, sigma), pulse.DriveChannel(0))
signals = converter.get_signals(schedule)

# continue with simulations
Loading