Skip to content

Commit

Permalink
fix limit_amplitude bug (#8308)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkanazawa1989 authored Jul 7, 2022
1 parent f134aeb commit d50d99a
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 4 deletions.
4 changes: 3 additions & 1 deletion qiskit/pulse/library/pulse.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ def __init__(
by default but may be set by the user to disable amplitude
checks globally.
"""
if limit_amplitude is None:
limit_amplitude = self.__class__.limit_amplitude

self.duration = duration
self.name = name
self._limit_amplitude = limit_amplitude or self.__class__.limit_amplitude
self._limit_amplitude = limit_amplitude

@property
def id(self) -> int: # pylint: disable=invalid-name
Expand Down
2 changes: 1 addition & 1 deletion qiskit/pulse/library/symbolic_pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def validate_parameters(self) -> None:
f"Assigned parameters {param_repr} violate following constraint: {const_repr}."
)

if self.limit_amplitude:
if self._limit_amplitude:
if self._valid_amp_conditions is not None:
fargs = _get_expression_args(self._valid_amp_conditions, self.parameters)
check_full_waveform = not bool(self._valid_amp_conditions_lam(*fargs))
Expand Down
4 changes: 2 additions & 2 deletions qiskit/qpy/binary_io/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _write_waveform(file_obj, data):
formats.WAVEFORM_PACK,
data.epsilon,
len(samples_bytes),
data.limit_amplitude,
data._limit_amplitude,
)
file_obj.write(header)
file_obj.write(samples_bytes)
Expand Down Expand Up @@ -210,7 +210,7 @@ def _write_symbolic_pulse(file_obj, data):
len(envelope_bytes),
len(constraints_bytes),
len(valid_amp_conditions_bytes),
data.limit_amplitude,
data._limit_amplitude,
)
file_obj.write(header_bytes)
file_obj.write(pulse_type_bytes)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
fixes:
- |
A bug that ``limit_amplitude`` set to an individual symbolic pulse or waveform instance
is not properly reflected to the parameter validation has been fixed.
In addition, QPY schedule ``dump`` has been fixed to save ``limit_amplitude`` value
tied to the instance, rather than saving the global class variable.
34 changes: 34 additions & 0 deletions test/python/pulse/test_pulse_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,14 @@ def test_gaussian_limit_amplitude(self):
waveform = Gaussian(duration=100, sigma=1.0, amp=1.1 + 0.8j)
self.assertGreater(np.abs(waveform.amp), 1.0)

def test_gaussian_limit_amplitude_per_instance(self):
"""Test that the check for amplitude per instance."""
with self.assertRaises(PulseError):
Gaussian(duration=100, sigma=1.0, amp=1.1 + 0.8j)

waveform = Gaussian(duration=100, sigma=1.0, amp=1.1 + 0.8j, limit_amplitude=False)
self.assertGreater(np.abs(waveform.amp), 1.0)

def test_gaussian_square_limit_amplitude(self):
"""Test that the check for amplitude less than or equal to 1 can be disabled."""
with self.assertRaises(PulseError):
Expand All @@ -299,6 +307,16 @@ def test_gaussian_square_limit_amplitude(self):
waveform = GaussianSquare(duration=100, sigma=1.0, amp=1.1 + 0.8j, width=10)
self.assertGreater(np.abs(waveform.amp), 1.0)

def test_gaussian_square_limit_amplitude_per_instance(self):
"""Test that the check for amplitude per instance."""
with self.assertRaises(PulseError):
GaussianSquare(duration=100, sigma=1.0, amp=1.1 + 0.8j, width=10)

waveform = GaussianSquare(
duration=100, sigma=1.0, amp=1.1 + 0.8j, width=10, limit_amplitude=False
)
self.assertGreater(np.abs(waveform.amp), 1.0)

def test_drag_limit_amplitude(self):
"""Test that the check for amplitude less than or equal to 1 can be disabled."""
with self.assertRaises(PulseError):
Expand All @@ -308,6 +326,14 @@ def test_drag_limit_amplitude(self):
waveform = Drag(duration=100, sigma=1.0, beta=1.0, amp=1.1 + 0.8j)
self.assertGreater(np.abs(waveform.amp), 1.0)

def test_drag_limit_amplitude_per_instance(self):
"""Test that the check for amplitude per instance."""
with self.assertRaises(PulseError):
Drag(duration=100, sigma=1.0, beta=1.0, amp=1.1 + 0.8j)

waveform = Drag(duration=100, sigma=1.0, beta=1.0, amp=1.1 + 0.8j, limit_amplitude=False)
self.assertGreater(np.abs(waveform.amp), 1.0)

def test_constant_limit_amplitude(self):
"""Test that the check for amplitude less than or equal to 1 can be disabled."""
with self.assertRaises(PulseError):
Expand All @@ -317,6 +343,14 @@ def test_constant_limit_amplitude(self):
waveform = Constant(duration=100, amp=1.1 + 0.8j)
self.assertGreater(np.abs(waveform.amp), 1.0)

def test_constant_limit_amplitude_per_instance(self):
"""Test that the check for amplitude per instance."""
with self.assertRaises(PulseError):
Constant(duration=100, amp=1.1 + 0.8j)

waveform = Constant(duration=100, amp=1.1 + 0.8j, limit_amplitude=False)
self.assertGreater(np.abs(waveform.amp), 1.0)

def test_get_parameters(self):
"""Test getting pulse parameters as attribute."""
drag_pulse = Drag(duration=100, amp=0.1, sigma=40, beta=3)
Expand Down
19 changes: 19 additions & 0 deletions test/python/qpy/test_block_load_from_qpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
GaussianSquare,
Drag,
Constant,
Waveform,
)
from qiskit.pulse.channels import (
DriveChannel,
Expand Down Expand Up @@ -97,6 +98,24 @@ def test_playing_custom_symbolic_pulse(self):
builder.play(my_pulse, DriveChannel(0))
self.assert_roundtrip_equal(test_sched)

def test_symbolic_amplitude_limit(self):
"""Test applying amplitude limit to symbolic pulse."""
with builder.build() as test_sched:
builder.play(
Gaussian(160, 20, 40, limit_amplitude=False),
DriveChannel(0),
)
self.assert_roundtrip_equal(test_sched)

def test_waveform_amplitude_limit(self):
"""Test applying amplitude limit to waveform."""
with builder.build() as test_sched:
builder.play(
Waveform([1, 2, 3, 4, 5], limit_amplitude=False),
DriveChannel(0),
)
self.assert_roundtrip_equal(test_sched)

def test_playing_waveform(self):
"""Test playing waveform."""
# pylint: disable=invalid-name
Expand Down

0 comments on commit d50d99a

Please sign in to comment.