diff --git a/cirq-core/cirq/ops/common_channels.py b/cirq-core/cirq/ops/common_channels.py index 7e917a41449..47ad788ee2f 100644 --- a/cirq-core/cirq/ops/common_channels.py +++ b/cirq-core/cirq/ops/common_channels.py @@ -246,7 +246,7 @@ def asymmetric_depolarize( @value.value_equality -class DepolarizingChannel(gate_features.SupportsOnEachGate, raw_types.Gate): +class DepolarizingChannel(raw_types.Gate): """A channel that depolarizes one or several qubits.""" def __init__(self, p: float, n_qubits: int = 1) -> None: diff --git a/cirq-core/cirq/ops/common_channels_test.py b/cirq-core/cirq/ops/common_channels_test.py index f5e429def99..c21e2de70b6 100644 --- a/cirq-core/cirq/ops/common_channels_test.py +++ b/cirq-core/cirq/ops/common_channels_test.py @@ -241,11 +241,24 @@ def test_deprecated_on_each_for_depolarizing_channel_one_qubit(): def test_deprecated_on_each_for_depolarizing_channel_two_qubits(): - q0, q1 = cirq.LineQubit.range(2) + q0, q1, q2, q3, q4, q5 = cirq.LineQubit.range(6) op = cirq.DepolarizingChannel(p=0.1, n_qubits=2) - with pytest.raises(ValueError, match="one qubit"): + op.on_each([(q0, q1)]) + op.on_each([(q0, q1), (q2, q3)]) + op.on_each(zip([q0, q2, q4], [q1, q3, q5])) + with pytest.raises(ValueError, match='cannot be in varargs form'): op.on_each(q0, q1) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + op.on_each((q0, q1)) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + op.on_each([q0, q1]) + with pytest.raises(ValueError, match='Qid'): + op.on_each([('bogus object 0', 'bogus object 1')]) + with pytest.raises(ValueError, match='Qid'): + op.on_each(['01']) + with pytest.raises(ValueError, match='Qid'): + op.on_each([(False, None)]) def test_depolarizing_channel_apply_two_qubits(): diff --git a/cirq-core/cirq/ops/gate_features.py b/cirq-core/cirq/ops/gate_features.py index e977d770ce3..ba7527a605a 100644 --- a/cirq-core/cirq/ops/gate_features.py +++ b/cirq-core/cirq/ops/gate_features.py @@ -18,8 +18,10 @@ """ import abc -from typing import Union, Iterable, Any, List +from typing import Union, Iterable, Any, List, Sequence +from cirq import value +from cirq._compat import deprecated_class from cirq.ops import raw_types @@ -31,38 +33,19 @@ def qubit_index_to_equivalence_group_key(self, index: int) -> int: return 0 -class SupportsOnEachGate(raw_types.Gate, metaclass=abc.ABCMeta): - """A gate that can be applied to exactly one qubit.""" +class _SupportsOnEachGateMeta(value.ABCMetaImplementAnyOneOf): + def __instancecheck__(cls, instance): + return isinstance(instance, SingleQubitGate) or issubclass( + type(instance), SupportsOnEachGate + ) - def on_each(self, *targets: Union[raw_types.Qid, Iterable[Any]]) -> List[raw_types.Operation]: - """Returns a list of operations applying the gate to all targets. + +@deprecated_class(deadline='v0.14', fix='None, this feature is in `Gate` now.') +class SupportsOnEachGate(raw_types.Gate, metaclass=_SupportsOnEachGateMeta): + pass - Args: - *targets: The qubits to apply this gate to. - Returns: - Operations applying this gate to the target qubits. - - Raises: - ValueError if targets are not instances of Qid or List[Qid]. - ValueError if the gate operates on two or more Qids. - """ - if self._num_qubits_() > 1: - raise ValueError('This gate only supports on_each when it is a one qubit gate.') - operations = [] # type: List[raw_types.Operation] - for target in targets: - if isinstance(target, raw_types.Qid): - operations.append(self.on(target)) - elif isinstance(target, Iterable) and not isinstance(target, str): - operations.extend(self.on_each(*target)) - else: - raise ValueError( - f'Gate was called with type different than Qid. Type: {type(target)}' - ) - return operations - - -class SingleQubitGate(SupportsOnEachGate, metaclass=abc.ABCMeta): +class SingleQubitGate(raw_types.Gate, metaclass=abc.ABCMeta): """A gate that must be applied to exactly one qubit.""" def _num_qubits_(self) -> int: diff --git a/cirq-core/cirq/ops/gate_features_test.py b/cirq-core/cirq/ops/gate_features_test.py index ffe074cccdf..d3a4c54806d 100644 --- a/cirq-core/cirq/ops/gate_features_test.py +++ b/cirq-core/cirq/ops/gate_features_test.py @@ -12,12 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterator -from typing import Any - import pytest import cirq +from cirq.testing import assert_deprecated def test_single_qubit_gate_validate_args(): @@ -38,26 +36,6 @@ def matrix(self): g.validate_args([q1, q2]) -def test_single_qubit_gate_validates_on_each(): - class Dummy(cirq.SingleQubitGate): - def matrix(self): - pass - - g = Dummy() - assert g.num_qubits() == 1 - - test_qubits = [cirq.NamedQubit(str(i)) for i in range(3)] - - _ = g.on_each(*test_qubits) - _ = g.on_each(test_qubits) - - test_non_qubits = [str(i) for i in range(3)] - with pytest.raises(ValueError): - _ = g.on_each(*test_non_qubits) - with pytest.raises(ValueError): - _ = g.on_each(*test_non_qubits) - - def test_single_qubit_validates_on(): class Dummy(cirq.SingleQubitGate): def matrix(self): @@ -137,41 +115,6 @@ def matrix(self): g.validate_args([a, b, c, d]) -def test_on_each(): - class CustomGate(cirq.SingleQubitGate): - pass - - a = cirq.NamedQubit('a') - b = cirq.NamedQubit('b') - c = CustomGate() - - assert c.on_each() == [] - assert c.on_each(a) == [c(a)] - assert c.on_each(a, b) == [c(a), c(b)] - assert c.on_each(b, a) == [c(b), c(a)] - - assert c.on_each([]) == [] - assert c.on_each([a]) == [c(a)] - assert c.on_each([a, b]) == [c(a), c(b)] - assert c.on_each([b, a]) == [c(b), c(a)] - assert c.on_each([a, [b, a], b]) == [c(a), c(b), c(a), c(b)] - - with pytest.raises(ValueError): - c.on_each('abcd') - with pytest.raises(ValueError): - c.on_each(['abcd']) - with pytest.raises(ValueError): - c.on_each([a, 'abcd']) - - def iterator(qubits): - for i in range(len(qubits)): - yield qubits[i] - - qubit_iterator = iterator([a, b, a, b]) - assert isinstance(qubit_iterator, Iterator) - assert c.on_each(qubit_iterator) == [c(a), c(b), c(a), c(b)] - - def test_qasm_output_args_validate(): args = cirq.QasmArgs(version='2.0') args.validate_version('2.0') @@ -231,16 +174,39 @@ def __init__(self, num_qubits): g.validate_args([a, b, c, d]) -def test_on_each_iterable_qid(): - class QidIter(cirq.Qid): - @property - def dimension(self) -> int: - return 2 +def test_supports_on_each_inheritance_shim(): + class NotOnEach(cirq.Gate): + def num_qubits(self): + return 1 # coverage: ignore + + class OnEach(cirq.ops.gate_features.SupportsOnEachGate): + def num_qubits(self): + return 1 # coverage: ignore + + class SingleQ(cirq.SingleQubitGate): + pass + + class TwoQ(cirq.TwoQubitGate): + pass + + not_on_each = NotOnEach() + single_q = SingleQ() + two_q = TwoQ() + with assert_deprecated(deadline="v0.14"): + on_each = OnEach() + + assert not isinstance(not_on_each, cirq.ops.gate_features.SupportsOnEachGate) + assert isinstance(on_each, cirq.ops.gate_features.SupportsOnEachGate) + assert isinstance(single_q, cirq.ops.gate_features.SupportsOnEachGate) + assert not isinstance(two_q, cirq.ops.gate_features.SupportsOnEachGate) + assert isinstance(cirq.X, cirq.ops.gate_features.SupportsOnEachGate) + assert not isinstance(cirq.CX, cirq.ops.gate_features.SupportsOnEachGate) - def _comparison_key(self) -> Any: - return 1 - def __iter__(self): - raise NotImplementedError() +def test_supports_on_each_deprecation(): + class CustomGate(cirq.ops.gate_features.SupportsOnEachGate): + def num_qubits(self): + return 1 # coverage: ignore - assert cirq.H.on_each(QidIter())[0] == cirq.H.on(QidIter()) + with assert_deprecated(deadline="v0.14"): + assert isinstance(CustomGate(), cirq.ops.gate_features.SupportsOnEachGate) diff --git a/cirq-core/cirq/ops/identity.py b/cirq-core/cirq/ops/identity.py index c705b2fc9aa..d642e3b7fa1 100644 --- a/cirq-core/cirq/ops/identity.py +++ b/cirq-core/cirq/ops/identity.py @@ -20,14 +20,14 @@ from cirq import protocols, value from cirq._doc import document -from cirq.ops import gate_features, raw_types +from cirq.ops import raw_types if TYPE_CHECKING: import cirq @value.value_equality -class IdentityGate(gate_features.SupportsOnEachGate, raw_types.Gate): +class IdentityGate(raw_types.Gate): """A Gate that perform no operation on qubits. The unitary matrix of this gate is a diagonal matrix with all 1s on the diff --git a/cirq-core/cirq/ops/identity_test.py b/cirq-core/cirq/ops/identity_test.py index 2ec9191b741..f9e3ee49e71 100644 --- a/cirq-core/cirq/ops/identity_test.py +++ b/cirq-core/cirq/ops/identity_test.py @@ -66,8 +66,29 @@ def test_identity_on_each_only_single_qubit(): cirq.IdentityGate(1, (3,)).on(q0_3), cirq.IdentityGate(1, (3,)).on(q1_3), ] - with pytest.raises(ValueError, match='one qubit'): - cirq.IdentityGate(num_qubits=2).on_each(q0, q1) + + +def test_identity_on_each_two_qubits(): + q0, q1, q2, q3 = cirq.LineQubit.range(4) + q0_3, q1_3 = q0.with_dimension(3), q1.with_dimension(3) + assert cirq.IdentityGate(2).on_each([(q0, q1)]) == [cirq.IdentityGate(2)(q0, q1)] + assert cirq.IdentityGate(2).on_each([(q0, q1), (q2, q3)]) == [ + cirq.IdentityGate(2)(q0, q1), + cirq.IdentityGate(2)(q2, q3), + ] + assert cirq.IdentityGate(2, (3, 3)).on_each([(q0_3, q1_3)]) == [ + cirq.IdentityGate(2, (3, 3))(q0_3, q1_3), + ] + with pytest.raises(ValueError, match='cannot be in varargs form'): + cirq.IdentityGate(2).on_each(q0, q1) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + cirq.IdentityGate(2).on_each((q0, q1)) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + cirq.IdentityGate(2).on_each([[(q0, q1)]]) + with pytest.raises(ValueError, match='Expected 2 qubits'): + cirq.IdentityGate(2).on_each([(q0,)]) + with pytest.raises(ValueError, match='Expected 2 qubits'): + cirq.IdentityGate(2).on_each([(q0, q1, q2)]) @pytest.mark.parametrize('num_qubits', [1, 2, 4]) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 36afcc7624b..25adfa494c7 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -29,6 +29,8 @@ TypeVar, TYPE_CHECKING, Union, + Iterable, + List, ) import numpy as np @@ -375,6 +377,55 @@ def _rmul_with_qubits(self, qubits: Tuple['cirq.Qid', ...], other): def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, attribute_names=[]) + def on_each(self, *targets: Union[Qid, Iterable[Any]]) -> List['Operation']: + """Returns a list of operations applying the gate to all targets. + + Args: + *targets: The qubits to apply this gate to. For single-qubit gates + this can be provided as varargs or a combination of nested + iterables. For multi-qubit gates this must be provided as an + `Iterable[Sequence[Qid]]`, where each sequence has `num_qubits` + qubits. + + Returns: + Operations applying this gate to the target qubits. + + Raises: + ValueError if targets are not instances of Qid or Iterable[Qid]. + ValueError if the gate qubit number is incompatible. + """ + operations: List['Operation'] = [] + if self._num_qubits_() > 1: + if len(targets) != 1 or not isinstance(targets[0], Iterable): + raise ValueError(f'The inputs for multi-qubit gates cannot be in varargs form.') + for target in targets[0]: + if not isinstance(target, Sequence): + if isinstance(target, Qid): + raise ValueError( + f'The inputs for multi-qubit gates cannot be in varargs form.' + ) + else: + raise ValueError( + f'Inputs to multi-qubit gates must be Sequence[Qid].' + f' Type: {type(target)}' + ) + if not all(isinstance(x, Qid) for x in target): + raise ValueError(f'All values in sequence should be Qids, but got {target}') + if len(target) != self._num_qubits_(): + raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}') + operations.append(self.on(*target)) + else: + for target in targets: + if isinstance(target, Qid): + operations.append(self.on(target)) + elif isinstance(target, Iterable) and not isinstance(target, str): + operations.extend(self.on_each(*target)) + else: + raise ValueError( + f'Gate was called with type different than Qid. Type: {type(target)}' + ) + return operations + TSelf = TypeVar('TSelf', bound='Operation') diff --git a/cirq-core/cirq/ops/raw_types_test.py b/cirq-core/cirq/ops/raw_types_test.py index afb85664751..03248af92a1 100644 --- a/cirq-core/cirq/ops/raw_types_test.py +++ b/cirq-core/cirq/ops/raw_types_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import AbstractSet +from typing import AbstractSet, Iterator, Any import pytest import numpy as np @@ -739,3 +739,192 @@ def qubits(self): cirq.act_on(NoActOn()(q).with_tags("test"), args) with pytest.raises(TypeError, match="Failed to act"): cirq.act_on(MissingActOn().with_tags("test"), args) + + +def test_single_qubit_gate_validates_on_each(): + class Dummy(cirq.SingleQubitGate): + def matrix(self): + pass + + g = Dummy() + assert g.num_qubits() == 1 + + test_qubits = [cirq.NamedQubit(str(i)) for i in range(3)] + + _ = g.on_each(*test_qubits) + _ = g.on_each(test_qubits) + + test_non_qubits = [str(i) for i in range(3)] + with pytest.raises(ValueError): + _ = g.on_each(*test_non_qubits) + with pytest.raises(ValueError): + _ = g.on_each(*test_non_qubits) + + +def test_on_each(): + class CustomGate(cirq.SingleQubitGate): + pass + + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + c = CustomGate() + + assert c.on_each() == [] + assert c.on_each(a) == [c(a)] + assert c.on_each(a, b) == [c(a), c(b)] + assert c.on_each(b, a) == [c(b), c(a)] + + assert c.on_each([]) == [] + assert c.on_each([a]) == [c(a)] + assert c.on_each([a, b]) == [c(a), c(b)] + assert c.on_each([b, a]) == [c(b), c(a)] + assert c.on_each([a, [b, a], b]) == [c(a), c(b), c(a), c(b)] + + with pytest.raises(ValueError): + c.on_each('abcd') + with pytest.raises(ValueError): + c.on_each(['abcd']) + with pytest.raises(ValueError): + c.on_each([a, 'abcd']) + + def iterator(qubits): + for i in range(len(qubits)): + yield qubits[i] + + qubit_iterator = iterator([a, b, a, b]) + assert isinstance(qubit_iterator, Iterator) + assert c.on_each(qubit_iterator) == [c(a), c(b), c(a), c(b)] + + +def test_on_each_two_qubits(): + class CustomGate(cirq.TwoQubitGate): + pass + + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + g = CustomGate() + + assert g.on_each([]) == [] + assert g.on_each([(a, b)]) == [g(a, b)] + assert g.on_each([[a, b]]) == [g(a, b)] + assert g.on_each([(b, a)]) == [g(b, a)] + assert g.on_each([(a, b), (b, a)]) == [g(a, b), g(b, a)] + assert g.on_each(zip([a, b], [b, a])) == [g(a, b), g(b, a)] + with pytest.raises(ValueError, match='cannot be in varargs form'): + g.on_each() + with pytest.raises(ValueError, match='cannot be in varargs form'): + g.on_each(a, b) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + g.on_each((b, a)) + with pytest.raises(ValueError, match='cannot be in varargs form'): + g.on_each((a, b), (a, b)) + with pytest.raises(ValueError, match='cannot be in varargs form'): + g.on_each(*zip([a, b], [b, a])) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + g.on_each([12]) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + g.on_each([(a, b), 12]) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([(a, b), [(a, b)]]) + with pytest.raises(ValueError, match='Expected 2 qubits'): + g.on_each([()]) + with pytest.raises(ValueError, match='Expected 2 qubits'): + g.on_each([(a,)]) + with pytest.raises(ValueError, match='Expected 2 qubits'): + g.on_each([(a, b, a)]) + with pytest.raises(ValueError, match='Expected 2 qubits'): + g.on_each(zip([a, a])) + with pytest.raises(ValueError, match='Expected 2 qubits'): + g.on_each(zip([a, a], [b, b], [a, a])) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each('ab') + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each(('ab',)) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([('ab',)]) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([(a, 'ab')]) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([(a, 'b')]) + + def iterator(qubits): + for i in range(len(qubits)): + yield qubits[i] + + qubit_iterator = iterator([[a, b], [a, b]]) + assert isinstance(qubit_iterator, Iterator) + assert g.on_each(qubit_iterator) == [g(a, b), g(a, b)] + + +def test_on_each_three_qubits(): + class CustomGate(cirq.ThreeQubitGate): + pass + + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + c = cirq.NamedQubit('c') + g = CustomGate() + + assert g.on_each([]) == [] + assert g.on_each([(a, b, c)]) == [g(a, b, c)] + assert g.on_each([[a, b, c]]) == [g(a, b, c)] + assert g.on_each([(c, b, a)]) == [g(c, b, a)] + assert g.on_each([(a, b, c), (c, b, a)]) == [g(a, b, c), g(c, b, a)] + assert g.on_each(zip([a, c], [b, b], [c, a])) == [g(a, b, c), g(c, b, a)] + with pytest.raises(ValueError, match='cannot be in varargs form'): + g.on_each() + with pytest.raises(ValueError, match='cannot be in varargs form'): + g.on_each(a, b, c) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + g.on_each((c, b, a)) + with pytest.raises(ValueError, match='cannot be in varargs form'): + g.on_each((a, b, c), (a, b, c)) + with pytest.raises(ValueError, match='cannot be in varargs form'): + g.on_each(*zip([a, c], [b, b], [c, a])) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + g.on_each([12]) + with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): + g.on_each([(a, b, c), 12]) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([(a, b, c), [(a, b, c)]]) + with pytest.raises(ValueError, match='Expected 3 qubits'): + g.on_each([(a,)]) + with pytest.raises(ValueError, match='Expected 3 qubits'): + g.on_each([(a, b)]) + with pytest.raises(ValueError, match='Expected 3 qubits'): + g.on_each([(a, b, c, a)]) + with pytest.raises(ValueError, match='Expected 3 qubits'): + g.on_each(zip([a, a], [b, b])) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each('abc') + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each(('abc',)) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([('abc',)]) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([(a, 'abc')]) + with pytest.raises(ValueError, match='All values in sequence should be Qids'): + g.on_each([(a, 'bc')]) + + def iterator(qubits): + for i in range(len(qubits)): + yield qubits[i] + + qubit_iterator = iterator([[a, b, c], [a, b, c]]) + assert isinstance(qubit_iterator, Iterator) + assert g.on_each(qubit_iterator) == [g(a, b, c), g(a, b, c)] + + +def test_on_each_iterable_qid(): + class QidIter(cirq.Qid): + @property + def dimension(self) -> int: + return 2 + + def _comparison_key(self) -> Any: + return 1 + + def __iter__(self): + raise NotImplementedError() + + assert cirq.H.on_each(QidIter())[0] == cirq.H.on(QidIter())