diff --git a/cirq_qubitization/__init__.py b/cirq_qubitization/__init__.py index 97fcab401f..a08ef54449 100644 --- a/cirq_qubitization/__init__.py +++ b/cirq_qubitization/__init__.py @@ -35,6 +35,7 @@ qfree, Register, Registers, + SelectionRegisters, SimpleQubitManager, ) from cirq_qubitization.generic_select import GenericSelect diff --git a/cirq_qubitization/cirq_algos/apply_gate_to_lth_target.ipynb b/cirq_qubitization/cirq_algos/apply_gate_to_lth_target.ipynb index 8206a82797..74b2e293ad 100644 --- a/cirq_qubitization/cirq_algos/apply_gate_to_lth_target.ipynb +++ b/cirq_qubitization/cirq_algos/apply_gate_to_lth_target.ipynb @@ -47,10 +47,9 @@ "`selection`-th qubit of `target` all controlled by the `control` register.\n", "\n", "#### Parameters\n", - " - `selection_bitsize`: The size of the indexing `select` register. This should be at most `log2(target_bitsize)`\n", - " - `target_bitsize`: The size of the `target` register. This also serves as the iteration length.\n", - " - `nth_gate`: A function mapping the selection index to a single-qubit gate.\n", - " - `control_bitsize`: The size of the control register.\n" + " - `selection_regs`: Indexing `select` registers of type `SelectionRegisters`. It also contains information about the iteration length of each selection register.\n", + " - `nth_gate`: A function mapping the composite selection index to a single-qubit gate.\n", + " - `control_regs`: Control registers for constructing a controlled version of the gate.\n" ] }, { @@ -63,6 +62,7 @@ "outputs": [], "source": [ "from cirq_qubitization.cirq_algos.apply_gate_to_lth_target import ApplyGateToLthQubit\n", + "from cirq_qubitization.cirq_infra.gate_with_registers import Registers, SelectionRegisters\n", "\n", "def _z_to_odd(n: int):\n", " if n % 2 == 1:\n", @@ -70,7 +70,9 @@ " return cirq.I\n", "\n", "apply_z_to_odd = ApplyGateToLthQubit(\n", - " selection_bitsize=3, target_bitsize=4, nth_gate=_z_to_odd, control_bitsize=2\n", + " SelectionRegisters.build(selection=(3, 4)),\n", + " nth_gate=_z_to_odd,\n", + " control_regs=Registers.build(control=2),\n", ")\n", "\n", "g = cq_testing.GateHelper(\n", diff --git a/cirq_qubitization/cirq_algos/apply_gate_to_lth_target.py b/cirq_qubitization/cirq_algos/apply_gate_to_lth_target.py index a9b9a20cce..c3eb9c0259 100644 --- a/cirq_qubitization/cirq_algos/apply_gate_to_lth_target.py +++ b/cirq_qubitization/cirq_algos/apply_gate_to_lth_target.py @@ -1,12 +1,15 @@ +import itertools from functools import cached_property from typing import Callable, Sequence, Tuple import cirq +from attrs import frozen from cirq_qubitization.cirq_algos.unary_iteration import UnaryIterationGate -from cirq_qubitization.cirq_infra.gate_with_registers import Registers +from cirq_qubitization.cirq_infra.gate_with_registers import Registers, SelectionRegisters +@frozen class ApplyGateToLthQubit(UnaryIterationGate): r"""A controlled SELECT operation for single-qubit gates. @@ -20,62 +23,52 @@ class ApplyGateToLthQubit(UnaryIterationGate): `selection`-th qubit of `target` all controlled by the `control` register. Args: - selection_bitsize: The size of the indexing `select` register. This should be at most - `log2(target_bitsize)` - target_bitsize: The size of the `target` register. This also serves as the iteration - length. - nth_gate: A function mapping the selection index to a single-qubit gate. - control_bitsize: The size of the control register. + selection_regs: Indexing `select` registers of type `SelectionRegisters`. It also contains + information about the iteration length of each selection register. + nth_gate: A function mapping the composite selection index to a single-qubit gate. + control_regs: Control registers for constructing a controlled version of the gate. """ - - def __init__( - self, - selection_bitsize: int, - target_bitsize: int, - nth_gate: Callable[[int], cirq.Gate], - *, - control_bitsize: int = 1, - ): - self._selection_bitsize = selection_bitsize - self._target_bitsize = target_bitsize - self._nth_gate = nth_gate - self._control_bitsize = control_bitsize + selection_regs: SelectionRegisters + nth_gate: Callable[[int, ...], cirq.Gate] + control_regs: Registers = Registers.build(control=1) @classmethod def make_on( - cls, *, nth_gate: Callable[[int], cirq.Gate], **quregs: Sequence[cirq.Qid] + cls, *, nth_gate: Callable[[int, ...], cirq.Gate], **quregs: Sequence[cirq.Qid] ) -> cirq.Operation: """Helper constructor to automatically deduce bitsize attributes.""" return cls( - selection_bitsize=len(quregs['selection']), - target_bitsize=len(quregs['target']), + SelectionRegisters.build(selection=(len(quregs['selection']), len(quregs['target']))), nth_gate=nth_gate, - control_bitsize=len(quregs['control']), + control_regs=Registers.build(control=len(quregs['control'])), ).on_registers(**quregs) @cached_property def control_registers(self) -> Registers: - return Registers.build(control=self._control_bitsize) + return self.control_regs @cached_property - def selection_registers(self) -> Registers: - return Registers.build(selection=self._selection_bitsize) + def selection_registers(self) -> SelectionRegisters: + return self.selection_regs @cached_property def target_registers(self) -> Registers: - return Registers.build(target=self._target_bitsize) + return Registers.build(target=self.selection_registers.total_iteration_size) @cached_property def iteration_lengths(self) -> Tuple[int, ...]: - return (self._target_bitsize,) + return self.selection_registers.iteration_lengths def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: wire_symbols = ["@"] * self.control_registers.bitsize wire_symbols += ["In"] * self.selection_registers.bitsize - wire_symbols += [str(self._nth_gate(i)) for i in range(self._target_bitsize)] + for it in itertools.product(*[range(x) for x in self.selection_regs.iteration_lengths]): + wire_symbols += [str(self.nth_gate(*it))] return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) def nth_operation( - self, selection: int, control: cirq.Qid, target: Sequence[cirq.Qid] + self, control: cirq.Qid, target: Sequence[cirq.Qid], **selection_indices: int ) -> cirq.OP_TREE: - return self._nth_gate(selection).on(target[-(selection + 1)]).controlled_by(control) + selection_idx = tuple(selection_indices[reg.name] for reg in self.selection_regs) + target_idx = self.selection_registers.to_flat_idx(*selection_idx) + return self.nth_gate(*selection_idx).on(target[target_idx]).controlled_by(control) diff --git a/cirq_qubitization/cirq_algos/apply_gate_to_lth_target_test.py b/cirq_qubitization/cirq_algos/apply_gate_to_lth_target_test.py index 6ef802d387..32aee19c62 100644 --- a/cirq_qubitization/cirq_algos/apply_gate_to_lth_target_test.py +++ b/cirq_qubitization/cirq_algos/apply_gate_to_lth_target_test.py @@ -11,7 +11,10 @@ def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize): greedy_mm = cq.cirq_infra.GreedyQubitManager(prefix="_a", maximize_reuse=True) with cq.cirq_infra.memory_management_context(greedy_mm): - gate = cq.ApplyGateToLthQubit(selection_bitsize, target_bitsize, lambda _: cirq.X) + gate = cq.ApplyGateToLthQubit( + cq.SelectionRegisters.build(selection=(selection_bitsize, target_bitsize)), + lambda _: cirq.X, + ) g = cq_testing.GateHelper(gate) # Upper bounded because not all ancillas may be used as part of unary iteration. assert ( @@ -28,7 +31,7 @@ def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize): qubit_vals |= zip(g.quregs['selection'], iter_bits(n, selection_bitsize)) initial_state = [qubit_vals[x] for x in g.all_qubits] - qubit_vals[g.quregs['target'][-(n + 1)]] = 1 + qubit_vals[g.quregs['target'][n]] = 1 final_state = [qubit_vals[x] for x in g.all_qubits] cq_testing.assert_circuit_inp_out_cirqsim( g.decomposed_circuit, g.all_qubits, initial_state, final_state @@ -37,7 +40,11 @@ def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize): def test_apply_gate_to_lth_qubit_diagram(): # Apply Z gate to all odd targets and Identity to even targets. - gate = cq.ApplyGateToLthQubit(3, 5, lambda n: cirq.Z if n & 1 else cirq.I, control_bitsize=2) + gate = cq.ApplyGateToLthQubit( + cq.SelectionRegisters.build(selection=(3, 5)), + lambda n: cirq.Z if n & 1 else cirq.I, + control_regs=cq.Registers.build(control=2), + ) circuit = cirq.Circuit(gate.on_registers(**gate.registers.get_named_qubits())) qubits = list(q for v in gate.registers.get_named_qubits().values() for q in v) cirq.testing.assert_has_diagram( @@ -68,14 +75,19 @@ def test_apply_gate_to_lth_qubit_diagram(): def test_apply_gate_to_lth_qubit_make_on(): - gate = cq.ApplyGateToLthQubit(3, 5, lambda n: cirq.Z if n & 1 else cirq.I, control_bitsize=2) + gate = cq.ApplyGateToLthQubit( + cq.SelectionRegisters.build(selection=(3, 5)), + lambda n: cirq.Z if n & 1 else cirq.I, + control_regs=cq.Registers.build(control=2), + ) op = gate.on_registers(**gate.registers.get_named_qubits()) op2 = cq.ApplyGateToLthQubit.make_on( nth_gate=lambda n: cirq.Z if n & 1 else cirq.I, **gate.registers.get_named_qubits() ) # Note: ApplyGateToLthQubit doesn't support value equality. assert op.qubits == op2.qubits - assert op.gate._selection_bitsize == op2.gate._selection_bitsize + assert op.gate.selection_regs == op2.gate.selection_regs + assert op.gate.control_regs == op2.gate.control_regs def test_notebook(): diff --git a/cirq_qubitization/cirq_algos/selected_majorana_fermion.py b/cirq_qubitization/cirq_algos/selected_majorana_fermion.py index 61440cf0a3..449c3cb81c 100644 --- a/cirq_qubitization/cirq_algos/selected_majorana_fermion.py +++ b/cirq_qubitization/cirq_algos/selected_majorana_fermion.py @@ -2,81 +2,87 @@ from typing import Sequence, Tuple import cirq +from attrs import frozen +from cirq_qubitization import cirq_infra from cirq_qubitization.cirq_algos import unary_iteration -from cirq_qubitization.cirq_infra.gate_with_registers import Registers +from cirq_qubitization.cirq_infra.gate_with_registers import Registers, SelectionRegisters -@cirq.value_equality() +@frozen class SelectedMajoranaFermionGate(unary_iteration.UnaryIterationGate): """Implements U s.t. U|l>|Psi> -> |l> T_{l} . Z_{l - 1} ... Z_{0} |Psi> where T = single qubit target gate. Defaults to pauli Y. - Uses: - * 1 Control qubit. - * 1 Accumulator qubit. - * `selection_bitsize` number of selection qubits. - * `target_bitsize` number of target qubits. - See Fig 9 of https://arxiv.org/abs/1805.03662 for more details. + Args: + selection_regs: Indexing `select` registers of type `SelectionRegisters`. It also contains + information about the iteration length of each selection register. + control_regs: Control registers for constructing a controlled version of the gate. + target_gate: Single qubit gate to be applied to the target qubits. + + References: + See Fig 9 of https://arxiv.org/abs/1805.03662 for more details. """ - def __init__(self, selection_bitsize: int, target_bitsize: int, target_gate=cirq.Y): - self._selection_bitsize = selection_bitsize - self._target_bitsize = target_bitsize - self._target_gate = target_gate + selection_regs: SelectionRegisters + control_regs: Registers = Registers.build(control=1) + target_gate: cirq.Gate = cirq.Y @classmethod def make_on(cls, *, target_gate=cirq.Y, **quregs: Sequence[cirq.Qid]) -> cirq.Operation: - """Helper constructor to automatically deduce bitsize attributes.""" + """Helper constructor to automatically deduce selection_regs attribute.""" return cls( - selection_bitsize=len(quregs['selection']), - target_bitsize=len(quregs['target']), + selection_regs=SelectionRegisters.build( + selection=(len(quregs['selection']), len(quregs['target'])) + ), target_gate=target_gate, ).on_registers(**quregs) - def _value_equality_values_(self): - return self._selection_bitsize, self._target_bitsize, self._target_gate - @cached_property def control_registers(self) -> Registers: - return Registers.build(control=1) + return self.control_regs @cached_property - def selection_registers(self) -> Registers: - return Registers.build(selection=self._selection_bitsize) + def selection_registers(self) -> SelectionRegisters: + return self.selection_regs @cached_property def target_registers(self) -> Registers: - return Registers.build(target=self._target_bitsize) + return Registers.build(target=self.selection_regs.total_iteration_size) @cached_property def iteration_lengths(self) -> Tuple[int, ...]: - return (self._target_bitsize,) + return self.selection_registers.iteration_lengths @cached_property def extra_registers(self) -> Registers: return Registers.build(accumulator=1) def decompose_from_registers(self, **qubit_regs: Sequence[cirq.Qid]) -> cirq.OP_TREE: - yield cirq.CNOT(*qubit_regs['control'], *qubit_regs['accumulator']) + qubit_regs['accumulator'] = cirq_infra.qalloc(1) + yield cirq.X(*qubit_regs['accumulator']).controlled_by( + *qubit_regs[self.control_regs[0].name] + ) yield super().decompose_from_registers(**qubit_regs) + cirq_infra.qfree(qubit_regs['accumulator']) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: wire_symbols = ["@"] * self.control_registers.bitsize wire_symbols += ["In"] * self.selection_registers.bitsize - wire_symbols += [f"Z{self._target_gate}"] * self.target_registers.bitsize - wire_symbols += ["Acc"] + wire_symbols += [f"Z{self.target_gate}"] * self.target_registers.bitsize return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) def nth_operation( self, - selection: int, control: cirq.Qid, target: Sequence[cirq.Qid], accumulator: Sequence[cirq.Qid], + **selection_indices: int, ) -> cirq.OP_TREE: - yield cirq.CNOT(control, accumulator[0]) - yield self._target_gate(target[selection]).controlled_by(control) - yield cirq.Z(target[selection]).controlled_by(accumulator[0]) + selection_idx = tuple(selection_indices[reg.name] for reg in self.selection_regs) + target_idx = self.selection_registers.to_flat_idx(*selection_idx) + yield cirq.CNOT(control, *accumulator) + yield self.target_gate(target[target_idx]).controlled_by(control) + yield cirq.CZ(*accumulator, target[target_idx]) diff --git a/cirq_qubitization/cirq_algos/selected_majorana_fermion_test.py b/cirq_qubitization/cirq_algos/selected_majorana_fermion_test.py index dc1434e483..df1ab78ccf 100644 --- a/cirq_qubitization/cirq_algos/selected_majorana_fermion_test.py +++ b/cirq_qubitization/cirq_algos/selected_majorana_fermion_test.py @@ -13,10 +13,11 @@ def test_selected_majorana_fermion_gate(selection_bitsize, target_bitsize, targe greedy_mm = cq.cirq_infra.GreedyQubitManager(prefix="_a", maximize_reuse=True) with cq.cirq_infra.memory_management_context(greedy_mm): gate = cq.SelectedMajoranaFermionGate( - selection_bitsize, target_bitsize, target_gate=target_gate + cq.SelectionRegisters.build(selection=(selection_bitsize, target_bitsize)), + target_gate=target_gate, ) g = cq_testing.GateHelper(gate) - assert len(g.all_qubits) <= gate.registers.bitsize + selection_bitsize + assert len(g.all_qubits) <= gate.registers.bitsize + selection_bitsize + 1 sim = cirq.Simulator(dtype=np.complex128) for n in range(target_bitsize): @@ -51,39 +52,43 @@ def test_selected_majorana_fermion_gate(selection_bitsize, target_bitsize, targe def test_selected_majorana_fermion_gate_diagram(): selection_bitsize, target_bitsize = 3, 5 - gate = cq.SelectedMajoranaFermionGate(selection_bitsize, target_bitsize, target_gate=cirq.X) + gate = cq.SelectedMajoranaFermionGate( + cq.SelectionRegisters.build(selection=(selection_bitsize, target_bitsize)), + target_gate=cirq.X, + ) circuit = cirq.Circuit(gate.on_registers(**gate.registers.get_named_qubits())) qubits = list(q for v in gate.registers.get_named_qubits().values() for q in v) cirq.testing.assert_has_diagram( circuit, """ -control: ───────@───── - │ -selection0: ────In──── - │ -selection1: ────In──── - │ -selection2: ────In──── - │ -target0: ───────ZX──── - │ -target1: ───────ZX──── - │ -target2: ───────ZX──── - │ -target3: ───────ZX──── - │ -target4: ───────ZX──── - │ -accumulator: ───Acc─── - """, +control: ──────@──── + │ +selection0: ───In─── + │ +selection1: ───In─── + │ +selection2: ───In─── + │ +target0: ──────ZX─── + │ +target1: ──────ZX─── + │ +target2: ──────ZX─── + │ +target3: ──────ZX─── + │ +target4: ──────ZX─── +""", qubit_order=qubits, ) def test_selected_majorana_fermion_gate_decomposed_diagram(): selection_bitsize, target_bitsize = 2, 3 - gate = cq.SelectedMajoranaFermionGate(selection_bitsize, target_bitsize, target_gate=cirq.X) + gate = cq.SelectedMajoranaFermionGate( + cq.SelectionRegisters.build(selection=(selection_bitsize, target_bitsize)), + target_gate=cirq.X, + ) greedy_mm = cq.cirq_infra.GreedyQubitManager(prefix="_a", maximize_reuse=True) with cq.cirq_infra.memory_management_context(greedy_mm): g = cq_testing.GateHelper(gate) @@ -91,38 +96,40 @@ def test_selected_majorana_fermion_gate_decomposed_diagram(): ancillas = sorted(set(circuit.all_qubits()) - set(g.operation.qubits)) qubits = ( g.quregs['control'] - + [q for qs in zip(g.quregs['selection'], ancillas) for q in qs] - + g.quregs['accumulator'] + + [q for qs in zip(g.quregs['selection'], ancillas[1:]) for q in qs] + + ancillas[0:1] + g.quregs['target'] ) cirq.testing.assert_has_diagram( circuit, """ -control: ───────@───@──────────────────────────────────────@───────────@────── - │ │ │ │ -selection0: ────┼───(0)────────────────────────────────────┼───────────@────── - │ │ │ │ -_a_0: ──────────┼───And───@─────────────@───────────@──────X───@───@───And†─── - │ │ │ │ │ │ -selection1: ────┼─────────(0)───────────┼───────────@──────────┼───┼────────── - │ │ │ │ │ │ -_a_1: ──────────┼─────────And───@───@───X───@───@───And†───────┼───┼────────── - │ │ │ │ │ │ │ -accumulator: ───X───────────────X───┼───@───X───┼───@──────────X───┼───@────── - │ │ │ │ │ │ -target0: ───────────────────────────X───@───────┼───┼──────────────┼───┼────── - │ │ │ │ -target1: ───────────────────────────────────────X───@──────────────┼───┼────── - │ │ -target2: ──────────────────────────────────────────────────────────X───@────── - """, +control: ──────@───@──────────────────────────────────────@───────────@────── + │ │ │ │ +selection0: ───┼───(0)────────────────────────────────────┼───────────@────── + │ │ │ │ +_a_1: ─────────┼───And───@─────────────@───────────@──────X───@───@───And†─── + │ │ │ │ │ │ +selection1: ───┼─────────(0)───────────┼───────────@──────────┼───┼────────── + │ │ │ │ │ │ +_a_2: ─────────┼─────────And───@───@───X───@───@───And†───────┼───┼────────── + │ │ │ │ │ │ │ +_a_0: ─────────X───────────────X───┼───@───X───┼───@──────────X───┼───@────── + │ │ │ │ │ │ +target0: ──────────────────────────X───@───────┼───┼──────────────┼───┼────── + │ │ │ │ +target1: ──────────────────────────────────────X───@──────────────┼───┼────── + │ │ +target2: ─────────────────────────────────────────────────────────X───@────── """, qubit_order=qubits, ) def test_selected_majorana_fermion_gate_make_on(): selection_bitsize, target_bitsize = 3, 5 - gate = cq.SelectedMajoranaFermionGate(selection_bitsize, target_bitsize, target_gate=cirq.X) + gate = cq.SelectedMajoranaFermionGate( + cq.SelectionRegisters.build(selection=(selection_bitsize, target_bitsize)), + target_gate=cirq.X, + ) op = gate.on_registers(**gate.registers.get_named_qubits()) op2 = cq.SelectedMajoranaFermionGate.make_on( target_gate=cirq.X, **gate.registers.get_named_qubits() diff --git a/cirq_qubitization/cirq_algos/unary_iteration.py b/cirq_qubitization/cirq_algos/unary_iteration.py index 81c1e7adcf..bae49763c7 100644 --- a/cirq_qubitization/cirq_algos/unary_iteration.py +++ b/cirq_qubitization/cirq_algos/unary_iteration.py @@ -175,12 +175,7 @@ def iteration_lengths(self) -> Tuple[int, ...]: @cached_property def registers(self) -> cirq_infra.Registers: return cirq_infra.Registers( - [ - *self.control_registers, - *self.selection_registers, - *self.target_registers, - *self.extra_registers, - ] + [*self.control_registers, *self.selection_registers, *self.target_registers] ) @cached_property diff --git a/cirq_qubitization/cirq_infra/__init__.py b/cirq_qubitization/cirq_infra/__init__.py index d2493e59c3..486e0ce890 100644 --- a/cirq_qubitization/cirq_infra/__init__.py +++ b/cirq_qubitization/cirq_infra/__init__.py @@ -1,5 +1,10 @@ from cirq_qubitization.cirq_infra.decompose_protocol import decompose_once_into_operations -from cirq_qubitization.cirq_infra.gate_with_registers import GateWithRegisters, Register, Registers +from cirq_qubitization.cirq_infra.gate_with_registers import ( + GateWithRegisters, + Register, + Registers, + SelectionRegisters, +) from cirq_qubitization.cirq_infra.qid_types import BorrowableQubit, CleanQubit from cirq_qubitization.cirq_infra.qubit_management_transformers import ( map_clean_and_borrowable_qubits, diff --git a/cirq_qubitization/cirq_infra/gate_with_registers.py b/cirq_qubitization/cirq_infra/gate_with_registers.py index 08c7c1e461..4fb79e3a6c 100644 --- a/cirq_qubitization/cirq_infra/gate_with_registers.py +++ b/cirq_qubitization/cirq_infra/gate_with_registers.py @@ -1,14 +1,15 @@ import abc -import dataclasses import sys -from typing import Dict, Iterable, List, overload, Sequence, Union +from typing import Dict, Iterable, List, overload, Sequence, Tuple, Union +import attrs import cirq +import numpy as np assert sys.version_info > (3, 6), "https://docs.python.org/3/whatsnew/3.6.html#whatsnew36-pep468" -@dataclasses.dataclass(frozen=True) +@attrs.frozen class Register: name: str bitsize: int @@ -112,6 +113,74 @@ def __eq__(self, other) -> bool: return self._registers == other._registers +@attrs.frozen +class SelectionRegister(Register): + iteration_length: int = attrs.field() + + @iteration_length.validator + def validate_iteration_length(self, attribute, value): + if not (0 <= value <= 2**self.bitsize): + raise ValueError(f'iteration length must be in range [0, 2^{self.bitsize}]') + + +class SelectionRegisters(Registers): + """Registers used to represent SELECT registers for various LCU methods. + + LCU methods often make use of coherent for-loops via UnaryIteration, iterating over a range + of values stored as a superposition over the `SELECT` register. The `SelectionRegisters` class + is used to represent such SELECT registers. In particular, it provides two additional features + on top of the regular `Registers` class: + + - For each selection register, we store the iteration length corresponding to that register + along with its size. + - We provide a default way of "flattening out" a composite index represented by a tuple of + values stored in multiple input selection registers to a single integer that can be used + to index a flat target register. + """ + + def __init__(self, registers: Iterable[SelectionRegister]): + super().__init__(registers) + self._registers = registers + self.iteration_lengths = tuple([reg.iteration_length for reg in registers]) + self._suffix_prod = np.multiply.accumulate(self.iteration_lengths[::-1])[::-1] + self._suffix_prod = np.append(self._suffix_prod, [1]) + + def to_flat_idx(self, *selection_vals: int) -> int: + """Flattens a composite index represented by a Tuple[int, ...] to a single output integer. + + For example: + + 1) We can flatten a 2D for-loop as follows + >>> for x in range(N): + >>> for y in range(M): + >>> flat_idx = x * M + y + + 2) Similarly, we can flatten a 3D for-loop as follows + >>> for x in range(N): + >>> for y in range(M): + >>> for z in range(L): + >>> flat_idx = x * M * L + y * L + z + + This is a general version of the mapping function described in Eq.45 of + https://arxiv.org/abs/1805.03662 + """ + assert len(selection_vals) == len(self) + return sum(v * self._suffix_prod[i + 1] for i, v in enumerate(selection_vals)) + + @property + def total_iteration_size(self) -> int: + return np.product(self.iteration_lengths) + + @classmethod + def build(cls, **registers: Tuple[int, int]) -> 'SelectionRegisters': + return cls( + [ + SelectionRegister(name=k, bitsize=v[0], iteration_length=v[1]) + for k, v in registers.items() + ] + ) + + class GateWithRegisters(cirq.Gate, metaclass=abc.ABCMeta): @property @abc.abstractmethod diff --git a/cirq_qubitization/cirq_infra/gate_with_registers_test.py b/cirq_qubitization/cirq_infra/gate_with_registers_test.py index 99473a4e07..e94e244137 100644 --- a/cirq_qubitization/cirq_infra/gate_with_registers_test.py +++ b/cirq_qubitization/cirq_infra/gate_with_registers_test.py @@ -3,7 +3,12 @@ import cirq import pytest -from cirq_qubitization.cirq_infra.gate_with_registers import GateWithRegisters, Register, Registers +from cirq_qubitization.cirq_infra.gate_with_registers import ( + GateWithRegisters, + Register, + Registers, + SelectionRegisters, +) from cirq_qubitization.jupyter_tools import execute_notebook @@ -57,6 +62,17 @@ def test_registers(): assert flat_named_qubits == expected_qubits +@pytest.mark.parametrize('n, N, m, M', [(4, 10, 5, 19), (4, 16, 5, 32)]) +def test_selection_registers_indexing(n, N, m, M): + reg = SelectionRegisters.build(x=(n, N), y=(m, M)) + assert reg.iteration_lengths == (N, M) + for x in range(N): + for y in range(M): + assert reg.to_flat_idx(x, y) == x * M + y + + assert reg.total_iteration_size == N * M + + def test_registers_getitem_raises(): g = Registers.build(a=4, b=3, c=2) with pytest.raises(IndexError, match="must be of the type"): diff --git a/cirq_qubitization/jupyter_autogen_factories.py b/cirq_qubitization/jupyter_autogen_factories.py index 7ccc53c0d7..2a563d5e82 100644 --- a/cirq_qubitization/jupyter_autogen_factories.py +++ b/cirq_qubitization/jupyter_autogen_factories.py @@ -25,6 +25,7 @@ def _make_ApplyGateToLthQubit(): from cirq_qubitization.cirq_algos.apply_gate_to_lth_target import ApplyGateToLthQubit + from cirq_qubitization.cirq_infra.gate_with_registers import Registers, SelectionRegisters def _z_to_odd(n: int): if n % 2 == 1: @@ -32,7 +33,9 @@ def _z_to_odd(n: int): return cirq.I apply_z_to_odd = ApplyGateToLthQubit( - selection_bitsize=3, target_bitsize=4, nth_gate=_z_to_odd, control_bitsize=2 + SelectionRegisters.build(selection=(3, 4)), + nth_gate=_z_to_odd, + control_regs=Registers.build(control=2), ) return apply_z_to_odd