Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Array class from pulse_to_signals #327

Merged
merged 10 commits into from
Feb 23, 2024
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.8'
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- uses: actions/setup-python@v4
name: Install Python
with:
python-version: '3.8'
python-version: '3.10'
- name: Install Deps
run: pip install -U wheel
- name: Build Artifacts
Expand Down
3 changes: 3 additions & 0 deletions docs/tutorials/optimizing_pulse_sequence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ entry on :ref:`JAX-compatible pulse schedules <how-to use pulse schedules for ja
)
)

# we need to set disable_validation True to enable jax-jitting.
pulse.ScalableSymbolicPulse.disable_validation = True

return pulse.ScalableSymbolicPulse(
pulse_type="GaussianSquare",
duration=230,
Expand Down
3 changes: 3 additions & 0 deletions docs/userguide/how_to_use_pulse_schedule_for_jax_jit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ JAX-compiled (or more generally, JAX-transformed).
_amp * sym.exp(sym.I * _angle) * lifted_gaussian(_t, _center, _duration + 1, _sigma)
)

# we need to set disable_validation True to enable jax-jitting.
pulse.ScalableSymbolicPulse.disable_validation = True

gaussian_pulse = pulse.ScalableSymbolicPulse(
pulse_type="Gaussian",
duration=160,
Expand Down
12 changes: 10 additions & 2 deletions qiskit_dynamics/pulse/pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from qiskit.pulse.library import SymbolicPulse
from qiskit import QiskitError

from qiskit_dynamics.array import Array
from qiskit_dynamics import DYNAMICS_NUMPY as unp
from qiskit_dynamics import ArrayLike

Expand Down Expand Up @@ -349,6 +348,12 @@ def get_samples(pulse: SymbolicPulse) -> ArrayLike:
raise PulseError("Pulse envelope expression is not assigned.")

args = []
try:
backend = (
"jax" if any(isinstance(v, jax.core.Tracer) for v in pulse_params.values()) else "numpy"
)
except (ImportError, NameError):
backend = "numpy"
for symbol in sorted(envelope.free_symbols, key=lambda s: s.name):
if symbol.name == "t":
times = unp.arange(0, pulse_params["duration"]) + 1 / 2
Expand All @@ -361,7 +366,10 @@ def get_samples(pulse: SymbolicPulse) -> ArrayLike:
f"Pulse parameter '{symbol.name}' is not defined for this instance. "
"Please check your waveform expression is correct."
) from ex
return _lru_cache_expr(envelope, Array.default_backend())(*args)
return _lru_cache_expr(
envelope,
backend,
)(*args)


@functools.lru_cache(maxsize=None)
Expand Down
44 changes: 35 additions & 9 deletions test/dynamics/pulse/test_pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@

from qiskit_ibm_runtime.fake_provider import FakeQuito

try:
from jax import jit
except ImportError:
pass

from qiskit_dynamics.pulse import InstructionToSignals
from qiskit_dynamics.signals import DiscreteSignal

from ..common import QiskitDynamicsTestCase, TestJaxBase
from ..common import QiskitDynamicsTestCase, JAXTestBase


class TestPulseToSignals(QiskitDynamicsTestCase):
Expand Down Expand Up @@ -358,14 +363,17 @@ def test_barrier_instructions(self):
self.assertAllClose(sigs[1].samples, np.array([0.0, 0.0, 0.0, -0.5, -0.5, -0.5]))


class TestPulseToSignalsJAXTransformations(QiskitDynamicsTestCase, TestJaxBase):
class TestPulseToSignalsJAXTransformations(JAXTestBase):
"""Tests InstructionToSignals class by using Jax."""

def setUp(self):
"""Set up gaussian waveform samples for comparison."""
self.constant_get_waveform_samples = (
pulse.Constant(duration=5, amp=0.1).get_waveform().samples
)
self.gaussian_get_waveform_samples = (
pulse.Gaussian(duration=5, amp=0.983, sigma=2.0).get_waveform().samples
)
self._dt = 0.222

def test_InstructionToSignals_with_JAX(self):
Expand All @@ -378,8 +386,8 @@ def jit_func_instruction_to_signals(amp):
(1, sym.And(_time >= 0, _time <= _duration)), (0, True)
)
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0
# we can use only SymbolicPulse when jax-jitting
# bacause jax-jitting doesn't correspond to validate_parameters in qiskit.pulse.
# we need to set disable_validation True to enable jax-jitting.
pulse.SymbolicPulse.disable_validation = True
instance = pulse.SymbolicPulse(
pulse_type="Constant",
duration=5,
Expand All @@ -393,11 +401,27 @@ def jit_func_instruction_to_signals(amp):
converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

self.jit_wrap(jit_func_instruction_to_signals)(0.1)
self.jit_grad_wrap(jit_func_instruction_to_signals)(0.1)
jit_samples = self.jit_wrap(jit_func_instruction_to_signals)(0.1)
def jit_func_gaussian_to_signals(amp):
pulse.Gaussian.disable_validation = True
instance = pulse.Gaussian(duration=5, amp=amp, sigma=2.0)
with pulse.build() as schedule:
pulse.play(instance, pulse.DriveChannel(0))

converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

jit(jit_func_instruction_to_signals)(0.1)
self.jit_grad(jit_func_instruction_to_signals)(0.1)
jit_samples = jit(jit_func_instruction_to_signals)(0.1)
self.assertAllClose(jit_samples, self.constant_get_waveform_samples, atol=1e-7, rtol=1e-7)

jit(jit_func_gaussian_to_signals)(0.983)
self.jit_grad(jit_func_gaussian_to_signals)(0.983)
jit_gaussian_samples = jit(jit_func_gaussian_to_signals)(0.983)
self.assertAllClose(
jit_gaussian_samples, self.gaussian_get_waveform_samples, atol=1e-7, rtol=1e-7
)

def test_pulse_types_combination_with_jax(self):
"""Test that converting schedule including some pulse types with Jax works well"""

Expand All @@ -408,6 +432,8 @@ def jit_func_symbolic_pulse(amp):
(1, sym.And(_time >= 0, _time <= _duration)), (0, True)
)
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0
# we need to set disable_validation True to enable jax-jitting.
pulse.SymbolicPulse.disable_validation = True
instance = pulse.SymbolicPulse(
pulse_type="Constant",
duration=5,
Expand All @@ -430,8 +456,8 @@ def jit_func_symbolic_pulse(amp):
converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

self.jit_wrap(jit_func_symbolic_pulse)(0.1)
self.jit_grad_wrap(jit_func_symbolic_pulse)(0.1)
jit(jit_func_symbolic_pulse)(0.1)
self.jit_grad(jit_func_symbolic_pulse)(0.1)


@ddt
Expand Down
2 changes: 2 additions & 0 deletions test/dynamics/solvers/test_solver_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,8 @@ def constant_pulse(amp):
(1, sym.And(_time >= 0, _time <= _duration)), (0, True)
)
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0
# we need to set disable_validation True to enable jax-jitting.
pulse.SymbolicPulse.disable_validation = True
return pulse.SymbolicPulse(
pulse_type="Constant",
duration=5,
Expand Down
Loading