Skip to content

Commit

Permalink
add test_pulse_types_combination_with_jax
Browse files Browse the repository at this point in the history
  • Loading branch information
to24toro committed Feb 2, 2023
1 parent d892d26 commit 03592c1
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 10 deletions.
6 changes: 3 additions & 3 deletions qiskit_dynamics/pulse/pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Pulse schedule to Signals converter.
"""

from typing import Callable ,Dict, List, Optional
from typing import Callable, Dict, List, Optional
import functools

import numpy as np
Expand Down Expand Up @@ -266,7 +266,7 @@ def _get_channel(self, channel_name: str):
) from error


def get_samples(pulse: SymbolicPulse)-> np.ndarray:
def get_samples(pulse: SymbolicPulse) -> np.ndarray:
"""Return samples filled according to the formula that the pulse
represents and the parameter values it contains.
Expand Down Expand Up @@ -304,7 +304,7 @@ def get_samples(pulse: SymbolicPulse)-> np.ndarray:


@functools.lru_cache(maxsize=None)
def _lru_cache_expr(expr: sym.Expr, backend)-> Callable:
def _lru_cache_expr(expr: sym.Expr, backend) -> Callable:
"""A helper function to get lambdified expression.
Args:
Expand Down
47 changes: 40 additions & 7 deletions test/dynamics/pulse/test_pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,16 @@
from qiskit.pulse.transforms.canonicalization import block_to_schedule
from qiskit import QiskitError

from qiskit_dynamics.array import Array
from qiskit_dynamics.pulse import InstructionToSignals
from qiskit_dynamics.pulse.pulse_to_signals import (
get_samples,
_lru_cache_expr,
)
from qiskit_dynamics.signals import DiscreteSignal

from ..common import QiskitDynamicsTestCase, TestJaxBase

try:
import jax
import jax.numpy as jnp
# pylint: disable=broad-except
except Exception:
pass
Expand Down Expand Up @@ -352,11 +349,12 @@ def setUp(self):
self.constant_get_waveform_samples = (
pulse.Constant(duration=5, amp=0.1).get_waveform().samples
)
self._dt = 0.222

def test_jit_get_samples(self):
"""Test compiling to get samples of Pulse."""

def jit_func(amp):
def jit_func_get_samples(amp):
parameters = {"amp": amp}
_time, _amp, _duration = sym.symbols("t, amp, duration")
envelope_expr = _amp * sym.Piecewise(
Expand All @@ -374,11 +372,46 @@ def jit_func(amp):
)
return get_samples(instance)

self.jit_wrap(jit_func)(0.1)
self.jit_grad_wrap(jit_func)(0.1)
jit_samples = jax.jit(jit_func)(0.1)
self.jit_wrap(jit_func_get_samples)(0.1)
self.jit_grad_wrap(jit_func_get_samples)(0.1)
jit_samples = jax.jit(jit_func_get_samples)(0.1)
self.assertAllClose(jit_samples, self.constant_get_waveform_samples, atol=1e-7, rtol=1e-7)

def test_pulse_types_combination_with_jax(self):
"""Test that converting schedule including some pulse types with Jax works well"""

def jit_func_symbolic_pulse(amp):
parameters = {"amp": amp}
_time, _amp, _duration = sym.symbols("t, amp, duration")
envelope_expr = _amp * sym.Piecewise(
(1, sym.And(_time >= 0, _time <= _duration)), (0, True)
)
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0
instance = SymbolicPulse(
pulse_type="Constant",
duration=5,
parameters=parameters,
envelope=envelope_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
# constrcut a pulse schedule with mixing some pulse types to test jax-jitting it
with pulse.build() as schedule:
pulse.play(instance, pulse.DriveChannel(0))
pulse.set_phase(0.1, pulse.DriveChannel(0))
pulse.set_frequency(0.1, pulse.DriveChannel(0))
pulse.shift_phase(0.1, pulse.DriveChannel(0))
pulse.set_phase(0.1, pulse.DriveChannel(0))
pulse.shift_frequency(0.1, pulse.DriveChannel(0))
pulse.shift_frequency(0.1, pulse.DriveChannel(0))
pulse.set_frequency(0.1, pulse.DriveChannel(0))
pulse.shift_phase(0.1, pulse.DriveChannel(0))
pulse.play(instance, pulse.DriveChannel(0))
converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

self.jit_wrap(jit_func_symbolic_pulse)(0.1)
self.jit_grad_wrap(jit_func_symbolic_pulse)(0.1)


@ddt
class TestPulseToSignalsFiltering(QiskitDynamicsTestCase):
Expand Down

0 comments on commit 03592c1

Please sign in to comment.