Skip to content

Commit

Permalink
conditional operation
Browse files Browse the repository at this point in the history
  • Loading branch information
daxfohl committed Nov 5, 2021
1 parent a1eced2 commit 06883d5
Show file tree
Hide file tree
Showing 10 changed files with 519 additions and 5 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@
CCNotPowGate,
CNOT,
CNotPowGate,
ConditionalOperation,
ControlledGate,
ControlledOperation,
cphase,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def _parallel_gate_op(gate, qubits):
'CCXPowGate': cirq.CCXPowGate,
'CCZPowGate': cirq.CCZPowGate,
'CNotPowGate': cirq.CNotPowGate,
'ConditionalOperation': cirq.ConditionalOperation,
'ControlledGate': cirq.ControlledGate,
'ControlledOperation': cirq.ControlledOperation,
'CSwapGate': cirq.CSwapGate,
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@
ParallelGateFamily,
)

from cirq.ops.conditional_operation import (
ConditionalOperation,
)

from cirq.ops.controlled_gate import (
ControlledGate,
)
Expand Down
155 changes: 155 additions & 0 deletions cirq-core/cirq/ops/conditional_operation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import (
AbstractSet,
Any,
cast,
Dict,
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
Union,
)

from cirq import protocols, value
from cirq.ops import raw_types

if TYPE_CHECKING:
import cirq


@value.value_equality
class ConditionalOperation(raw_types.Operation):
"""Augments existing operations to be conditionally executed."""

def __init__(
self,
sub_operation: 'cirq.Operation',
controls: Sequence[Union[str, 'cirq.MeasurementKey']],
):
controls = tuple(value.MeasurementKey(k) if isinstance(k, str) else k for k in controls)
if isinstance(sub_operation, ConditionalOperation):
# Auto-flatten nested controlled operations.
sub_operation = cast(ConditionalOperation, sub_operation)
self._controls = controls + sub_operation.controls
self._sub_operation = sub_operation.sub_operation
else:
self._controls = controls
self._sub_operation = sub_operation

@property
def controls(self):
return self._controls

@property
def sub_operation(self):
return self._sub_operation

@property
def qubits(self):
return self.sub_operation.qubits

def with_qubits(self, *new_qubits):
return ConditionalOperation(self._sub_operation.with_qubits(*new_qubits), self._controls)

def _decompose_(self):
result = protocols.decompose_once(self._sub_operation, NotImplemented)
if result is NotImplemented:
return NotImplemented

return [ConditionalOperation(op, self._controls) for op in result]

def _value_equality_values_(self):
return (frozenset(self._controls), self._sub_operation)

def __str__(self) -> str:
return repr(self)

def __repr__(self):
return f'ConditionalOperation({self._sub_operation!r}, {list(self._controls)!r})'

def _is_parameterized_(self) -> bool:
return protocols.is_parameterized(self._sub_operation)

def _parameter_names_(self) -> AbstractSet[str]:
return protocols.parameter_names(self._sub_operation)

def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'ConditionalOperation':
new_sub_op = protocols.resolve_parameters(self._sub_operation, resolver, recursive)
return ConditionalOperation(new_sub_op, self._controls)

def __pow__(self, exponent: Any) -> 'ConditionalOperation':
new_sub_op = protocols.pow(self._sub_operation, exponent, NotImplemented)
if new_sub_op is NotImplemented:
return NotImplemented # coverage: ignore
return ConditionalOperation(new_sub_op, self._controls)

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> Optional['protocols.CircuitDiagramInfo']:
sub_args = protocols.CircuitDiagramInfoArgs(
known_qubit_count=args.known_qubit_count,
known_qubits=args.known_qubits,
use_unicode_characters=args.use_unicode_characters,
precision=args.precision,
qubit_map=args.qubit_map,
)
sub_info = protocols.circuit_diagram_info(self._sub_operation, sub_args, None)
if sub_info is None:
return NotImplemented # coverage: ignore

wire_symbols = sub_info.wire_symbols + ('^',) * len(self._controls)
exponent_qubit_index = None
if sub_info.exponent_qubit_index is not None:
exponent_qubit_index = sub_info.exponent_qubit_index + len(self._controls)
elif sub_info.exponent is not None:
exponent_qubit_index = len(self._controls)
return protocols.CircuitDiagramInfo(
wire_symbols=wire_symbols,
exponent=sub_info.exponent,
exponent_qubit_index=exponent_qubit_index,
)

def _json_dict_(self) -> Dict[str, Any]:
return {
'cirq_type': self.__class__.__name__,
'controls': self._controls,
'sub_operation': self._sub_operation,
}

def _act_on_(self, args: 'cirq.ActOnArgs') -> bool:
def not_zero(measurement):
return any(i != 0 for i in measurement)

measurements = [args.log_of_measurement_results[str(key)] for key in self._controls]
if all(not_zero(measurement) for measurement in measurements):
protocols.act_on(self._sub_operation, args)
return True

def _with_key_path_(self, path: Tuple[str, ...]) -> 'ConditionalOperation':
return ConditionalOperation(
self._sub_operation, [protocols.with_key_path(k, path) for k in self._controls]
)

def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'ConditionalOperation':
return ConditionalOperation(
self._sub_operation,
[protocols.with_measurement_key_mapping(k, key_map) for k in self._controls],
)

def _control_keys_(self) -> Tuple[value.MeasurementKey, ...]:
return self._controls
Loading

0 comments on commit 06883d5

Please sign in to comment.