Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More improvements to the Controlled infrastructure #879

Merged
merged 5 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions qualtran/_infra/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ def decompose_from_registers(
return cirq.inverse(self.subbloq.decompose_from_registers(context=context, **quregs))
return super().decompose_from_registers(context=context, **quregs)

def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return type annotation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

sub_info = cirq.circuit_diagram_info(self.subbloq, args, default=NotImplemented)
if sub_info is NotImplemented:
return NotImplemented
sub_info.exponent *= -1
return sub_info

def supports_decompose_bloq(self) -> bool:
"""Delegate to `subbloq.supports_decompose_bloq()`"""
return self.subbloq.supports_decompose_bloq()
Expand Down
18 changes: 17 additions & 1 deletion qualtran/_infra/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from .bloq import Bloq
from .data_types import QBit, QDType
from .gate_with_registers import GateWithRegisters
from .registers import Register, Side, Signature

if TYPE_CHECKING:
Expand Down Expand Up @@ -278,7 +279,7 @@ def _get_nice_ctrl_reg_names(reg_names: List[str], n: int) -> Tuple[str, ...]:


@attrs.frozen
class Controlled(Bloq):
class Controlled(GateWithRegisters):
"""A controlled version of `subbloq`.

This meta-bloq is part of the 'controlled' protocol. As a default fallback,
Expand Down Expand Up @@ -410,6 +411,18 @@ def add_my_tensors(
# Add the data to the tensor network.
tn.add(qtn.Tensor(data=data, inds=out_ind + in_ind, tags=[self.short_name(), tag]))

def _unitary_(self):
if isinstance(self.subbloq, GateWithRegisters):
# subbloq is a cirq gate, use the cirq-style API to derive a unitary.
return cirq.unitary(
cirq.ControlledGate(self.subbloq, control_values=self.ctrl_spec.to_cirq_cv())
)
if all(reg.side == Side.THRU for reg in self.subbloq.signature):
# subbloq has only THRU registers, so the unitary is well defined.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pycharm tells me this is supposed to be "well-defined"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reworded

return self.tensor_contract()
# Unable to determine the unitary effect.
return NotImplemented

def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
if soq.reg.name not in self.ctrl_reg_names:
# Delegate to subbloq
Expand All @@ -419,6 +432,9 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
i = self.ctrl_reg_names.index(soq.reg.name)
return self.ctrl_spec.wire_symbol(i, soq)

def adjoint(self) -> 'Bloq':
return self.subbloq.adjoint().controlled(self.ctrl_spec)

def pretty_name(self) -> str:
return f'C[{self.subbloq.pretty_name()}]'

Expand Down
2 changes: 1 addition & 1 deletion qualtran/_infra/controlled_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def test_notebook():

def _verify_ctrl_tensor_for_unitary(ctrl_spec: CtrlSpec, bloq: Bloq, gate: cirq.Gate):
cbloq = Controlled(bloq, ctrl_spec)
cgate = gate.controlled(control_values=ctrl_spec.to_cirq_cv())
cgate = cirq.ControlledGate(gate, control_values=ctrl_spec.to_cirq_cv())
np.testing.assert_array_equal(cbloq.tensor_contract(), cirq.unitary(cgate))


Expand Down
42 changes: 29 additions & 13 deletions qualtran/_infra/gate_with_registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from qualtran._infra.bloq import Bloq, DecomposeNotImplementedError, DecomposeTypeError
from qualtran._infra.composite_bloq import CompositeBloq
from qualtran._infra.controlled import Controlled, CtrlSpec
from qualtran._infra.quantum_graph import Soquet
from qualtran._infra.registers import Register, Side

Expand Down Expand Up @@ -311,23 +310,40 @@ def on_registers(
) -> cirq.Operation:
return self.on(*merge_qubits(self.signature, **qubit_regs))

def __pow__(self, power: int) -> 'GateWithRegisters':
if power == -1:
return self.adjoint()
if power == 1:
return self
raise NotImplementedError("pow is not implemented.")

# pylint: disable=arguments-renamed
def controlled(
self,
num_controls: Optional[int] = None,
num_controls: Union[Optional[int], 'CtrlSpec'] = None,
control_values=None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
) -> 'cirq.Gate':
from qualtran.cirq_interop import BloqAsCirqGate

controlled_gate = cirq.ControlledGate(
self,
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
)
ctrl_spec = CtrlSpec.from_cirq_cv(controlled_gate.control_values)
return BloqAsCirqGate(Controlled(self, ctrl_spec))
) -> 'GateWithRegisters':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be useful to provide overloaded type hints?

    # pylint: disable=arguments-renamed
    @overload
    def controlled(
        self,
        num_controls: Optional[int] = None,
        control_values=None,
        control_qid_shape: Optional[Tuple[int, ...]] = None,
    ) -> 'GateWithRegisters':
        ...

    # pylint: disable=signature-differs
    @overload
    def controlled(self, ctrl_spec: 'CtrlSpec') -> 'GateWithRegisters':
        ...

    # pylint: disable=arguments-renamed
    def controlled(self, num_controls=None, control_values=None, control_qid_shape=None):

Right now it's difficult to understand the usage without looking at the function body.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I think this needs a big docstring that describes the two modes of operation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a detailed docstring and the overloads. Also added tests to verify the different cases raise errors.

from qualtran._infra.controlled import Controlled, CtrlSpec

if isinstance(num_controls, CtrlSpec):
if not (control_values is None and control_qid_shape is None):
raise ValueError(
'GateWithRegisters.controlled() must be called with either cirq-style API'
f'or Bloq style API. Found arguments: {num_controls=}, '
f'{control_values=}, {control_qid_shape=}'
)
ctrl_spec = num_controls
else:
controlled_gate = cirq.ControlledGate(
self,
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
)
ctrl_spec = CtrlSpec.from_cirq_cv(controlled_gate.control_values)

return Controlled(self, ctrl_spec)

def _unitary_(self):
return NotImplemented
Expand Down
8 changes: 3 additions & 5 deletions qualtran/bloqs/phase_estimation/lp_resource_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,17 @@ def decompose_from_registers(
yield [OnEach(self.bitsize, Hadamard()).on(*q), Hadamard().on(*anc)]
for i in range(self.bitsize):
rz_angle = -2 * np.pi * (2**i) / (2**self.bitsize + 1)
yield cirq.Rz(rads=rz_angle).controlled().on(q[i], *anc)
yield Rz(angle=rz_angle).controlled().on(q[i], *anc)
yield Rz(angle=-2 * np.pi / (2**self.bitsize + 1)).on(*anc)
yield Hadamard().on(*anc)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
rz_angle = -2 * pi(self.bitsize) / (2**self.bitsize + 1)
ret = {(Rz(angle=rz_angle), 1), (Hadamard(), 2 + self.bitsize)}
if is_symbolic(self.bitsize):
ret |= {(Rz(angle=rz_angle).controlled().bloq, self.bitsize)}
ret |= {(Rz(angle=rz_angle).controlled(), self.bitsize)}
else:
ret |= {
(Rz(angle=rz_angle * (2**i)).controlled().bloq, 1) for i in range(self.bitsize)
}
ret |= {(Rz(angle=rz_angle * (2**i)).controlled(), 1) for i in range(self.bitsize)}
return ret

def _t_complexity_(self) -> 'TComplexity':
Expand Down
11 changes: 7 additions & 4 deletions qualtran/bloqs/phase_estimation/text_book_qpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import Set, Tuple
from typing import Sequence, Set, Tuple

import attrs
import cirq

from qualtran import Bloq, bloq_example, BloqDocSpec, GateWithRegisters, QFxp, Register, Signature
from qualtran._infra.gate_with_registers import merge_qubits
from qualtran.bloqs.basic_gates import Hadamard, OnEach
from qualtran.bloqs.qft.qft_text_book import QFTTextBook
from qualtran.bloqs.util_bloqs import Power
from qualtran.resource_counting.symbolic_counting_utils import (
ceil,
is_symbolic,
Expand Down Expand Up @@ -220,13 +222,14 @@ def decompose_from_registers(
self, context: cirq.DecompositionContext, **quregs
) -> cirq.OP_TREE:
target_quregs = {reg.name: quregs[reg.name] for reg in self.target_registers}
unitary_op = self.unitary.on_registers(**target_quregs)
target_qubits = merge_qubits(self.target_registers, **target_quregs)

phase_qubits = quregs['qpe_reg']
phase_qubits: Sequence[cirq.Qid] = quregs['qpe_reg'].flatten()

yield self.ctrl_state_prep.on(*phase_qubits)
for i, qbit in enumerate(phase_qubits[::-1]):
yield cirq.pow(unitary_op.controlled_by(qbit), 2**i)
yield Power(self.unitary.controlled(), 2**i).on(qbit, *target_qubits)
# yield cirq.pow(unitary_op.controlled_by(qbit), 2**i)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove old code

Copy link
Collaborator Author

@tanujkhattar tanujkhattar Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted the file so now there are no changes and the original implementation uses bloqs instead of cirq gates. See added tests in gate_with_registers_test.py for clarity.

yield self.qft_inv.on(*phase_qubits)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
Expand Down
21 changes: 7 additions & 14 deletions qualtran/bloqs/qsp/generalized_qsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,8 @@
from numpy.polynomial import Polynomial
from numpy.typing import NDArray

from qualtran import (
bloq_example,
BloqDocSpec,
Controlled,
CtrlSpec,
GateWithRegisters,
QBit,
Register,
Signature,
)
from qualtran import bloq_example, BloqDocSpec, GateWithRegisters, QBit, Register, Signature
from qualtran._infra.gate_with_registers import merge_qubits
from qualtran.bloqs.basic_gates.su2_rotation import SU2RotationGate

if TYPE_CHECKING:
Expand Down Expand Up @@ -382,14 +374,15 @@ def decompose_from_registers(
num_inverse_applications = self.negative_power

yield self.signal_rotations[0].on(signal_qubit)
flat_qubits_for_u = merge_qubits(self.U.signature, **quregs)
for signal_rotation in self.signal_rotations[1:]:
if num_inverse_applications > 0:
# apply C-U^\dagger
yield self.U.adjoint().on_registers(**quregs).controlled_by(signal_qubit)
yield self.U.adjoint().controlled().on(signal_qubit, *flat_qubits_for_u)
num_inverse_applications -= 1
else:
# apply C[0]-U
yield self.U.on_registers(**quregs).controlled_by(signal_qubit, control_values=[0])
yield self.U.controlled(control_values=[0]).on(signal_qubit, *flat_qubits_for_u)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we default to using qualtran style controls CtrlSpec(cvs=0) instead of cirq style control_values=[0]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted this change because self.U.on_registers(**quregs).controlled_by(signal_qubit) now uses the Controlled bloq instead of cirq gate. The op.controlled_by still expects the cirq-style API since we don't have a custom GWROperation where we can override the controlled_by to expect the bloq-style API as well. But I think it's fine right now to do things this way

yield signal_rotation.on(signal_qubit)

for _ in range(num_inverse_applications):
Expand All @@ -401,12 +394,12 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
counts = set(Counter(self.signal_rotations).items())

if degree > self.negative_power:
counts.add((Controlled(self.U, CtrlSpec(cvs=0)), degree - self.negative_power))
counts.add((self.U.controlled(control_values=[0]), degree - self.negative_power))
elif self.negative_power > degree:
counts.add((self.U.adjoint(), self.negative_power - degree))

if self.negative_power > 0:
counts.add((Controlled(self.U.adjoint(), CtrlSpec()), min(degree, self.negative_power)))
counts.add((self.U.adjoint().controlled(), min(degree, self.negative_power)))

return counts

Expand Down
5 changes: 3 additions & 2 deletions qualtran/bloqs/qubitization_walk_operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import pytest

from qualtran import Adjoint
from qualtran._infra.gate_with_registers import get_named_qubits, total_bits
from qualtran.bloqs.chemistry.ising import get_1d_ising_hamiltonian
from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiControlPauli
Expand Down Expand Up @@ -161,8 +162,8 @@ def test_qubitization_walk_operator_diagrams():

def keep(op):
ret = op in gateset_to_keep
if op.gate is not None and isinstance(op.gate, cirq.ops.raw_types._InverseCompositeGate):
ret |= op.gate._original in gateset_to_keep
if op.gate is not None and isinstance(op.gate, Adjoint):
ret |= op.gate.subbloq in gateset_to_keep
return ret

greedy_mm = cirq.GreedyQubitManager(prefix="ancilla", maximize_reuse=True)
Expand Down
7 changes: 7 additions & 0 deletions qualtran/cirq_interop/_cirq_to_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
Bloq,
BloqBuilder,
CompositeBloq,
Controlled,
CtrlSpec,
DecomposeNotImplementedError,
DecomposeTypeError,
GateWithRegisters,
Expand Down Expand Up @@ -350,6 +352,11 @@ def _cirq_gate_to_bloq(gate: cirq.Gate) -> Bloq:
# Inverse of a cirq gate, delegate to Adjoint
return Adjoint(_cirq_gate_to_bloq(gate._original))

if isinstance(gate, cirq.ControlledGate):
return Controlled(
_cirq_gate_to_bloq(gate.sub_gate), CtrlSpec.from_cirq_cv(gate.control_values)
)

# Check specific basic gates instances.
CIRQ_GATE_TO_BLOQ_MAP = {
cirq.T: TGate(),
Expand Down
14 changes: 13 additions & 1 deletion qualtran/cirq_interop/t_complexity_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import cachetools
import cirq

from qualtran import Bloq
from qualtran import Bloq, Controlled
from qualtran.cirq_interop.decompose_protocol import _decompose_once_considering_known_decomposition
from qualtran.resource_counting.symbolic_counting_utils import ceil, log2, SymbolicFloat

Expand Down Expand Up @@ -118,6 +118,18 @@ def _from_directly_countable(stc: Any) -> Optional[TComplexity]:
if stc in _ROTS_GATESET:
return TComplexity(rotations=1)

if isinstance(stc, Controlled) and cirq.num_qubits(stc) <= 2:
# We need this hack temporarily because we assume access to decomposition
# of a C-U gate where $U$ is a single qubit rotation. Cirq has this decomposition
# but the right thing to do in Qualtran is to add explicit bloqs and annotate
# them with costs. See https://github.com/quantumlib/Qualtran/issues/878
from qualtran._infra.gate_with_registers import get_named_qubits

quregs = get_named_qubits(stc.signature)
qm = cirq.SimpleQubitManager()
op, _ = stc.as_cirq_op(qubit_manager=qm, **quregs)
return t_complexity(cirq.decompose_once(op))

if cirq.num_qubits(stc) == 1 and cirq.has_unitary(stc):
# Single qubit rotation operation.
return TComplexity(rotations=1)
Expand Down
Loading