Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More improvements to the Controlled infrastructure #879

Merged
merged 5 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions qualtran/_infra/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ def decompose_from_registers(
return cirq.inverse(self.subbloq.decompose_from_registers(context=context, **quregs))
return super().decompose_from_registers(context=context, **quregs)

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> cirq.CircuitDiagramInfo:
sub_info = cirq.circuit_diagram_info(self.subbloq, args, default=NotImplemented)
if sub_info is NotImplemented:
return NotImplemented
sub_info.exponent *= -1
return sub_info

def supports_decompose_bloq(self) -> bool:
"""Delegate to `subbloq.supports_decompose_bloq()`"""
return self.subbloq.supports_decompose_bloq()
Expand Down
19 changes: 18 additions & 1 deletion qualtran/_infra/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from .bloq import Bloq
from .data_types import QBit, QDType
from .gate_with_registers import GateWithRegisters
from .registers import Register, Side, Signature

if TYPE_CHECKING:
Expand Down Expand Up @@ -278,7 +279,7 @@ def _get_nice_ctrl_reg_names(reg_names: List[str], n: int) -> Tuple[str, ...]:


@attrs.frozen
class Controlled(Bloq):
class Controlled(GateWithRegisters):
"""A controlled version of `subbloq`.

This meta-bloq is part of the 'controlled' protocol. As a default fallback,
Expand Down Expand Up @@ -410,6 +411,19 @@ def add_my_tensors(
# Add the data to the tensor network.
tn.add(qtn.Tensor(data=data, inds=out_ind + in_ind, tags=[self.short_name(), tag]))

def _unitary_(self):
if isinstance(self.subbloq, GateWithRegisters):
# subbloq is a cirq gate, use the cirq-style API to derive a unitary.
return cirq.unitary(
cirq.ControlledGate(self.subbloq, control_values=self.ctrl_spec.to_cirq_cv())
)
if all(reg.side == Side.THRU for reg in self.subbloq.signature):
# subbloq has only THRU registers, so the tensor contraction corresponds
# to a unitary matrix.
return self.tensor_contract()
# Unable to determine the unitary effect.
return NotImplemented

def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
if soq.reg.name not in self.ctrl_reg_names:
# Delegate to subbloq
Expand All @@ -419,6 +433,9 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
i = self.ctrl_reg_names.index(soq.reg.name)
return self.ctrl_spec.wire_symbol(i, soq)

def adjoint(self) -> 'Bloq':
return self.subbloq.adjoint().controlled(self.ctrl_spec)

def pretty_name(self) -> str:
return f'C[{self.subbloq.pretty_name()}]'

Expand Down
2 changes: 1 addition & 1 deletion qualtran/_infra/controlled_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def test_notebook():

def _verify_ctrl_tensor_for_unitary(ctrl_spec: CtrlSpec, bloq: Bloq, gate: cirq.Gate):
cbloq = Controlled(bloq, ctrl_spec)
cgate = gate.controlled(control_values=ctrl_spec.to_cirq_cv())
cgate = cirq.ControlledGate(gate, control_values=ctrl_spec.to_cirq_cv())
np.testing.assert_array_equal(cbloq.tensor_contract(), cirq.unitary(cgate))


Expand Down
157 changes: 144 additions & 13 deletions qualtran/_infra/gate_with_registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,30 @@
# limitations under the License.

import abc
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from typing import (
Collection,
Dict,
Iterable,
List,
Optional,
overload,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)

import cirq
import numpy as np
from numpy.typing import NDArray

from qualtran._infra.bloq import Bloq, DecomposeNotImplementedError, DecomposeTypeError
from qualtran._infra.composite_bloq import CompositeBloq
from qualtran._infra.controlled import Controlled, CtrlSpec
from qualtran._infra.quantum_graph import Soquet
from qualtran._infra.registers import Register, Side

if TYPE_CHECKING:
from qualtran import CtrlSpec
from qualtran.cirq_interop import CirqQuregT
from qualtran.drawing import WireSymbol

Expand Down Expand Up @@ -311,23 +322,143 @@ def on_registers(
) -> cirq.Operation:
return self.on(*merge_qubits(self.signature, **qubit_regs))

def __pow__(self, power: int) -> 'GateWithRegisters':
bloq = self if power > 0 else self.adjoint()
if abs(power) == 1:
return bloq
if all(reg.side == Side.THRU for reg in self.signature):
from qualtran.bloqs.util_bloqs import Power

return Power(bloq, abs(power))
raise NotImplementedError(f"{self} does not implemented __pow__ for {power=}.")

def _get_ctrl_spec(
self,
num_controls: Union[Optional[int], 'CtrlSpec'] = None,
control_values=None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
*,
ctrl_spec: Optional['CtrlSpec'] = None,
) -> 'CtrlSpec':
"""Helper method to support Cirq & Bloq style APIs for constructing controlled Bloqs.

This method can be used to construct a `CtrlSpec` from either the Bloq-style API that
already accepts a `CtrlSpec` and simply returns it OR a Cirq-style API which accepts
parameters expected by `cirq.Gate.controlled()` and converts them to a `CtrlSpec` object.

Users implementing custom `GateWithRegisters.controlled()` overrides can use this helper
to generate a CtrlSpec from the cirq-style API and thus easily support both Cirq & Bloq
APIs. For example

>>> class CustomGWR(GateWithRegisters):
>>> def controlled(self, *args, **kwargs) -> 'Bloq':
>>> ctrl_spec = self._get_ctrl_spec(*args, **kwargs)
>>> # Use ctrl_spec to construct a controlled version of `self`.

Args:
num_controls:
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
"""
from qualtran._infra.controlled import CtrlSpec

ok = True
if ctrl_spec is not None:
# Bloq API invoked via kwargs - bloq.controlled(ctrl_spec=ctrl_spec)
ok &= control_values is None and control_qid_shape is None and num_controls is None
elif isinstance(num_controls, CtrlSpec):
# Bloq API invoked via args - bloq.controlled(ctrl_spec)
ok &= control_values is None and control_qid_shape is None
if not ok:
raise ValueError(
'GateWithRegisters.controlled() must be called with either cirq-style API'
f'or Bloq style API. Found arguments: {num_controls=}, '
f'{control_values=}, {control_qid_shape=}, {ctrl_spec=}'
)

if isinstance(num_controls, CtrlSpec):
ctrl_spec = num_controls
elif ctrl_spec is None:
controlled_gate = cirq.ControlledGate(
self,
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
)
ctrl_spec = CtrlSpec.from_cirq_cv(controlled_gate.control_values)
return ctrl_spec

# pylint: disable=arguments-renamed
@overload
def controlled(
self,
num_controls: Optional[int] = None,
control_values=None,
control_values: Optional[
Union[cirq.ops.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
] = None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
) -> 'cirq.Gate':
from qualtran.cirq_interop import BloqAsCirqGate

controlled_gate = cirq.ControlledGate(
self,
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
) -> 'GateWithRegisters':
"""Cirq-style API to construct a controlled gate. See `cirq.Gate.controlled()`"""

# pylint: disable=signature-differs
@overload
def controlled(self, ctrl_spec: Optional['CtrlSpec'] = None) -> 'GateWithRegisters':
"""Bloq-style API to construct a controlled Bloq. See `Bloq.controlled()`."""

def controlled(
self,
num_controls: Union[Optional[int], 'CtrlSpec'] = None,
control_values: Optional[
Union[cirq.ops.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
] = None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
*,
ctrl_spec: Optional['CtrlSpec'] = None,
) -> 'Bloq':
"""Return a controlled version of self. Controls can be specified via Cirq/Bloq-style APIs.

If no arguments are specified, defaults to a single qubit control.

Supports both Cirq-style API and Bloq-style API to construct controlled Bloqs. The cirq-style
API is supported by intercepting the Cirq-style way of specifying a control specification;
via arguments `num_controls`, `control_values` and `control_qid_shape`, and constructing a
`CtrlSpec` object from it before delegating to `self.get_ctrl_system`.

By default, the system will use the `qualtran.Controlled` meta-bloq to wrap this
bloq. Bloqs authors can declare their own, custom controlled versions by overriding
`Bloq.get_ctrl_system` in the bloq.

If overriding the `GWR.controlled()` method directly, Bloq authors can use the
`self._get_ctrl_spec` helper to construct a `CtrlSpec` object from the input parameters of
`GWR.controlled()` and use it to return a custom controlled version of this Bloq.


Args:
num_controls: Cirq style API to specify control specification -
Total number of control qubits.
control_values: Cirq style API to specify control specification -
Which control computational basis state to apply the
sub gate. A sequence of length `num_controls` where each
entry is an integer (or set of integers) corresponding to the
computational basis state (or set of possible values) where that
control is enabled. When all controls are enabled, the sub gate is
applied. If unspecified, control values default to 1.
control_qid_shape: Cirq style API to specify control specification -
The qid shape of the controls. A tuple of the
expected dimension of each control qid. Defaults to
`(2,) * num_controls`. Specify this argument when using qudits.
ctrl_spec: Bloq style API to specify a control specification -
An optional keyword argument `CtrlSpec`, which specifies how to control
the bloq. The default spec means the bloq will be active when one control qubit is
in the |1> state. See the CtrlSpec documentation for more possibilities including
negative controls, integer-equality control, and ndarrays of control values.

Returns:
A controlled version of the bloq.
"""
ctrl_spec = self._get_ctrl_spec(
num_controls, control_values, control_qid_shape, ctrl_spec=ctrl_spec
)
ctrl_spec = CtrlSpec.from_cirq_cv(controlled_gate.control_values)
return BloqAsCirqGate(Controlled(self, ctrl_spec))
controlled_bloq, _ = self.get_ctrl_system(ctrl_spec=ctrl_spec)
return controlled_bloq

def _unitary_(self):
return NotImplemented
Expand Down
45 changes: 44 additions & 1 deletion qualtran/_infra/gate_with_registers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,19 @@
import numpy as np
import pytest

from qualtran import GateWithRegisters, QAny, QBit, Register, Side, Signature, SoquetT
from qualtran import (
Controlled,
CtrlSpec,
GateWithRegisters,
QAny,
QBit,
Register,
Side,
Signature,
SoquetT,
)
from qualtran.bloqs.basic_gates import XGate, YGate, ZGate
from qualtran.bloqs.util_bloqs import Power
from qualtran.testing import execute_notebook


Expand Down Expand Up @@ -51,6 +62,38 @@ def test_gate_with_registers():

np.testing.assert_allclose(cirq.unitary(tg), tg.tensor_contract())

# Test GWR.controlled() works correctly with Bloq and Cirq style API
ctrl = cirq.q('ctrl')
cop1 = tg.controlled().on(ctrl, *qubits[:5], *qubits[6:], qubits[5])
cop2 = tg.controlled().on_registers(ctrl=ctrl, r1=qubits[:5], r2=qubits[6:], r3=qubits[5])
cop3 = op1.controlled_by(ctrl)
cop4 = op2.controlled_by(ctrl)
assert cop1 == cop2 == cop3 == cop4
assert cop1.gate == cop2.gate == cop3.gate == cop4.gate == Controlled(tg, CtrlSpec())

assert (
tg.controlled(num_controls=1, control_values=[0])
== tg.controlled(control_values=[0], control_qid_shape=(2,))
== tg.controlled(CtrlSpec(cvs=0))
== tg.controlled(ctrl_spec=CtrlSpec(cvs=0))
)

# Test GWR.controlled() raises with incorrect invocation.
with pytest.raises(ValueError):
tg.controlled(control_values=[0], ctrl_spec=CtrlSpec())

with pytest.raises(ValueError):
tg.controlled(CtrlSpec(), control_values=[0])

with pytest.raises(ValueError):
tg.controlled(CtrlSpec(), ctrl_spec=CtrlSpec())

# Test GWR**pow
assert tg**-1 == tg.adjoint()
assert tg**1 is tg
assert tg**-10 == Power(tg.adjoint(), 10)
assert tg**10 == Power(tg, 10)


class _TestGateAtomic(GateWithRegisters):
@property
Expand Down
8 changes: 3 additions & 5 deletions qualtran/bloqs/phase_estimation/lp_resource_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,17 @@ def decompose_from_registers(
yield [OnEach(self.bitsize, Hadamard()).on(*q), Hadamard().on(*anc)]
for i in range(self.bitsize):
rz_angle = -2 * np.pi * (2**i) / (2**self.bitsize + 1)
yield cirq.Rz(rads=rz_angle).controlled().on(q[i], *anc)
yield Rz(angle=rz_angle).controlled().on(q[i], *anc)
yield Rz(angle=-2 * np.pi / (2**self.bitsize + 1)).on(*anc)
yield Hadamard().on(*anc)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
rz_angle = -2 * pi(self.bitsize) / (2**self.bitsize + 1)
ret = {(Rz(angle=rz_angle), 1), (Hadamard(), 2 + self.bitsize)}
if is_symbolic(self.bitsize):
ret |= {(Rz(angle=rz_angle).controlled().bloq, self.bitsize)}
ret |= {(Rz(angle=rz_angle).controlled(), self.bitsize)}
else:
ret |= {
(Rz(angle=rz_angle * (2**i)).controlled().bloq, 1) for i in range(self.bitsize)
}
ret |= {(Rz(angle=rz_angle * (2**i)).controlled(), 1) for i in range(self.bitsize)}
return ret

def _t_complexity_(self) -> 'TComplexity':
Expand Down
15 changes: 3 additions & 12 deletions qualtran/bloqs/qsp/generalized_qsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,7 @@
from numpy.polynomial import Polynomial
from numpy.typing import NDArray

from qualtran import (
bloq_example,
BloqDocSpec,
Controlled,
CtrlSpec,
GateWithRegisters,
QBit,
Register,
Signature,
)
from qualtran import bloq_example, BloqDocSpec, GateWithRegisters, QBit, Register, Signature
from qualtran.bloqs.basic_gates.su2_rotation import SU2RotationGate

if TYPE_CHECKING:
Expand Down Expand Up @@ -401,12 +392,12 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
counts = set(Counter(self.signal_rotations).items())

if degree > self.negative_power:
counts.add((Controlled(self.U, CtrlSpec(cvs=0)), degree - self.negative_power))
counts.add((self.U.controlled(control_values=[0]), degree - self.negative_power))
elif self.negative_power > degree:
counts.add((self.U.adjoint(), self.negative_power - degree))

if self.negative_power > 0:
counts.add((Controlled(self.U.adjoint(), CtrlSpec()), min(degree, self.negative_power)))
counts.add((self.U.adjoint().controlled(), min(degree, self.negative_power)))

return counts

Expand Down
5 changes: 3 additions & 2 deletions qualtran/bloqs/qubitization_walk_operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import pytest

from qualtran import Adjoint
from qualtran._infra.gate_with_registers import get_named_qubits, total_bits
from qualtran.bloqs.chemistry.ising import get_1d_ising_hamiltonian
from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiControlPauli
Expand Down Expand Up @@ -161,8 +162,8 @@ def test_qubitization_walk_operator_diagrams():

def keep(op):
ret = op in gateset_to_keep
if op.gate is not None and isinstance(op.gate, cirq.ops.raw_types._InverseCompositeGate):
ret |= op.gate._original in gateset_to_keep
if op.gate is not None and isinstance(op.gate, Adjoint):
ret |= op.gate.subbloq in gateset_to_keep
return ret

greedy_mm = cirq.GreedyQubitManager(prefix="ancilla", maximize_reuse=True)
Expand Down
Loading
Loading