From 059ad54b1ff6881f78e963050f551dbcc301b91b Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 11 Oct 2024 14:27:35 -0400 Subject: [PATCH 1/5] MeasurementValue raises an error when used as boolean --- doc/releases/changelog-dev.md | 2 ++ pennylane/measurements/mid_measure.py | 8 +++++++- tests/measurements/test_mid_measure.py | 9 +++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index e077b96306b..e5d9d675fb9 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -258,6 +258,8 @@

Bug fixes 🐛

+* `MeasurementValue` now raises an error when it is used as a boolean. + * `adjoint_metric_tensor` now works with circuits containing state preparation operations. [(#6358)](https://github.com/PennyLaneAI/pennylane/pull/6358) diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py index fd971dc2192..abf03757f51 100644 --- a/pennylane/measurements/mid_measure.py +++ b/pennylane/measurements/mid_measure.py @@ -15,8 +15,9 @@ This module contains the qml.measure measurement. """ import uuid +from collections.abc import Hashable from functools import lru_cache -from typing import Generic, Hashable, Optional, TypeVar, Union +from typing import Generic, Optional, TypeVar, Union import pennylane as qml from pennylane.wires import Wires @@ -456,6 +457,11 @@ def __invert__(self): value.""" return self._apply(qml.math.logical_not) + def __bool__(self) -> bool: + raise ValueError( + "The truth value of a MeasurementValue is undefined. To condition on a MeasurementValue, please use qml.cond instead." + ) + def __eq__(self, other): return self._transform_bin_op(lambda a, b: a == b, other) diff --git a/tests/measurements/test_mid_measure.py b/tests/measurements/test_mid_measure.py index 379524be83d..622673f1fc2 100644 --- a/tests/measurements/test_mid_measure.py +++ b/tests/measurements/test_mid_measure.py @@ -81,6 +81,15 @@ def test_label(self, postselect, reset, expected): class TestMeasurementValueManipulation: """Test all the dunder methods associated with the MeasurementValue class""" + def test_error_on_boolean_conversion(self): + """Test that an error is raised if a measurement value if used as a boolean.""" + + m = MeasurementValue([mp1], lambda v: v) + + with pytest.raises(ValueError, match="The truth value of a MeasurementValue"): + if m: + return + def test_apply_function_to_measurement(self): """Test the general _apply method that can apply an arbitrary function to a measurement.""" From ffd709b4c1358cd2e0586d5918eeb088fa0649f1 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Fri, 11 Oct 2024 14:33:27 -0400 Subject: [PATCH 2/5] Update doc/releases/changelog-dev.md --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 7849e4228e6..91641f01a6a 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -262,6 +262,7 @@

Bug fixes 🐛

* `MeasurementValue` now raises an error when it is used as a boolean. + [(#6386)](https://github.com/PennyLaneAI/pennylane/pull/6386) * `adjoint_metric_tensor` now works with circuits containing state preparation operations. [(#6358)](https://github.com/PennyLaneAI/pennylane/pull/6358) From 6f3ce2183a65aa87d13a0967232a5a785e86e420 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 11 Oct 2024 15:07:28 -0400 Subject: [PATCH 3/5] fix test failures --- pennylane/devices/qubit/simulate.py | 23 ++++++++++++---------- pennylane/measurements/mid_measure.py | 2 +- pennylane/transforms/defer_measurements.py | 2 +- pennylane/transforms/dynamic_one_shot.py | 8 ++++---- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index f535f5428e6..2105aa52641 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -663,7 +663,7 @@ def split_circuit_at_mcms(circuit): last_circuit_measurements = [] for m in circuit.measurements: - if not m.mv: + if m.mv is None: last_circuit_measurements.append(m) circuits.append( @@ -698,16 +698,17 @@ def prepend_state_prep(circuit, state, interface, wires): def insert_mcms(circuit, results, mid_measurements): """Inserts terminal measurements of MCMs if the circuit is evaluated in analytic mode.""" - if circuit.shots or not any(m.mv for m in circuit.measurements): + if circuit.shots or all(m.mv is None for m in circuit.measurements): return results results = list(results) new_results = [] mid_measurements = {k: qml.math.array([[v]]) for k, v in mid_measurements.items()} for m in circuit.measurements: - if m.mv: - new_results.append(gather_mcm(m, mid_measurements, qml.math.array([[True]]))) - else: + if m.mv is None: new_results.append(results.pop(0)) + else: + new_results.append(gather_mcm(m, mid_measurements, qml.math.array([[True]]))) + return new_results @@ -841,9 +842,9 @@ def variance_post_processing(results): extra_measurements = [] for m in circuit.measurements: if isinstance(m, VarianceMP): - obs2 = m.mv * m.mv if m.mv else m.obs @ m.obs + obs2 = m.mv * m.mv if m.mv is not None else m.obs @ m.obs new_measurements.append(ExpectationMP(obs=obs2)) - extra_measurements.append(ExpectationMP(obs=m.mv if m.mv else m.obs)) + extra_measurements.append(ExpectationMP(obs=m.mv if m.mv is not None else m.obs)) else: new_measurements.append(m) new_measurements.extend(extra_measurements) @@ -870,16 +871,18 @@ def combine_measurements(terminal_measurements, results, mcm_samples): """Returns combined measurement values of various types.""" empty_mcm_samples = False need_mcm_samples = not all(v is None for v in mcm_samples.values()) - need_mcm_samples = need_mcm_samples and any(circ_meas.mv for circ_meas in terminal_measurements) + need_mcm_samples = need_mcm_samples and any( + circ_meas.mv is not None for circ_meas in terminal_measurements + ) if need_mcm_samples: empty_mcm_samples = len(next(iter(mcm_samples.values()))) == 0 if empty_mcm_samples and any(len(m) != 0 for m in mcm_samples.values()): # pragma: no cover raise ValueError("mcm_samples have inconsistent shapes.") final_measurements = [] for circ_meas in terminal_measurements: - if need_mcm_samples and circ_meas.mv and empty_mcm_samples: + if need_mcm_samples and circ_meas.mv is not None and empty_mcm_samples: comb_meas = measurement_with_no_shots(circ_meas) - elif need_mcm_samples and circ_meas.mv: + elif need_mcm_samples and circ_meas.mv is not None: mcm_samples = {k: v.reshape((-1, 1)) for k, v in mcm_samples.items()} is_valid = qml.math.ones(list(mcm_samples.values())[0].shape[0], dtype=bool) comb_meas = gather_mcm(circ_meas, mcm_samples, is_valid) diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py index abf03757f51..afe89e680da 100644 --- a/pennylane/measurements/mid_measure.py +++ b/pennylane/measurements/mid_measure.py @@ -574,6 +574,6 @@ def find_post_processed_mcms(circuit): if isinstance(m.mv, list): for mv in m.mv: post_processed_mcms = post_processed_mcms | set(mv.measurements) - elif m.mv: + elif m.mv is not None: post_processed_mcms = post_processed_mcms | set(m.mv.measurements) return post_processed_mcms diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py index 60f1446b111..79bf95cf2ad 100644 --- a/pennylane/transforms/defer_measurements.py +++ b/pennylane/transforms/defer_measurements.py @@ -39,7 +39,7 @@ def _check_tape_validity(tape: QuantumScript): for mp in tape.measurements: if isinstance(mp, (CountsMP, ProbabilityMP, SampleMP)) and not ( - mp.obs or mp._wires or mp.mv + mp.obs or mp._wires or mp.mv is not None ): raise ValueError( f"Cannot use {mp.__class__.__name__} as a measurement without specifying wires " diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 56103205bae..d5b25a3072f 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -204,7 +204,7 @@ def init_auxiliary_tape(circuit: qml.tape.QuantumScript): """ new_measurements = [] for m in circuit.measurements: - if not m.mv: + if m.mv is None: if isinstance(m, VarianceMP): new_measurements.append(SampleMP(obs=m.obs)) else: @@ -281,13 +281,13 @@ def measurement_with_no_shots(measurement): raise TypeError( f"Native mid-circuit measurement mode does not support {type(m).__name__} measurements." ) - if interface != "jax" and m.mv and not has_valid: + if interface != "jax" and m.mv is not None and not has_valid: meas = measurement_with_no_shots(m) - elif m.mv and active_qjit: + elif m.mv is not None and active_qjit: meas = gather_mcm_qjit( m, mcm_samples, is_valid, postselect_mode=postselect_mode ) # pragma: no cover - elif m.mv: + elif m.mv is not None: meas = gather_mcm(m, mcm_samples, is_valid, postselect_mode=postselect_mode) elif interface != "jax" and not has_valid: meas = measurement_with_no_shots(m) From ec0a47d3060c54e5ecf4f1f8fa8b6ebb1444a812 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 11 Oct 2024 16:41:14 -0400 Subject: [PATCH 4/5] fix more test failures --- pennylane/drawer/tape_mpl.py | 4 ++-- pennylane/measurements/counts.py | 2 +- pennylane/ops/op_math/condition.py | 30 ++++++++++++++--------------- tests/ops/op_math/test_condition.py | 2 +- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/pennylane/drawer/tape_mpl.py b/pennylane/drawer/tape_mpl.py index 2626cde639d..4155ff3123f 100644 --- a/pennylane/drawer/tape_mpl.py +++ b/pennylane/drawer/tape_mpl.py @@ -180,7 +180,7 @@ def _(op: qml.ops.op_math.Conditional, drawer, layer, config) -> None: def _get_measured_wires(measurements, wires) -> set: measured_wires = set() for m in measurements: - if not m.mv: + if m.mv is None: # state and probs if len(m.wires) == 0: return wires @@ -210,7 +210,7 @@ def _get_measured_bits(measurements, bit_map, offset): if isinstance(m.mv, list): for mv in m.mv: measured_bits += [bit_map[mcm] + offset for mcm in mv.measurements] - elif m.mv: + elif m.mv is not None: measured_bits += [bit_map[mcm] + offset for mcm in m.mv.measurements] return measured_bits diff --git a/pennylane/measurements/counts.py b/pennylane/measurements/counts.py index 0a1669d7403..565ed3b96d1 100644 --- a/pennylane/measurements/counts.py +++ b/pennylane/measurements/counts.py @@ -197,7 +197,7 @@ def _flatten(self): return (self.obs or self.mv, self._eigvals), metadata def __repr__(self): - if self.mv: + if self.mv is not None: return f"CountsMP({repr(self.mv)}, all_outcomes={self.all_outcomes})" if self.obs: return f"CountsMP({self.obs}, all_outcomes={self.all_outcomes})" diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index dabdffb4bb5..4e5db25c154 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -715,21 +715,21 @@ def _(*all_args, jaxpr_branches, n_consts_per_branch, n_args): for pred, jaxpr, n_consts in zip(conditions, jaxpr_branches, n_consts_per_branch): consts = consts_flat[start : start + n_consts] start += n_consts - if pred and jaxpr is not None: - if isinstance(pred, qml.measurements.MeasurementValue): - with qml.queuing.AnnotatedQueue() as q: - out = jax.core.eval_jaxpr(jaxpr, consts, *args) - - if len(out) != 0: - raise ConditionalTransformError( - "Only quantum functions without return values can be applied " - "conditionally with mid-circuit measurement predicates." - ) - for wrapped_op in q: - Conditional(pred, wrapped_op.obj) - - else: - return jax.core.eval_jaxpr(jaxpr, consts, *args) + if jaxpr is None: + continue + if isinstance(pred, qml.measurements.MeasurementValue): + with qml.queuing.AnnotatedQueue() as q: + out = jax.core.eval_jaxpr(jaxpr, consts, *args) + + if len(out) != 0: + raise ConditionalTransformError( + "Only quantum functions without return values can be applied " + "conditionally with mid-circuit measurement predicates." + ) + for wrapped_op in q: + Conditional(pred, wrapped_op.obj) + elif pred: + return jax.core.eval_jaxpr(jaxpr, consts, *args) return () diff --git a/tests/ops/op_math/test_condition.py b/tests/ops/op_math/test_condition.py index 2db00e49fc5..f5b9426d276 100644 --- a/tests/ops/op_math/test_condition.py +++ b/tests/ops/op_math/test_condition.py @@ -493,7 +493,7 @@ def test_adjoint(self): adj_op = op.adjoint() assert isinstance(adj_op, Conditional) - assert adj_op.meas_val == op.meas_val + assert adj_op.meas_val is op.meas_val assert adj_op.base == base.adjoint() From 12fd29cc183fe8fd5534778dc88fb07a20d2772d Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 15 Oct 2024 11:02:32 -0400 Subject: [PATCH 5/5] fix some tests --- pennylane/devices/_qubit_device.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pennylane/devices/_qubit_device.py b/pennylane/devices/_qubit_device.py index 8498e5d66ef..f56c001dc01 100644 --- a/pennylane/devices/_qubit_device.py +++ b/pennylane/devices/_qubit_device.py @@ -646,7 +646,8 @@ def statistics( # uses a list of mid-circuit measurement values obs = m # pragma: no cover else: - obs = m.obs or m.mv or m + obs = m.obs or m.mv + obs = m if obs is None else obs # Check if there is an overriden version of the measurement process if method := getattr(self, self.measurement_map[type(m)], False): if isinstance(m, MeasurementTransform):