Skip to content

Commit

Permalink
add jax_grad test
Browse files Browse the repository at this point in the history
  • Loading branch information
Kento Ueda authored and Kento Ueda committed Dec 10, 2022
1 parent fbd5d75 commit f36cb66
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 61 deletions.
59 changes: 0 additions & 59 deletions example.py

This file was deleted.

20 changes: 18 additions & 2 deletions test/dynamics/signals/test_pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ddt import ddt, data, unpack
import numpy as np
import sympy as sym

from qiskit import pulse
from qiskit.pulse import (
Expand All @@ -31,6 +32,7 @@
Gaussian,
Constant,
Waveform,
SymbolicPulse,
)
from qiskit.pulse.transforms.canonicalization import block_to_schedule
from qiskit import QiskitError
Expand Down Expand Up @@ -266,8 +268,22 @@ def test_jit_get_samples(self):
"""Test compiling to get samples of Pulse."""

def jit_func(amp):
return get_samples(Constant(100, amp))

parameters = {"amp": amp}
_t, _amp, _duration = sym.symbols("t, amp, duration")
envelope_expr = _amp * sym.Piecewise((1, sym.And(_t >= 0, _t <= _duration)), (0, True))
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0
# we can use only SymbolicPulse when jax-jitting bacause jax-jitting doesn't correspond to validate_parameters in qiskit.pulse.
instance = SymbolicPulse(
pulse_type="Constant",
duration=100,
parameters=parameters,
envelope=envelope_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
return get_samples(instance)

self.jit_wrap(jit_func)(0.1)
self.jit_grad_wrap(jit_func)(0.1)
jit_samples = jax.jit(jit_func)(0.1)
self.assertAllClose(jit_samples, self.gauss_get_waveform_samples, atol=1e-7, rtol=1e-7)

Expand Down

0 comments on commit f36cb66

Please sign in to comment.