Skip to content

Commit

Permalink
fix codes about TestPulseToSignalsJAXTransformations
Browse files Browse the repository at this point in the history
  • Loading branch information
to24toro committed Feb 14, 2023
1 parent 454f119 commit 2814cf0
Showing 1 changed file with 12 additions and 20 deletions.
32 changes: 12 additions & 20 deletions test/dynamics/pulse/test_pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,10 @@
from qiskit import QiskitError

from qiskit_dynamics.pulse import InstructionToSignals
from qiskit_dynamics.pulse.pulse_to_signals import (
get_samples,
)
from qiskit_dynamics.signals import DiscreteSignal

from ..common import QiskitDynamicsTestCase, TestJaxBase

try:
import jax
# pylint: disable=broad-except
except Exception:
pass


class TestPulseToSignals(QiskitDynamicsTestCase):
"""Tests the conversion between pulse schedules and signals."""
Expand Down Expand Up @@ -339,22 +330,19 @@ def test_InstructionToSignals(self):


class TestPulseToSignalsJAXTransformations(QiskitDynamicsTestCase, TestJaxBase):
"""Tests get_samples function by using Jax."""
"""Tests InstructionToSignals class by using Jax."""

def setUp(self):
"""Set up gaussian waveform samples for comparison."""
self.gauss_get_waveform_samples = (
pulse.Gaussian(duration=5, amp=0.983, sigma=2.0).get_waveform().samples
)
self.constant_get_waveform_samples = (
pulse.Constant(duration=5, amp=0.1).get_waveform().samples
)
self._dt = 0.222

def test_jit_get_samples(self):
"""Test compiling to get samples of Pulse."""
def test_InstructionToSignals_with_JAX(self):
"""Test InstructionToSignals with JAX jit."""

def jit_func_get_samples(amp):
def jit_func_instruction_to_signals(amp):
parameters = {"amp": amp}
_time, _amp, _duration = sym.symbols("t, amp, duration")
envelope_expr = _amp * sym.Piecewise(
Expand All @@ -370,11 +358,15 @@ def jit_func_get_samples(amp):
envelope=envelope_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
return get_samples(instance)
with pulse.build() as schedule:
pulse.play(instance, pulse.DriveChannel(0))

converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

self.jit_wrap(jit_func_get_samples)(0.1)
self.jit_grad_wrap(jit_func_get_samples)(0.1)
jit_samples = jax.jit(jit_func_get_samples)(0.1)
self.jit_wrap(jit_func_instruction_to_signals)(0.1)
self.jit_grad_wrap(jit_func_instruction_to_signals)(0.1)
jit_samples = self.jit_wrap(jit_func_instruction_to_signals)(0.1)
self.assertAllClose(jit_samples, self.constant_get_waveform_samples, atol=1e-7, rtol=1e-7)

def test_pulse_types_combination_with_jax(self):
Expand Down

0 comments on commit 2814cf0

Please sign in to comment.