diff --git a/qualtran/_infra/adjoint.py b/qualtran/_infra/adjoint.py index ca790c005..159542254 100644 --- a/qualtran/_infra/adjoint.py +++ b/qualtran/_infra/adjoint.py @@ -150,6 +150,15 @@ def decompose_from_registers( return cirq.inverse(self.subbloq.decompose_from_registers(context=context, **quregs)) return super().decompose_from_registers(context=context, **quregs) + def _circuit_diagram_info_( + self, args: 'cirq.CircuitDiagramInfoArgs' + ) -> cirq.CircuitDiagramInfo: + sub_info = cirq.circuit_diagram_info(self.subbloq, args, default=NotImplemented) + if sub_info is NotImplemented: + return NotImplemented + sub_info.exponent *= -1 + return sub_info + def supports_decompose_bloq(self) -> bool: """Delegate to `subbloq.supports_decompose_bloq()`""" return self.subbloq.supports_decompose_bloq() diff --git a/qualtran/_infra/controlled.py b/qualtran/_infra/controlled.py index 3c9c83eef..fea521404 100644 --- a/qualtran/_infra/controlled.py +++ b/qualtran/_infra/controlled.py @@ -34,6 +34,7 @@ from .bloq import Bloq from .data_types import QBit, QDType +from .gate_with_registers import GateWithRegisters from .registers import Register, Side, Signature if TYPE_CHECKING: @@ -278,7 +279,7 @@ def _get_nice_ctrl_reg_names(reg_names: List[str], n: int) -> Tuple[str, ...]: @attrs.frozen -class Controlled(Bloq): +class Controlled(GateWithRegisters): """A controlled version of `subbloq`. This meta-bloq is part of the 'controlled' protocol. As a default fallback, @@ -410,6 +411,19 @@ def add_my_tensors( # Add the data to the tensor network. tn.add(qtn.Tensor(data=data, inds=out_ind + in_ind, tags=[self.short_name(), tag])) + def _unitary_(self): + if isinstance(self.subbloq, GateWithRegisters): + # subbloq is a cirq gate, use the cirq-style API to derive a unitary. + return cirq.unitary( + cirq.ControlledGate(self.subbloq, control_values=self.ctrl_spec.to_cirq_cv()) + ) + if all(reg.side == Side.THRU for reg in self.subbloq.signature): + # subbloq has only THRU registers, so the tensor contraction corresponds + # to a unitary matrix. + return self.tensor_contract() + # Unable to determine the unitary effect. + return NotImplemented + def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol': if soq.reg.name not in self.ctrl_reg_names: # Delegate to subbloq @@ -419,6 +433,9 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol': i = self.ctrl_reg_names.index(soq.reg.name) return self.ctrl_spec.wire_symbol(i, soq) + def adjoint(self) -> 'Bloq': + return self.subbloq.adjoint().controlled(self.ctrl_spec) + def pretty_name(self) -> str: return f'C[{self.subbloq.pretty_name()}]' diff --git a/qualtran/_infra/controlled_test.py b/qualtran/_infra/controlled_test.py index 701f6fec5..9975e9bf2 100644 --- a/qualtran/_infra/controlled_test.py +++ b/qualtran/_infra/controlled_test.py @@ -322,7 +322,7 @@ def test_notebook(): def _verify_ctrl_tensor_for_unitary(ctrl_spec: CtrlSpec, bloq: Bloq, gate: cirq.Gate): cbloq = Controlled(bloq, ctrl_spec) - cgate = gate.controlled(control_values=ctrl_spec.to_cirq_cv()) + cgate = cirq.ControlledGate(gate, control_values=ctrl_spec.to_cirq_cv()) np.testing.assert_array_equal(cbloq.tensor_contract(), cirq.unitary(cgate)) diff --git a/qualtran/_infra/gate_with_registers.py b/qualtran/_infra/gate_with_registers.py index da5e661e1..429a2988d 100644 --- a/qualtran/_infra/gate_with_registers.py +++ b/qualtran/_infra/gate_with_registers.py @@ -13,7 +13,18 @@ # limitations under the License. import abc -from typing import Dict, Iterable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import ( + Collection, + Dict, + Iterable, + List, + Optional, + overload, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) import cirq import numpy as np @@ -21,11 +32,11 @@ from qualtran._infra.bloq import Bloq, DecomposeNotImplementedError, DecomposeTypeError from qualtran._infra.composite_bloq import CompositeBloq -from qualtran._infra.controlled import Controlled, CtrlSpec from qualtran._infra.quantum_graph import Soquet from qualtran._infra.registers import Register, Side if TYPE_CHECKING: + from qualtran import CtrlSpec from qualtran.cirq_interop import CirqQuregT from qualtran.drawing import WireSymbol @@ -311,23 +322,160 @@ def on_registers( ) -> cirq.Operation: return self.on(*merge_qubits(self.signature, **qubit_regs)) + def __pow__(self, power: int) -> 'GateWithRegisters': + bloq = self if power > 0 else self.adjoint() + if abs(power) == 1: + return bloq + if all(reg.side == Side.THRU for reg in self.signature): + from qualtran.bloqs.util_bloqs import Power + + return Power(bloq, abs(power)) + raise NotImplementedError(f"{self} does not implemented __pow__ for {power=}.") + + def _get_ctrl_spec( + self, + num_controls: Union[Optional[int], 'CtrlSpec'] = None, + control_values=None, + control_qid_shape: Optional[Tuple[int, ...]] = None, + *, + ctrl_spec: Optional['CtrlSpec'] = None, + ) -> 'CtrlSpec': + """Helper method to support Cirq & Bloq style APIs for constructing controlled Bloqs. + + This method can be used to construct a `CtrlSpec` from either the Bloq-style API that + already accepts a `CtrlSpec` and simply returns it OR a Cirq-style API which accepts + parameters expected by `cirq.Gate.controlled()` and converts them to a `CtrlSpec` object. + + Users implementing custom `GateWithRegisters.controlled()` overrides can use this helper + to generate a CtrlSpec from the cirq-style API and thus easily support both Cirq & Bloq + APIs. For example + + >>> class CustomGWR(GateWithRegisters): + >>> def controlled(self, *args, **kwargs) -> 'Bloq': + >>> ctrl_spec = self._get_ctrl_spec(*args, **kwargs) + >>> # Use ctrl_spec to construct a controlled version of `self`. + + Args: + num_controls: Cirq style API to specify control specification - + Total number of control qubits. + control_values: Cirq style API to specify control specification - + Which control computational basis state to apply the + sub gate. A sequence of length `num_controls` where each + entry is an integer (or set of integers) corresponding to the + computational basis state (or set of possible values) where that + control is enabled. When all controls are enabled, the sub gate is + applied. If unspecified, control values default to 1. + control_qid_shape: Cirq style API to specify control specification - + The qid shape of the controls. A tuple of the + expected dimension of each control qid. Defaults to + `(2,) * num_controls`. Specify this argument when using qudits. + ctrl_spec: Bloq style API to specify a control specification - + An optional keyword argument `CtrlSpec`, which specifies how to control + the bloq. The default spec means the bloq will be active when one control qubit is + in the |1> state. See the CtrlSpec documentation for more possibilities including + negative controls, integer-equality control, and ndarrays of control values. + """ + from qualtran._infra.controlled import CtrlSpec + + ok = True + if ctrl_spec is not None: + # Bloq API invoked via kwargs - bloq.controlled(ctrl_spec=ctrl_spec) + ok &= control_values is None and control_qid_shape is None and num_controls is None + elif isinstance(num_controls, CtrlSpec): + # Bloq API invoked via args - bloq.controlled(ctrl_spec) + ok &= control_values is None and control_qid_shape is None + if not ok: + raise ValueError( + 'GateWithRegisters.controlled() must be called with either cirq-style API' + f'or Bloq style API. Found arguments: {num_controls=}, ' + f'{control_values=}, {control_qid_shape=}, {ctrl_spec=}' + ) + + if isinstance(num_controls, CtrlSpec): + ctrl_spec = num_controls + elif ctrl_spec is None: + controlled_gate = cirq.ControlledGate( + self, + num_controls=num_controls, + control_values=control_values, + control_qid_shape=control_qid_shape, + ) + ctrl_spec = CtrlSpec.from_cirq_cv(controlled_gate.control_values) + return ctrl_spec + # pylint: disable=arguments-renamed + @overload def controlled( self, num_controls: Optional[int] = None, - control_values=None, + control_values: Optional[ + Union[cirq.ops.AbstractControlValues, Sequence[Union[int, Collection[int]]]] + ] = None, control_qid_shape: Optional[Tuple[int, ...]] = None, - ) -> 'cirq.Gate': - from qualtran.cirq_interop import BloqAsCirqGate - - controlled_gate = cirq.ControlledGate( - self, - num_controls=num_controls, - control_values=control_values, - control_qid_shape=control_qid_shape, + ) -> 'GateWithRegisters': + """Cirq-style API to construct a controlled gate. See `cirq.Gate.controlled()`""" + + # pylint: disable=signature-differs + @overload + def controlled(self, ctrl_spec: Optional['CtrlSpec'] = None) -> 'GateWithRegisters': + """Bloq-style API to construct a controlled Bloq. See `Bloq.controlled()`.""" + + def controlled( + self, + num_controls: Union[Optional[int], 'CtrlSpec'] = None, + control_values: Optional[ + Union[cirq.ops.AbstractControlValues, Sequence[Union[int, Collection[int]]]] + ] = None, + control_qid_shape: Optional[Tuple[int, ...]] = None, + *, + ctrl_spec: Optional['CtrlSpec'] = None, + ) -> 'Bloq': + """Return a controlled version of self. Controls can be specified via Cirq/Bloq-style APIs. + + If no arguments are specified, defaults to a single qubit control. + + Supports both Cirq-style API and Bloq-style API to construct controlled Bloqs. The cirq-style + API is supported by intercepting the Cirq-style way of specifying a control specification; + via arguments `num_controls`, `control_values` and `control_qid_shape`, and constructing a + `CtrlSpec` object from it before delegating to `self.get_ctrl_system`. + + By default, the system will use the `qualtran.Controlled` meta-bloq to wrap this + bloq. Bloqs authors can declare their own, custom controlled versions by overriding + `Bloq.get_ctrl_system` in the bloq. + + If overriding the `GWR.controlled()` method directly, Bloq authors can use the + `self._get_ctrl_spec` helper to construct a `CtrlSpec` object from the input parameters of + `GWR.controlled()` and use it to return a custom controlled version of this Bloq. + + + Args: + num_controls: Cirq style API to specify control specification - + Total number of control qubits. + control_values: Cirq style API to specify control specification - + Which control computational basis state to apply the + sub gate. A sequence of length `num_controls` where each + entry is an integer (or set of integers) corresponding to the + computational basis state (or set of possible values) where that + control is enabled. When all controls are enabled, the sub gate is + applied. If unspecified, control values default to 1. + control_qid_shape: Cirq style API to specify control specification - + The qid shape of the controls. A tuple of the + expected dimension of each control qid. Defaults to + `(2,) * num_controls`. Specify this argument when using qudits. + ctrl_spec: Bloq style API to specify a control specification - + An optional keyword argument `CtrlSpec`, which specifies how to control + the bloq. The default spec means the bloq will be active when one control qubit is + in the |1> state. See the CtrlSpec documentation for more possibilities including + negative controls, integer-equality control, and ndarrays of control values. + + Returns: + A controlled version of the bloq. + """ + ctrl_spec = self._get_ctrl_spec( + num_controls, control_values, control_qid_shape, ctrl_spec=ctrl_spec ) - ctrl_spec = CtrlSpec.from_cirq_cv(controlled_gate.control_values) - return BloqAsCirqGate(Controlled(self, ctrl_spec)) + controlled_bloq, _ = self.get_ctrl_system(ctrl_spec=ctrl_spec) + return controlled_bloq def _unitary_(self): return NotImplemented diff --git a/qualtran/_infra/gate_with_registers_test.py b/qualtran/_infra/gate_with_registers_test.py index 40df3be3c..a9b1f38f3 100644 --- a/qualtran/_infra/gate_with_registers_test.py +++ b/qualtran/_infra/gate_with_registers_test.py @@ -18,8 +18,19 @@ import numpy as np import pytest -from qualtran import GateWithRegisters, QAny, QBit, Register, Side, Signature, SoquetT +from qualtran import ( + Controlled, + CtrlSpec, + GateWithRegisters, + QAny, + QBit, + Register, + Side, + Signature, + SoquetT, +) from qualtran.bloqs.basic_gates import XGate, YGate, ZGate +from qualtran.bloqs.util_bloqs import Power from qualtran.testing import execute_notebook @@ -51,6 +62,38 @@ def test_gate_with_registers(): np.testing.assert_allclose(cirq.unitary(tg), tg.tensor_contract()) + # Test GWR.controlled() works correctly with Bloq and Cirq style API + ctrl = cirq.q('ctrl') + cop1 = tg.controlled().on(ctrl, *qubits[:5], *qubits[6:], qubits[5]) + cop2 = tg.controlled().on_registers(ctrl=ctrl, r1=qubits[:5], r2=qubits[6:], r3=qubits[5]) + cop3 = op1.controlled_by(ctrl) + cop4 = op2.controlled_by(ctrl) + assert cop1 == cop2 == cop3 == cop4 + assert cop1.gate == cop2.gate == cop3.gate == cop4.gate == Controlled(tg, CtrlSpec()) + + assert ( + tg.controlled(num_controls=1, control_values=[0]) + == tg.controlled(control_values=[0], control_qid_shape=(2,)) + == tg.controlled(CtrlSpec(cvs=0)) + == tg.controlled(ctrl_spec=CtrlSpec(cvs=0)) + ) + + # Test GWR.controlled() raises with incorrect invocation. + with pytest.raises(ValueError): + tg.controlled(control_values=[0], ctrl_spec=CtrlSpec()) + + with pytest.raises(ValueError): + tg.controlled(CtrlSpec(), control_values=[0]) + + with pytest.raises(ValueError): + tg.controlled(CtrlSpec(), ctrl_spec=CtrlSpec()) + + # Test GWR**pow + assert tg**-1 == tg.adjoint() + assert tg**1 is tg + assert tg**-10 == Power(tg.adjoint(), 10) + assert tg**10 == Power(tg, 10) + class _TestGateAtomic(GateWithRegisters): @property diff --git a/qualtran/bloqs/phase_estimation/lp_resource_state.py b/qualtran/bloqs/phase_estimation/lp_resource_state.py index 4bec20589..52dd6e8a0 100644 --- a/qualtran/bloqs/phase_estimation/lp_resource_state.py +++ b/qualtran/bloqs/phase_estimation/lp_resource_state.py @@ -64,7 +64,7 @@ def decompose_from_registers( yield [OnEach(self.bitsize, Hadamard()).on(*q), Hadamard().on(*anc)] for i in range(self.bitsize): rz_angle = -2 * np.pi * (2**i) / (2**self.bitsize + 1) - yield cirq.Rz(rads=rz_angle).controlled().on(q[i], *anc) + yield Rz(angle=rz_angle).controlled().on(q[i], *anc) yield Rz(angle=-2 * np.pi / (2**self.bitsize + 1)).on(*anc) yield Hadamard().on(*anc) @@ -72,11 +72,9 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: rz_angle = -2 * pi(self.bitsize) / (2**self.bitsize + 1) ret = {(Rz(angle=rz_angle), 1), (Hadamard(), 2 + self.bitsize)} if is_symbolic(self.bitsize): - ret |= {(Rz(angle=rz_angle).controlled().bloq, self.bitsize)} + ret |= {(Rz(angle=rz_angle).controlled(), self.bitsize)} else: - ret |= { - (Rz(angle=rz_angle * (2**i)).controlled().bloq, 1) for i in range(self.bitsize) - } + ret |= {(Rz(angle=rz_angle * (2**i)).controlled(), 1) for i in range(self.bitsize)} return ret def _t_complexity_(self) -> 'TComplexity': diff --git a/qualtran/bloqs/qsp/generalized_qsp.py b/qualtran/bloqs/qsp/generalized_qsp.py index f7c2a393d..d9e328677 100644 --- a/qualtran/bloqs/qsp/generalized_qsp.py +++ b/qualtran/bloqs/qsp/generalized_qsp.py @@ -20,16 +20,7 @@ from numpy.polynomial import Polynomial from numpy.typing import NDArray -from qualtran import ( - bloq_example, - BloqDocSpec, - Controlled, - CtrlSpec, - GateWithRegisters, - QBit, - Register, - Signature, -) +from qualtran import bloq_example, BloqDocSpec, GateWithRegisters, QBit, Register, Signature from qualtran.bloqs.basic_gates.su2_rotation import SU2RotationGate if TYPE_CHECKING: @@ -401,12 +392,12 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: counts = set(Counter(self.signal_rotations).items()) if degree > self.negative_power: - counts.add((Controlled(self.U, CtrlSpec(cvs=0)), degree - self.negative_power)) + counts.add((self.U.controlled(control_values=[0]), degree - self.negative_power)) elif self.negative_power > degree: counts.add((self.U.adjoint(), self.negative_power - degree)) if self.negative_power > 0: - counts.add((Controlled(self.U.adjoint(), CtrlSpec()), min(degree, self.negative_power))) + counts.add((self.U.adjoint().controlled(), min(degree, self.negative_power))) return counts diff --git a/qualtran/bloqs/qubitization_walk_operator_test.py b/qualtran/bloqs/qubitization_walk_operator_test.py index 63b7ff1d6..9c8cacf30 100644 --- a/qualtran/bloqs/qubitization_walk_operator_test.py +++ b/qualtran/bloqs/qubitization_walk_operator_test.py @@ -16,6 +16,7 @@ import numpy as np import pytest +from qualtran import Adjoint from qualtran._infra.gate_with_registers import get_named_qubits, total_bits from qualtran.bloqs.chemistry.ising import get_1d_ising_hamiltonian from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiControlPauli @@ -161,8 +162,8 @@ def test_qubitization_walk_operator_diagrams(): def keep(op): ret = op in gateset_to_keep - if op.gate is not None and isinstance(op.gate, cirq.ops.raw_types._InverseCompositeGate): - ret |= op.gate._original in gateset_to_keep + if op.gate is not None and isinstance(op.gate, Adjoint): + ret |= op.gate.subbloq in gateset_to_keep return ret greedy_mm = cirq.GreedyQubitManager(prefix="ancilla", maximize_reuse=True) diff --git a/qualtran/bloqs/util_bloqs.py b/qualtran/bloqs/util_bloqs.py index e545740eb..294012941 100644 --- a/qualtran/bloqs/util_bloqs.py +++ b/qualtran/bloqs/util_bloqs.py @@ -26,6 +26,7 @@ from qualtran import ( Bloq, BloqBuilder, + GateWithRegisters, QAny, QBit, QDType, @@ -478,7 +479,7 @@ def _t_complexity_(self) -> 'TComplexity': @frozen -class Power(Bloq): +class Power(GateWithRegisters): """Wrapper that repeats the given `bloq` `power` times. `Bloq` must have only THRU registers. diff --git a/qualtran/cirq_interop/_cirq_to_bloq.py b/qualtran/cirq_interop/_cirq_to_bloq.py index efc515ab7..46cae89f8 100644 --- a/qualtran/cirq_interop/_cirq_to_bloq.py +++ b/qualtran/cirq_interop/_cirq_to_bloq.py @@ -28,6 +28,8 @@ Bloq, BloqBuilder, CompositeBloq, + Controlled, + CtrlSpec, DecomposeNotImplementedError, DecomposeTypeError, GateWithRegisters, @@ -45,7 +47,6 @@ get_named_qubits, split_qubits, ) -from qualtran.bloqs.util_bloqs import Cast from qualtran.cirq_interop._interop_qubit_manager import InteropQubitManager from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity from qualtran.simulation.tensor._tensor_data_manipulation import ( @@ -245,6 +246,8 @@ def _ensure_in_reg_exists( bb: BloqBuilder, in_reg: _QReg, qreg_to_qvar: Dict[_QReg, Soquet] ) -> None: """Takes care of qubit allocations, split and joins to ensure `qreg_to_qvar[in_reg]` exists.""" + from qualtran.bloqs.util_bloqs import Cast + all_mapped_qubits = {q for qreg in qreg_to_qvar for q in qreg.qubits} qubits_to_allocate: List[cirq.Qid] = [q for q in in_reg.qubits if q not in all_mapped_qubits] if qubits_to_allocate: @@ -350,6 +353,11 @@ def _cirq_gate_to_bloq(gate: cirq.Gate) -> Bloq: # Inverse of a cirq gate, delegate to Adjoint return Adjoint(_cirq_gate_to_bloq(gate._original)) + if isinstance(gate, cirq.ControlledGate): + return Controlled( + _cirq_gate_to_bloq(gate.sub_gate), CtrlSpec.from_cirq_cv(gate.control_values) + ) + # Check specific basic gates instances. CIRQ_GATE_TO_BLOQ_MAP = { cirq.T: TGate(), diff --git a/qualtran/cirq_interop/t_complexity_protocol.py b/qualtran/cirq_interop/t_complexity_protocol.py index a7f49a2eb..ee97c65f4 100644 --- a/qualtran/cirq_interop/t_complexity_protocol.py +++ b/qualtran/cirq_interop/t_complexity_protocol.py @@ -17,7 +17,7 @@ import cachetools import cirq -from qualtran import Bloq +from qualtran import Bloq, Controlled from qualtran.cirq_interop.decompose_protocol import _decompose_once_considering_known_decomposition from qualtran.resource_counting.symbolic_counting_utils import ceil, log2, SymbolicFloat @@ -118,6 +118,18 @@ def _from_directly_countable(stc: Any) -> Optional[TComplexity]: if stc in _ROTS_GATESET: return TComplexity(rotations=1) + if isinstance(stc, Controlled) and cirq.num_qubits(stc) <= 2: + # We need this hack temporarily because we assume access to decomposition + # of a C-U gate where $U$ is a single qubit rotation. Cirq has this decomposition + # but the right thing to do in Qualtran is to add explicit bloqs and annotate + # them with costs. See https://github.com/quantumlib/Qualtran/issues/878 + from qualtran._infra.gate_with_registers import get_named_qubits + + quregs = get_named_qubits(stc.signature) + qm = cirq.SimpleQubitManager() + op, _ = stc.as_cirq_op(qubit_manager=qm, **quregs) + return t_complexity(cirq.decompose_once(op)) + if cirq.num_qubits(stc) == 1 and cirq.has_unitary(stc): # Single qubit rotation operation. return TComplexity(rotations=1)