diff --git a/cirq/contrib/acquaintance/permutation.py b/cirq/contrib/acquaintance/permutation.py index 4621e563cb0..7165529adba 100644 --- a/cirq/contrib/acquaintance/permutation.py +++ b/cirq/contrib/acquaintance/permutation.py @@ -13,8 +13,8 @@ # limitations under the License. import abc -from typing import (Any, cast, Dict, Iterable, Sequence, Tuple, TYPE_CHECKING, - TypeVar, Union, TYPE_CHECKING) +from typing import (Any, Dict, Iterable, Sequence, Tuple, TypeVar, Union, + TYPE_CHECKING) from cirq import circuits, ops, optimizers, protocols, value from cirq.type_workarounds import NotImplementedType @@ -230,9 +230,11 @@ def update_mapping(mapping: Dict[ops.Qid, LogicalIndex], operations: The operations to update according to. """ for op in ops.flatten_op_tree(operations): - if (isinstance(op, ops.GateOperation) and - isinstance(op.gate, PermutationGate)): - op.gate.update_mapping(mapping, op.qubits) + # Type check false positive (https://github.com/python/mypy/issues/5374) + gate = ops.op_gate_of_type(op, PermutationGate) # type: ignore + if gate is not None: + # Ignoring type warning about op.qubits not being a tuple. + gate.update_mapping(mapping, op.qubits) # type: ignore def get_logical_operations(operations: 'cirq.OP_TREE', @@ -250,10 +252,11 @@ def get_logical_operations(operations: 'cirq.OP_TREE', qubit. """ mapping = initial_mapping.copy() - for op in cast(Iterable['cirq.Operation'], ops.flatten_op_tree(operations)): - if (isinstance(op, ops.GateOperation) and - isinstance(op.gate, PermutationGate)): - op.gate.update_mapping(mapping, op.qubits) + for op in ops.flatten_to_ops(operations): + # Type check false positive (https://github.com/python/mypy/issues/5374) + gate = ops.op_gate_of_type(op, PermutationGate) # type: ignore + if gate is not None: + gate.update_mapping(mapping, op.qubits) else: for q in op.qubits: if mapping.get(q) is None: diff --git a/cirq/contrib/routing/router_test.py b/cirq/contrib/routing/router_test.py index 5e27e67ad87..93b32f98c13 100644 --- a/cirq/contrib/routing/router_test.py +++ b/cirq/contrib/routing/router_test.py @@ -155,3 +155,36 @@ def test_router_bad_args(): cirq.CZ(cirq.LineQubit(i), cirq.LineQubit(i + 1)) for i in range(5)) with pytest.raises(ValueError): ccr.route_circuit(circuit, device_graph, algo_name=algo_name) + + +def test_fail_when_operating_on_unmapped_qubits(): + a, b, t = cirq.LineQubit.range(3) + swap = cirq.contrib.acquaintance.SwapPermutationGate() + + # Works before swap. + swap_network = cirq.contrib.routing.SwapNetwork(cirq.Circuit(cirq.CZ(a, t)), + { + a: a, + b: b + }) + with pytest.raises(ValueError, match='acts on unmapped qubit'): + _ = list(swap_network.get_logical_operations()) + + # Works after swap. + swap_network = cirq.contrib.routing.SwapNetwork( + cirq.Circuit(swap(b, t), cirq.CZ(a, b)), { + a: a, + b: b + }) + with pytest.raises(ValueError, match='acts on unmapped qubit'): + _ = list(swap_network.get_logical_operations()) + + # This test case used to cause a CZ(a, a) to be created due to unmapped + # qubits being mapped to stale values. + swap_network = cirq.contrib.routing.SwapNetwork( + cirq.Circuit(swap(a, t), swap(b, t), cirq.CZ(a, b)), { + a: a, + b: b + }) + with pytest.raises(ValueError, match='acts on unmapped qubit'): + _ = list(swap_network.get_logical_operations()) diff --git a/cirq/contrib/routing/swap_network.py b/cirq/contrib/routing/swap_network.py index 178be1987ca..0b458d726f1 100644 --- a/cirq/contrib/routing/swap_network.py +++ b/cirq/contrib/routing/swap_network.py @@ -60,6 +60,10 @@ def __eq__(self, other) -> bool: return (self.circuit == other.circuit and self.initial_mapping == other.initial_mapping) + def __repr__(self): + return 'cirq.contrib.routing.SwapNetwork({!r}, {!r})'.format( + self.circuit, self.initial_mapping) + @property def device(self) -> 'cirq.Device': return self.circuit.device diff --git a/cirq/contrib/routing/swap_network_test.py b/cirq/contrib/routing/swap_network_test.py index 39715a0f9db..6c104a2c1d0 100644 --- a/cirq/contrib/routing/swap_network_test.py +++ b/cirq/contrib/routing/swap_network_test.py @@ -59,6 +59,15 @@ def test_swap_network_equality(circuits): et.add_equality_group(ccr.SwapNetwork(circuit, mapping)) +def test_repr(): + a, b = cirq.LineQubit.range(2) + cirq.testing.assert_equivalent_repr( + cirq.contrib.routing.SwapNetwork(cirq.Circuit(cirq.CZ(a, b)), { + a: a, + b: b + })) + + def test_swap_network_str(): n_qubits = 5 phys_qubits = cirq.GridQubit.rect(n_qubits, 1)