diff --git a/qualtran/bloqs/for_testing/random_select_and_prepare.py b/qualtran/bloqs/for_testing/random_select_and_prepare.py index 80a8b5ef7c..6481ef5e47 100644 --- a/qualtran/bloqs/for_testing/random_select_and_prepare.py +++ b/qualtran/bloqs/for_testing/random_select_and_prepare.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import cached_property -from typing import Tuple +from typing import Optional, Tuple import cirq import numpy as np from attrs import frozen from numpy.typing import NDArray -from qualtran import BloqBuilder, BoundedQUInt, Register, SoquetT +from qualtran import BloqBuilder, BoundedQUInt, QBit, Register, SoquetT from qualtran.bloqs.for_testing.matrix_gate import MatrixGate from qualtran.bloqs.qubitization_walk_operator import QubitizationWalkOperator from qualtran.bloqs.select_and_prepare import PrepareOracle, SelectOracle @@ -103,6 +103,7 @@ class TestPauliSelectOracle(SelectOracle): select_bitsize: int target_bitsize: int select_unitaries: tuple[cirq.DensePauliString, ...] + control_val: Optional[int] = None @classmethod def random( @@ -116,7 +117,7 @@ def random( @property def control_registers(self) -> Tuple[Register, ...]: - return () + return () if self.control_val is None else (Register('control', QBit()),) @property def selection_registers(self) -> Tuple[Register, ...]: @@ -136,7 +137,17 @@ def decompose_from_registers( ) -> cirq.OP_TREE: for cv, U in enumerate(self.select_unitaries): bits = tuple(map(int, bin(cv)[2:].zfill(self.select_bitsize)))[::-1] - yield U.on(*target).controlled_by(*selection, control_values=bits) + op = U.on(*target).controlled_by(*selection, control_values=bits) + if self.control_val is not None: + op = op.controlled_by(*quregs['control'], control_values=[self.control_val]) + yield op + + def get_ctrl_system( + self, ctrl_spec: Optional['CtrlSpec'] = None + ) -> Tuple['Bloq', 'AddControlledT']: + from qualtran._infra.gate_with_registers import get_ctrl_system_for_single_qubit_controlled + + return get_ctrl_system_for_single_qubit_controlled(self, ctrl_spec) def random_qubitization_walk_operator( diff --git a/qualtran/bloqs/qubitization_walk_operator.py b/qualtran/bloqs/qubitization_walk_operator.py index 656f92173e..d5376995e2 100644 --- a/qualtran/bloqs/qubitization_walk_operator.py +++ b/qualtran/bloqs/qubitization_walk_operator.py @@ -13,13 +13,13 @@ # limitations under the License. from functools import cached_property -from typing import Optional, Tuple +from typing import Iterable, Optional, Sequence, Tuple import attrs import cirq from numpy.typing import NDArray -from qualtran import bloq_example, BloqDocSpec, GateWithRegisters, Register, Signature +from qualtran import bloq_example, BloqDocSpec, CtrlSpec, GateWithRegisters, Register, Signature from qualtran._infra.gate_with_registers import total_bits from qualtran.bloqs.reflection_using_prepare import ReflectionUsingPrepare from qualtran.bloqs.select_and_prepare import PrepareOracle, SelectOracle @@ -61,14 +61,11 @@ class QubitizationWalkOperator(GateWithRegisters): select: SelectOracle prepare: PrepareOracle + control_val: Optional[int] = None def __attrs_post_init__(self): assert self.select.control_registers == self.reflect.control_registers - @cached_property - def control_val(self) -> Optional[int]: - return self.select.control_val - @cached_property def control_registers(self) -> Tuple[Register, ...]: return self.select.control_registers @@ -102,6 +99,32 @@ def decompose_from_registers( reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.signature} yield self.reflect.on_registers(**reflect_reg) + def get_ctrl_system( + self, ctrl_spec: Optional['CtrlSpec'] = None + ) -> Tuple['Bloq', 'AddControlledT']: + if ctrl_spec is None: + ctrl_spec = CtrlSpec() + + if self.control_val is None and ctrl_spec.shapes in [((),), ((1,),)]: + c_select = self.select.controlled(ctrl_spec) + assert isinstance(c_select, SelectOracle) + cbloq = attrs.evolve(self, select=c_select, control_val=(int(ctrl_spec.cvs[0].item()))) + (ctrl_reg,) = cbloq.control_registers + + def adder( + bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: dict[str, 'SoquetT'] + ) -> tuple[Iterable['SoquetT'], Iterable['SoquetT']]: + soqs = {ctrl_reg.name: ctrl_soqs[0]} | in_soqs + soqs = bb.add_d(cbloq, **soqs) + ctrl_soqs = [soqs.pop(ctrl_reg.name)] + return ctrl_soqs, soqs.values() + + return cbloq, adder + + raise NotImplementedError( + f'Cannot create a controlled version of {self} with {ctrl_spec=}.' + ) + def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: wire_symbols = ['@' if self.control_val else '@(0)'] * total_bits(self.control_registers) wire_symbols += ['W'] * (total_bits(self.signature) - total_bits(self.control_registers)) diff --git a/qualtran/bloqs/qubitization_walk_operator_test.py b/qualtran/bloqs/qubitization_walk_operator_test.py index e168895fdc..16f38a63d7 100644 --- a/qualtran/bloqs/qubitization_walk_operator_test.py +++ b/qualtran/bloqs/qubitization_walk_operator_test.py @@ -140,7 +140,7 @@ def decompose_twice(op): ''', ) # 3. Diagram for $Ctrl-W = Ctrl-SELECT.Ctrl-R_{L}$ - controlled_walk_op = walk.controlled().on_registers(**g.quregs, ctrl=cirq.q('control')) + controlled_walk_op = walk.controlled().on_registers(**g.quregs, control=cirq.q('control')) circuit = cirq.Circuit(cirq.decompose_once(controlled_walk_op)) cirq.testing.assert_has_diagram( circuit,