Skip to content

Commit

Permalink
More improvements to the Controlled infrastructure (#879)
Browse files Browse the repository at this point in the history
* More improvemetns to the Controlled infrastructure

* Fix pylint

* Add docstrings, simplify usage of cirq-style API and address comments

* Add docstring
  • Loading branch information
tanujkhattar authored Apr 19, 2024
1 parent ee23b48 commit 7bf19b3
Show file tree
Hide file tree
Showing 11 changed files with 266 additions and 38 deletions.
9 changes: 9 additions & 0 deletions qualtran/_infra/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ def decompose_from_registers(
return cirq.inverse(self.subbloq.decompose_from_registers(context=context, **quregs))
return super().decompose_from_registers(context=context, **quregs)

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> cirq.CircuitDiagramInfo:
sub_info = cirq.circuit_diagram_info(self.subbloq, args, default=NotImplemented)
if sub_info is NotImplemented:
return NotImplemented
sub_info.exponent *= -1
return sub_info

def supports_decompose_bloq(self) -> bool:
"""Delegate to `subbloq.supports_decompose_bloq()`"""
return self.subbloq.supports_decompose_bloq()
Expand Down
19 changes: 18 additions & 1 deletion qualtran/_infra/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from .bloq import Bloq
from .data_types import QBit, QDType
from .gate_with_registers import GateWithRegisters
from .registers import Register, Side, Signature

if TYPE_CHECKING:
Expand Down Expand Up @@ -278,7 +279,7 @@ def _get_nice_ctrl_reg_names(reg_names: List[str], n: int) -> Tuple[str, ...]:


@attrs.frozen
class Controlled(Bloq):
class Controlled(GateWithRegisters):
"""A controlled version of `subbloq`.
This meta-bloq is part of the 'controlled' protocol. As a default fallback,
Expand Down Expand Up @@ -410,6 +411,19 @@ def add_my_tensors(
# Add the data to the tensor network.
tn.add(qtn.Tensor(data=data, inds=out_ind + in_ind, tags=[self.short_name(), tag]))

def _unitary_(self):
if isinstance(self.subbloq, GateWithRegisters):
# subbloq is a cirq gate, use the cirq-style API to derive a unitary.
return cirq.unitary(
cirq.ControlledGate(self.subbloq, control_values=self.ctrl_spec.to_cirq_cv())
)
if all(reg.side == Side.THRU for reg in self.subbloq.signature):
# subbloq has only THRU registers, so the tensor contraction corresponds
# to a unitary matrix.
return self.tensor_contract()
# Unable to determine the unitary effect.
return NotImplemented

def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
if soq.reg.name not in self.ctrl_reg_names:
# Delegate to subbloq
Expand All @@ -419,6 +433,9 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
i = self.ctrl_reg_names.index(soq.reg.name)
return self.ctrl_spec.wire_symbol(i, soq)

def adjoint(self) -> 'Bloq':
return self.subbloq.adjoint().controlled(self.ctrl_spec)

def pretty_name(self) -> str:
return f'C[{self.subbloq.pretty_name()}]'

Expand Down
2 changes: 1 addition & 1 deletion qualtran/_infra/controlled_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def test_notebook():

def _verify_ctrl_tensor_for_unitary(ctrl_spec: CtrlSpec, bloq: Bloq, gate: cirq.Gate):
cbloq = Controlled(bloq, ctrl_spec)
cgate = gate.controlled(control_values=ctrl_spec.to_cirq_cv())
cgate = cirq.ControlledGate(gate, control_values=ctrl_spec.to_cirq_cv())
np.testing.assert_array_equal(cbloq.tensor_contract(), cirq.unitary(cgate))


Expand Down
174 changes: 161 additions & 13 deletions qualtran/_infra/gate_with_registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,30 @@
# limitations under the License.

import abc
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from typing import (
Collection,
Dict,
Iterable,
List,
Optional,
overload,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)

import cirq
import numpy as np
from numpy.typing import NDArray

from qualtran._infra.bloq import Bloq, DecomposeNotImplementedError, DecomposeTypeError
from qualtran._infra.composite_bloq import CompositeBloq
from qualtran._infra.controlled import Controlled, CtrlSpec
from qualtran._infra.quantum_graph import Soquet
from qualtran._infra.registers import Register, Side

if TYPE_CHECKING:
from qualtran import CtrlSpec
from qualtran.cirq_interop import CirqQuregT
from qualtran.drawing import WireSymbol

Expand Down Expand Up @@ -311,23 +322,160 @@ def on_registers(
) -> cirq.Operation:
return self.on(*merge_qubits(self.signature, **qubit_regs))

def __pow__(self, power: int) -> 'GateWithRegisters':
bloq = self if power > 0 else self.adjoint()
if abs(power) == 1:
return bloq
if all(reg.side == Side.THRU for reg in self.signature):
from qualtran.bloqs.util_bloqs import Power

return Power(bloq, abs(power))
raise NotImplementedError(f"{self} does not implemented __pow__ for {power=}.")

def _get_ctrl_spec(
self,
num_controls: Union[Optional[int], 'CtrlSpec'] = None,
control_values=None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
*,
ctrl_spec: Optional['CtrlSpec'] = None,
) -> 'CtrlSpec':
"""Helper method to support Cirq & Bloq style APIs for constructing controlled Bloqs.
This method can be used to construct a `CtrlSpec` from either the Bloq-style API that
already accepts a `CtrlSpec` and simply returns it OR a Cirq-style API which accepts
parameters expected by `cirq.Gate.controlled()` and converts them to a `CtrlSpec` object.
Users implementing custom `GateWithRegisters.controlled()` overrides can use this helper
to generate a CtrlSpec from the cirq-style API and thus easily support both Cirq & Bloq
APIs. For example
>>> class CustomGWR(GateWithRegisters):
>>> def controlled(self, *args, **kwargs) -> 'Bloq':
>>> ctrl_spec = self._get_ctrl_spec(*args, **kwargs)
>>> # Use ctrl_spec to construct a controlled version of `self`.
Args:
num_controls: Cirq style API to specify control specification -
Total number of control qubits.
control_values: Cirq style API to specify control specification -
Which control computational basis state to apply the
sub gate. A sequence of length `num_controls` where each
entry is an integer (or set of integers) corresponding to the
computational basis state (or set of possible values) where that
control is enabled. When all controls are enabled, the sub gate is
applied. If unspecified, control values default to 1.
control_qid_shape: Cirq style API to specify control specification -
The qid shape of the controls. A tuple of the
expected dimension of each control qid. Defaults to
`(2,) * num_controls`. Specify this argument when using qudits.
ctrl_spec: Bloq style API to specify a control specification -
An optional keyword argument `CtrlSpec`, which specifies how to control
the bloq. The default spec means the bloq will be active when one control qubit is
in the |1> state. See the CtrlSpec documentation for more possibilities including
negative controls, integer-equality control, and ndarrays of control values.
"""
from qualtran._infra.controlled import CtrlSpec

ok = True
if ctrl_spec is not None:
# Bloq API invoked via kwargs - bloq.controlled(ctrl_spec=ctrl_spec)
ok &= control_values is None and control_qid_shape is None and num_controls is None
elif isinstance(num_controls, CtrlSpec):
# Bloq API invoked via args - bloq.controlled(ctrl_spec)
ok &= control_values is None and control_qid_shape is None
if not ok:
raise ValueError(
'GateWithRegisters.controlled() must be called with either cirq-style API'
f'or Bloq style API. Found arguments: {num_controls=}, '
f'{control_values=}, {control_qid_shape=}, {ctrl_spec=}'
)

if isinstance(num_controls, CtrlSpec):
ctrl_spec = num_controls
elif ctrl_spec is None:
controlled_gate = cirq.ControlledGate(
self,
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
)
ctrl_spec = CtrlSpec.from_cirq_cv(controlled_gate.control_values)
return ctrl_spec

# pylint: disable=arguments-renamed
@overload
def controlled(
self,
num_controls: Optional[int] = None,
control_values=None,
control_values: Optional[
Union[cirq.ops.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
] = None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
) -> 'cirq.Gate':
from qualtran.cirq_interop import BloqAsCirqGate

controlled_gate = cirq.ControlledGate(
self,
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
) -> 'GateWithRegisters':
"""Cirq-style API to construct a controlled gate. See `cirq.Gate.controlled()`"""

# pylint: disable=signature-differs
@overload
def controlled(self, ctrl_spec: Optional['CtrlSpec'] = None) -> 'GateWithRegisters':
"""Bloq-style API to construct a controlled Bloq. See `Bloq.controlled()`."""

def controlled(
self,
num_controls: Union[Optional[int], 'CtrlSpec'] = None,
control_values: Optional[
Union[cirq.ops.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
] = None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
*,
ctrl_spec: Optional['CtrlSpec'] = None,
) -> 'Bloq':
"""Return a controlled version of self. Controls can be specified via Cirq/Bloq-style APIs.
If no arguments are specified, defaults to a single qubit control.
Supports both Cirq-style API and Bloq-style API to construct controlled Bloqs. The cirq-style
API is supported by intercepting the Cirq-style way of specifying a control specification;
via arguments `num_controls`, `control_values` and `control_qid_shape`, and constructing a
`CtrlSpec` object from it before delegating to `self.get_ctrl_system`.
By default, the system will use the `qualtran.Controlled` meta-bloq to wrap this
bloq. Bloqs authors can declare their own, custom controlled versions by overriding
`Bloq.get_ctrl_system` in the bloq.
If overriding the `GWR.controlled()` method directly, Bloq authors can use the
`self._get_ctrl_spec` helper to construct a `CtrlSpec` object from the input parameters of
`GWR.controlled()` and use it to return a custom controlled version of this Bloq.
Args:
num_controls: Cirq style API to specify control specification -
Total number of control qubits.
control_values: Cirq style API to specify control specification -
Which control computational basis state to apply the
sub gate. A sequence of length `num_controls` where each
entry is an integer (or set of integers) corresponding to the
computational basis state (or set of possible values) where that
control is enabled. When all controls are enabled, the sub gate is
applied. If unspecified, control values default to 1.
control_qid_shape: Cirq style API to specify control specification -
The qid shape of the controls. A tuple of the
expected dimension of each control qid. Defaults to
`(2,) * num_controls`. Specify this argument when using qudits.
ctrl_spec: Bloq style API to specify a control specification -
An optional keyword argument `CtrlSpec`, which specifies how to control
the bloq. The default spec means the bloq will be active when one control qubit is
in the |1> state. See the CtrlSpec documentation for more possibilities including
negative controls, integer-equality control, and ndarrays of control values.
Returns:
A controlled version of the bloq.
"""
ctrl_spec = self._get_ctrl_spec(
num_controls, control_values, control_qid_shape, ctrl_spec=ctrl_spec
)
ctrl_spec = CtrlSpec.from_cirq_cv(controlled_gate.control_values)
return BloqAsCirqGate(Controlled(self, ctrl_spec))
controlled_bloq, _ = self.get_ctrl_system(ctrl_spec=ctrl_spec)
return controlled_bloq

def _unitary_(self):
return NotImplemented
Expand Down
45 changes: 44 additions & 1 deletion qualtran/_infra/gate_with_registers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,19 @@
import numpy as np
import pytest

from qualtran import GateWithRegisters, QAny, QBit, Register, Side, Signature, SoquetT
from qualtran import (
Controlled,
CtrlSpec,
GateWithRegisters,
QAny,
QBit,
Register,
Side,
Signature,
SoquetT,
)
from qualtran.bloqs.basic_gates import XGate, YGate, ZGate
from qualtran.bloqs.util_bloqs import Power
from qualtran.testing import execute_notebook


Expand Down Expand Up @@ -51,6 +62,38 @@ def test_gate_with_registers():

np.testing.assert_allclose(cirq.unitary(tg), tg.tensor_contract())

# Test GWR.controlled() works correctly with Bloq and Cirq style API
ctrl = cirq.q('ctrl')
cop1 = tg.controlled().on(ctrl, *qubits[:5], *qubits[6:], qubits[5])
cop2 = tg.controlled().on_registers(ctrl=ctrl, r1=qubits[:5], r2=qubits[6:], r3=qubits[5])
cop3 = op1.controlled_by(ctrl)
cop4 = op2.controlled_by(ctrl)
assert cop1 == cop2 == cop3 == cop4
assert cop1.gate == cop2.gate == cop3.gate == cop4.gate == Controlled(tg, CtrlSpec())

assert (
tg.controlled(num_controls=1, control_values=[0])
== tg.controlled(control_values=[0], control_qid_shape=(2,))
== tg.controlled(CtrlSpec(cvs=0))
== tg.controlled(ctrl_spec=CtrlSpec(cvs=0))
)

# Test GWR.controlled() raises with incorrect invocation.
with pytest.raises(ValueError):
tg.controlled(control_values=[0], ctrl_spec=CtrlSpec())

with pytest.raises(ValueError):
tg.controlled(CtrlSpec(), control_values=[0])

with pytest.raises(ValueError):
tg.controlled(CtrlSpec(), ctrl_spec=CtrlSpec())

# Test GWR**pow
assert tg**-1 == tg.adjoint()
assert tg**1 is tg
assert tg**-10 == Power(tg.adjoint(), 10)
assert tg**10 == Power(tg, 10)


class _TestGateAtomic(GateWithRegisters):
@property
Expand Down
8 changes: 3 additions & 5 deletions qualtran/bloqs/phase_estimation/lp_resource_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,17 @@ def decompose_from_registers(
yield [OnEach(self.bitsize, Hadamard()).on(*q), Hadamard().on(*anc)]
for i in range(self.bitsize):
rz_angle = -2 * np.pi * (2**i) / (2**self.bitsize + 1)
yield cirq.Rz(rads=rz_angle).controlled().on(q[i], *anc)
yield Rz(angle=rz_angle).controlled().on(q[i], *anc)
yield Rz(angle=-2 * np.pi / (2**self.bitsize + 1)).on(*anc)
yield Hadamard().on(*anc)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
rz_angle = -2 * pi(self.bitsize) / (2**self.bitsize + 1)
ret = {(Rz(angle=rz_angle), 1), (Hadamard(), 2 + self.bitsize)}
if is_symbolic(self.bitsize):
ret |= {(Rz(angle=rz_angle).controlled().bloq, self.bitsize)}
ret |= {(Rz(angle=rz_angle).controlled(), self.bitsize)}
else:
ret |= {
(Rz(angle=rz_angle * (2**i)).controlled().bloq, 1) for i in range(self.bitsize)
}
ret |= {(Rz(angle=rz_angle * (2**i)).controlled(), 1) for i in range(self.bitsize)}
return ret

def _t_complexity_(self) -> 'TComplexity':
Expand Down
15 changes: 3 additions & 12 deletions qualtran/bloqs/qsp/generalized_qsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,7 @@
from numpy.polynomial import Polynomial
from numpy.typing import NDArray

from qualtran import (
bloq_example,
BloqDocSpec,
Controlled,
CtrlSpec,
GateWithRegisters,
QBit,
Register,
Signature,
)
from qualtran import bloq_example, BloqDocSpec, GateWithRegisters, QBit, Register, Signature
from qualtran.bloqs.basic_gates.su2_rotation import SU2RotationGate

if TYPE_CHECKING:
Expand Down Expand Up @@ -401,12 +392,12 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
counts = set(Counter(self.signal_rotations).items())

if degree > self.negative_power:
counts.add((Controlled(self.U, CtrlSpec(cvs=0)), degree - self.negative_power))
counts.add((self.U.controlled(control_values=[0]), degree - self.negative_power))
elif self.negative_power > degree:
counts.add((self.U.adjoint(), self.negative_power - degree))

if self.negative_power > 0:
counts.add((Controlled(self.U.adjoint(), CtrlSpec()), min(degree, self.negative_power)))
counts.add((self.U.adjoint().controlled(), min(degree, self.negative_power)))

return counts

Expand Down
Loading

0 comments on commit 7bf19b3

Please sign in to comment.