From 06883d5d3d767eb08fe77bc0e4a0cf6ae7e0d64c Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Fri, 5 Nov 2021 12:47:49 -0700 Subject: [PATCH] conditional operation --- cirq-core/cirq/__init__.py | 1 + cirq-core/cirq/json_resolver_cache.py | 1 + cirq-core/cirq/ops/__init__.py | 4 + cirq-core/cirq/ops/conditional_operation.py | 155 +++++++++ .../cirq/ops/conditional_operation_test.py | 327 ++++++++++++++++++ cirq-core/cirq/ops/moment.py | 2 +- .../json_test_data/ConditionalOperation.json | 25 ++ .../json_test_data/ConditionalOperation.repr | 1 + cirq-core/cirq/sim/act_on_args.py | 4 +- .../cirq/sim/clifford/stabilizer_sampler.py | 4 +- 10 files changed, 519 insertions(+), 5 deletions(-) create mode 100644 cirq-core/cirq/ops/conditional_operation.py create mode 100644 cirq-core/cirq/ops/conditional_operation_test.py create mode 100644 cirq-core/cirq/protocols/json_test_data/ConditionalOperation.json create mode 100644 cirq-core/cirq/protocols/json_test_data/ConditionalOperation.repr diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index a0eaf08db05..631ce3ac104 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -192,6 +192,7 @@ CCNotPowGate, CNOT, CNotPowGate, + ConditionalOperation, ControlledGate, ControlledOperation, cphase, diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 04e1e00487a..048c9cf11cb 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -61,6 +61,7 @@ def _parallel_gate_op(gate, qubits): 'CCXPowGate': cirq.CCXPowGate, 'CCZPowGate': cirq.CCZPowGate, 'CNotPowGate': cirq.CNotPowGate, + 'ConditionalOperation': cirq.ConditionalOperation, 'ControlledGate': cirq.ControlledGate, 'ControlledOperation': cirq.ControlledOperation, 'CSwapGate': cirq.CSwapGate, diff --git a/cirq-core/cirq/ops/__init__.py b/cirq-core/cirq/ops/__init__.py index 0b4041e5641..54312060f2f 100644 --- a/cirq-core/cirq/ops/__init__.py +++ b/cirq-core/cirq/ops/__init__.py @@ -82,6 +82,10 @@ ParallelGateFamily, ) +from cirq.ops.conditional_operation import ( + ConditionalOperation, +) + from cirq.ops.controlled_gate import ( ControlledGate, ) diff --git a/cirq-core/cirq/ops/conditional_operation.py b/cirq-core/cirq/ops/conditional_operation.py new file mode 100644 index 00000000000..a5620b78ffc --- /dev/null +++ b/cirq-core/cirq/ops/conditional_operation.py @@ -0,0 +1,155 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import ( + AbstractSet, + Any, + cast, + Dict, + Optional, + Sequence, + TYPE_CHECKING, + Tuple, + Union, +) + +from cirq import protocols, value +from cirq.ops import raw_types + +if TYPE_CHECKING: + import cirq + + +@value.value_equality +class ConditionalOperation(raw_types.Operation): + """Augments existing operations to be conditionally executed.""" + + def __init__( + self, + sub_operation: 'cirq.Operation', + controls: Sequence[Union[str, 'cirq.MeasurementKey']], + ): + controls = tuple(value.MeasurementKey(k) if isinstance(k, str) else k for k in controls) + if isinstance(sub_operation, ConditionalOperation): + # Auto-flatten nested controlled operations. + sub_operation = cast(ConditionalOperation, sub_operation) + self._controls = controls + sub_operation.controls + self._sub_operation = sub_operation.sub_operation + else: + self._controls = controls + self._sub_operation = sub_operation + + @property + def controls(self): + return self._controls + + @property + def sub_operation(self): + return self._sub_operation + + @property + def qubits(self): + return self.sub_operation.qubits + + def with_qubits(self, *new_qubits): + return ConditionalOperation(self._sub_operation.with_qubits(*new_qubits), self._controls) + + def _decompose_(self): + result = protocols.decompose_once(self._sub_operation, NotImplemented) + if result is NotImplemented: + return NotImplemented + + return [ConditionalOperation(op, self._controls) for op in result] + + def _value_equality_values_(self): + return (frozenset(self._controls), self._sub_operation) + + def __str__(self) -> str: + return repr(self) + + def __repr__(self): + return f'ConditionalOperation({self._sub_operation!r}, {list(self._controls)!r})' + + def _is_parameterized_(self) -> bool: + return protocols.is_parameterized(self._sub_operation) + + def _parameter_names_(self) -> AbstractSet[str]: + return protocols.parameter_names(self._sub_operation) + + def _resolve_parameters_( + self, resolver: 'cirq.ParamResolver', recursive: bool + ) -> 'ConditionalOperation': + new_sub_op = protocols.resolve_parameters(self._sub_operation, resolver, recursive) + return ConditionalOperation(new_sub_op, self._controls) + + def __pow__(self, exponent: Any) -> 'ConditionalOperation': + new_sub_op = protocols.pow(self._sub_operation, exponent, NotImplemented) + if new_sub_op is NotImplemented: + return NotImplemented # coverage: ignore + return ConditionalOperation(new_sub_op, self._controls) + + def _circuit_diagram_info_( + self, args: 'cirq.CircuitDiagramInfoArgs' + ) -> Optional['protocols.CircuitDiagramInfo']: + sub_args = protocols.CircuitDiagramInfoArgs( + known_qubit_count=args.known_qubit_count, + known_qubits=args.known_qubits, + use_unicode_characters=args.use_unicode_characters, + precision=args.precision, + qubit_map=args.qubit_map, + ) + sub_info = protocols.circuit_diagram_info(self._sub_operation, sub_args, None) + if sub_info is None: + return NotImplemented # coverage: ignore + + wire_symbols = sub_info.wire_symbols + ('^',) * len(self._controls) + exponent_qubit_index = None + if sub_info.exponent_qubit_index is not None: + exponent_qubit_index = sub_info.exponent_qubit_index + len(self._controls) + elif sub_info.exponent is not None: + exponent_qubit_index = len(self._controls) + return protocols.CircuitDiagramInfo( + wire_symbols=wire_symbols, + exponent=sub_info.exponent, + exponent_qubit_index=exponent_qubit_index, + ) + + def _json_dict_(self) -> Dict[str, Any]: + return { + 'cirq_type': self.__class__.__name__, + 'controls': self._controls, + 'sub_operation': self._sub_operation, + } + + def _act_on_(self, args: 'cirq.ActOnArgs') -> bool: + def not_zero(measurement): + return any(i != 0 for i in measurement) + + measurements = [args.log_of_measurement_results[str(key)] for key in self._controls] + if all(not_zero(measurement) for measurement in measurements): + protocols.act_on(self._sub_operation, args) + return True + + def _with_key_path_(self, path: Tuple[str, ...]) -> 'ConditionalOperation': + return ConditionalOperation( + self._sub_operation, [protocols.with_key_path(k, path) for k in self._controls] + ) + + def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'ConditionalOperation': + return ConditionalOperation( + self._sub_operation, + [protocols.with_measurement_key_mapping(k, key_map) for k in self._controls], + ) + + def _control_keys_(self) -> Tuple[value.MeasurementKey, ...]: + return self._controls diff --git a/cirq-core/cirq/ops/conditional_operation_test.py b/cirq-core/cirq/ops/conditional_operation_test.py new file mode 100644 index 00000000000..1629e06ccc2 --- /dev/null +++ b/cirq-core/cirq/ops/conditional_operation_test.py @@ -0,0 +1,327 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sympy + +import cirq +from cirq.ops.conditional_operation import ConditionalOperation + + +def test_diagram(): + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit(cirq.measure(q0, key='a'), ConditionalOperation(cirq.X(q1), ['a'])) + + cirq.testing.assert_has_diagram( + c, + """ +0: ───M─────── + ║ +1: ───╫───X─── + ║ ║ +a: ═══@═══^═══ +""", + use_unicode_characters=True, + ) + + +def test_diagram_pauli(): + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit( + cirq.measure_single_paulistring(cirq.X(q0), key='a'), + ConditionalOperation(cirq.X(q1), ['a']), + ) + + cirq.testing.assert_has_diagram( + c, + """ +0: ───M(X)─────── + ║ +1: ───╫──────X─── + ║ ║ +a: ═══@══════^═══ +""", + use_unicode_characters=True, + ) + + +def test_diagram_extra_measurements(): + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.measure(q0, key='b'), + ConditionalOperation(cirq.X(q1), ['a']), + ) + + cirq.testing.assert_has_diagram( + c, + """ +0: ───M───M('b')─── + ║ +1: ───╫───X──────── + ║ ║ +a: ═══@═══^════════ +""", + use_unicode_characters=True, + ) + + +def test_diagram_extra_controlled_bits(): + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit( + cirq.measure(q0, key='a'), + ConditionalOperation(cirq.CX(q0, q1), ['a']), + ) + + cirq.testing.assert_has_diagram( + c, + """ +0: ───M───@─── + ║ ║ +1: ───╫───X─── + ║ ║ +a: ═══@═══^═══ +""", + use_unicode_characters=True, + ) + + +def test_diagram_extra_control_bits(): + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.measure(q0, key='b'), + ConditionalOperation(cirq.X(q1), ['a', 'b']), + ) + + cirq.testing.assert_has_diagram( + c, + """ +0: ───M───M─────── + ║ ║ +1: ───╫───╫───X─── + ║ ║ ║ +a: ═══@═══╬═══^═══ + ║ ║ +b: ═══════@═══^═══ +""", + use_unicode_characters=True, + ) + + +def test_diagram_multiple_ops_single_moment(): + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.measure(q1, key='b'), + ConditionalOperation(cirq.X(q0), ['a']), + ConditionalOperation(cirq.X(q1), ['b']), + ) + + cirq.testing.assert_has_diagram( + c, + """ + ┌──┐ ┌──┐ +0: ────M──────X───── + ║ ║ +1: ────╫M─────╫X──── + ║║ ║║ +a: ════@╬═════^╬════ + ║ ║ +b: ═════@══════^════ + └──┘ └──┘ +""", + use_unicode_characters=True, + ) + + +def test_diagram_subcircuit(): + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.measure(q0, key='a'), + ConditionalOperation(cirq.X(q1), ['a']), + ) + ) + ) + + cirq.testing.assert_has_diagram( + c, + """ + Circuit_0x0000000000000000: + [ 0: ───M─────── ] +0: ───[ ║ ]─── + [ 1: ───╫───X─── ] + [ ║ ║ ] + [ a: ═══@═══^═══ ] + │ +1: ───#2──────────────────────────── +""", + use_unicode_characters=True, + ) + + +def test_diagram_subcircuit_layered(): + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.measure(q0, key='a'), + ConditionalOperation(cirq.X(q1), ['a']), + ), + ), + ConditionalOperation(cirq.X(q1), ['a']), + ) + + cirq.testing.assert_has_diagram( + c, + """ + Circuit_0x0000000000000000: + [ 0: ───M─────── ] +0: ───M───[ ║ ]─────── + ║ [ 1: ───╫───X─── ] + ║ [ ║ ║ ] + ║ [ a: ═══@═══^═══ ] + ║ ║ +1: ───╫───#2────────────────────────────X─── + ║ ║ ║ +a: ═══@═══╩═════════════════════════════^═══ +""", + use_unicode_characters=True, + ) + + +def test_key_unset(): + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit( + cirq.measure(q0, key='a'), + ConditionalOperation(cirq.X(q1), ['a']), + cirq.measure(q1, key='b'), + ) + s = cirq.Simulator() + result = s.run(c) + assert result.measurements['a'] == 0 + assert result.measurements['b'] == 0 + + +def test_key_set(): + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit( + cirq.X(q0), + cirq.measure(q0, key='a'), + ConditionalOperation(cirq.X(q1), ['a']), + cirq.measure(q1, key='b'), + ) + s = cirq.Simulator() + result = s.run(c) + assert result.measurements['a'] == 1 + assert result.measurements['b'] == 1 + + +def test_subcircuit_key_unset(): + q0, q1 = cirq.LineQubit.range(2) + inner = cirq.Circuit( + cirq.measure(q0, key='c'), + ConditionalOperation(cirq.X(q1), ['c']), + cirq.measure(q1, key='b'), + ) + c = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=2, measurement_key_map={'c': 'a'}) + ) + s = cirq.Simulator() + result = s.run(c) + assert result.measurements['0:a'] == 0 + assert result.measurements['0:b'] == 0 + assert result.measurements['1:a'] == 0 + assert result.measurements['1:b'] == 0 + + +def test_subcircuit_key_set(): + q0, q1 = cirq.LineQubit.range(2) + inner = cirq.Circuit( + cirq.X(q0), + cirq.measure(q0, key='c'), + ConditionalOperation(cirq.X(q1), ['c']), + cirq.measure(q1, key='b'), + ) + c = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=8, measurement_key_map={'c': 'a'}) + ) + s = cirq.Simulator() + result = s.run(c) + assert result.measurements['0:a'] == 1 + assert result.measurements['0:b'] == 1 + assert result.measurements['1:a'] == 0 + assert result.measurements['1:b'] == 1 + assert result.measurements['2:a'] == 1 + assert result.measurements['2:b'] == 0 + assert result.measurements['3:a'] == 0 + assert result.measurements['3:b'] == 0 + assert result.measurements['4:a'] == 1 + assert result.measurements['4:b'] == 1 + assert result.measurements['5:a'] == 0 + assert result.measurements['5:b'] == 1 + assert result.measurements['6:a'] == 1 + assert result.measurements['6:b'] == 0 + assert result.measurements['7:a'] == 0 + assert result.measurements['7:b'] == 0 + + +def test_key_stacking(): + q0 = cirq.LineQubit(0) + inner = cirq.X(q0) + op = ConditionalOperation(ConditionalOperation(inner, ['a']), ['b']) + assert op.sub_operation is inner + assert set(map(str, op.controls)) == {'a', 'b'} + + +def test_qubit_mapping(): + q0, q1 = cirq.LineQubit.range(2) + op = ConditionalOperation(cirq.X(q0), ['a']) + assert op.with_qubits(q1).qubits == (q1,) + + +def test_parameterizable(): + a = sympy.Symbol('S') + q0 = cirq.LineQubit(0) + op = ConditionalOperation(cirq.X(q0), ['a']) + opa = ConditionalOperation(cirq.XPowGate(exponent=a).on(q0), ['a']) + assert cirq.is_parameterized(opa) + assert not cirq.is_parameterized(op) + assert cirq.resolve_parameters(opa, cirq.ParamResolver({'S': 1})) == op + + +def test_decompose(): + q0 = cirq.LineQubit(0) + op = ConditionalOperation(cirq.H(q0), ['a']) + assert cirq.decompose(op) == [ + ConditionalOperation(cirq.Y(q0) ** 0.5, ['a']), + ConditionalOperation(cirq.XPowGate(exponent=1.0, global_shift=-0.25).on(q0), ['a']), + ] + + +def test_str(): + q0 = cirq.LineQubit(0) + op = ConditionalOperation(cirq.X(q0), ['a']) + assert ( + str(op) + == "ConditionalOperation(cirq.X(cirq.LineQubit(0)), [cirq.MeasurementKey(name='a')])" + ) + + +def test_pow(): + q0 = cirq.LineQubit(0) + inner = cirq.X(q0) + op = ConditionalOperation(inner, ['a']) ** 2 + assert op.sub_operation == inner ** 2 diff --git a/cirq-core/cirq/ops/moment.py b/cirq-core/cirq/ops/moment.py index 483f3ec018f..f9e473fc108 100644 --- a/cirq-core/cirq/ops/moment.py +++ b/cirq-core/cirq/ops/moment.py @@ -214,7 +214,7 @@ def without_operations_touching(self, qubits: Iterable['cirq.Qid']) -> 'cirq.Mom def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): return Moment( protocols.with_measurement_key_mapping(op, key_map) - if protocols.is_measurement(op) + if protocols.is_measurement(op) or protocols.control_keys(op) else op for op in self.operations ) diff --git a/cirq-core/cirq/protocols/json_test_data/ConditionalOperation.json b/cirq-core/cirq/protocols/json_test_data/ConditionalOperation.json new file mode 100644 index 00000000000..0edb449a9e6 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ConditionalOperation.json @@ -0,0 +1,25 @@ +{ + "cirq_type": "ConditionalOperation", + "controls": [ + { + "cirq_type": "LineQubit", + "x": 0 + }, + { + "cirq_type": "LineQubit", + "x": 1 + } + ], + "sub_operation": { + "cirq_type": "SingleQubitPauliStringGateOperation", + "pauli": { + "cirq_type": "_PauliY", + "exponent": 1, + "global_shift": 0.0 + }, + "qubit": { + "cirq_type": "NamedQubit", + "name": "target" + } + } +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/ConditionalOperation.repr b/cirq-core/cirq/protocols/json_test_data/ConditionalOperation.repr new file mode 100644 index 00000000000..8c36428bb10 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ConditionalOperation.repr @@ -0,0 +1 @@ +cirq.ConditionalOperation(cirq.Y.on(cirq.NamedQubit('target')), [cirq.LineQubit(0), cirq.LineQubit(1)]) \ No newline at end of file diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index d3dc9416103..6c87c0f4e62 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -46,7 +46,7 @@ def __init__( self, prng: np.random.RandomState = None, qubits: Sequence['cirq.Qid'] = None, - log_of_measurement_results: Dict[str, Any] = None, + log_of_measurement_results: Dict[str, List[int]] = None, ): """Inits ActOnArgs. @@ -181,7 +181,7 @@ def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], targ functionality, if supported.""" @property - def log_of_measurement_results(self) -> Dict[str, Any]: + def log_of_measurement_results(self) -> Dict[str, List[int]]: return self._log_of_measurement_results @property diff --git a/cirq-core/cirq/sim/clifford/stabilizer_sampler.py b/cirq-core/cirq/sim/clifford/stabilizer_sampler.py index bb0e32da8f5..e2ef5ce514e 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_sampler.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_sampler.py @@ -53,7 +53,7 @@ def run_sweep( def _run(self, circuit: circuits.AbstractCircuit, repetitions: int) -> Dict[str, np.ndarray]: - measurements: Dict[str, List[int]] = { + measurements: Dict[str, List[np.ndarray]] = { key: [] for key in protocols.measurement_key_names(circuit) } qubits = circuit.all_qubits() @@ -69,6 +69,6 @@ def _run(self, circuit: circuits.AbstractCircuit, repetitions: int) -> Dict[str, protocols.act_on(op, state) for k, v in state.log_of_measurement_results.items(): - measurements[k].append(v) + measurements[k].append(np.array(v, dtype=np.uint8)) return {k: np.array(v) for k, v in measurements.items()}