From 2814cf0587ceabc6354cbf628560536bcad24577 Mon Sep 17 00:00:00 2001 From: to24toro Date: Wed, 15 Feb 2023 08:34:21 +0900 Subject: [PATCH] fix codes about TestPulseToSignalsJAXTransformations --- test/dynamics/pulse/test_pulse_to_signals.py | 32 ++++++++------------ 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/test/dynamics/pulse/test_pulse_to_signals.py b/test/dynamics/pulse/test_pulse_to_signals.py index 77a31823a..03df1fc47 100644 --- a/test/dynamics/pulse/test_pulse_to_signals.py +++ b/test/dynamics/pulse/test_pulse_to_signals.py @@ -39,19 +39,10 @@ from qiskit import QiskitError from qiskit_dynamics.pulse import InstructionToSignals -from qiskit_dynamics.pulse.pulse_to_signals import ( - get_samples, -) from qiskit_dynamics.signals import DiscreteSignal from ..common import QiskitDynamicsTestCase, TestJaxBase -try: - import jax -# pylint: disable=broad-except -except Exception: - pass - class TestPulseToSignals(QiskitDynamicsTestCase): """Tests the conversion between pulse schedules and signals.""" @@ -339,22 +330,19 @@ def test_InstructionToSignals(self): class TestPulseToSignalsJAXTransformations(QiskitDynamicsTestCase, TestJaxBase): - """Tests get_samples function by using Jax.""" + """Tests InstructionToSignals class by using Jax.""" def setUp(self): """Set up gaussian waveform samples for comparison.""" - self.gauss_get_waveform_samples = ( - pulse.Gaussian(duration=5, amp=0.983, sigma=2.0).get_waveform().samples - ) self.constant_get_waveform_samples = ( pulse.Constant(duration=5, amp=0.1).get_waveform().samples ) self._dt = 0.222 - def test_jit_get_samples(self): - """Test compiling to get samples of Pulse.""" + def test_InstructionToSignals_with_JAX(self): + """Test InstructionToSignals with JAX jit.""" - def jit_func_get_samples(amp): + def jit_func_instruction_to_signals(amp): parameters = {"amp": amp} _time, _amp, _duration = sym.symbols("t, amp, duration") envelope_expr = _amp * sym.Piecewise( @@ -370,11 +358,15 @@ def jit_func_get_samples(amp): envelope=envelope_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - return get_samples(instance) + with pulse.build() as schedule: + pulse.play(instance, pulse.DriveChannel(0)) + + converter = InstructionToSignals(self._dt, carriers={"d0": 5}) + return converter.get_signals(schedule)[0].samples - self.jit_wrap(jit_func_get_samples)(0.1) - self.jit_grad_wrap(jit_func_get_samples)(0.1) - jit_samples = jax.jit(jit_func_get_samples)(0.1) + self.jit_wrap(jit_func_instruction_to_signals)(0.1) + self.jit_grad_wrap(jit_func_instruction_to_signals)(0.1) + jit_samples = self.jit_wrap(jit_func_instruction_to_signals)(0.1) self.assertAllClose(jit_samples, self.constant_get_waveform_samples, atol=1e-7, rtol=1e-7) def test_pulse_types_combination_with_jax(self):