-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More improvements to the Controlled infrastructure #879
Changes from 2 commits
c72bcef
065cabf
fd1e74f
d92095b
0ba0f69
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
|
||
from .bloq import Bloq | ||
from .data_types import QBit, QDType | ||
from .gate_with_registers import GateWithRegisters | ||
from .registers import Register, Side, Signature | ||
|
||
if TYPE_CHECKING: | ||
|
@@ -278,7 +279,7 @@ def _get_nice_ctrl_reg_names(reg_names: List[str], n: int) -> Tuple[str, ...]: | |
|
||
|
||
@attrs.frozen | ||
class Controlled(Bloq): | ||
class Controlled(GateWithRegisters): | ||
"""A controlled version of `subbloq`. | ||
|
||
This meta-bloq is part of the 'controlled' protocol. As a default fallback, | ||
|
@@ -410,6 +411,18 @@ def add_my_tensors( | |
# Add the data to the tensor network. | ||
tn.add(qtn.Tensor(data=data, inds=out_ind + in_ind, tags=[self.short_name(), tag])) | ||
|
||
def _unitary_(self): | ||
if isinstance(self.subbloq, GateWithRegisters): | ||
# subbloq is a cirq gate, use the cirq-style API to derive a unitary. | ||
return cirq.unitary( | ||
cirq.ControlledGate(self.subbloq, control_values=self.ctrl_spec.to_cirq_cv()) | ||
) | ||
if all(reg.side == Side.THRU for reg in self.subbloq.signature): | ||
# subbloq has only THRU registers, so the unitary is well defined. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pycharm tells me this is supposed to be "well-defined" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reworded |
||
return self.tensor_contract() | ||
# Unable to determine the unitary effect. | ||
return NotImplemented | ||
|
||
def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol': | ||
if soq.reg.name not in self.ctrl_reg_names: | ||
# Delegate to subbloq | ||
|
@@ -419,6 +432,9 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol': | |
i = self.ctrl_reg_names.index(soq.reg.name) | ||
return self.ctrl_spec.wire_symbol(i, soq) | ||
|
||
def adjoint(self) -> 'Bloq': | ||
return self.subbloq.adjoint().controlled(self.ctrl_spec) | ||
|
||
def pretty_name(self) -> str: | ||
return f'C[{self.subbloq.pretty_name()}]' | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,6 @@ | |
|
||
from qualtran._infra.bloq import Bloq, DecomposeNotImplementedError, DecomposeTypeError | ||
from qualtran._infra.composite_bloq import CompositeBloq | ||
from qualtran._infra.controlled import Controlled, CtrlSpec | ||
from qualtran._infra.quantum_graph import Soquet | ||
from qualtran._infra.registers import Register, Side | ||
|
||
|
@@ -311,23 +310,40 @@ def on_registers( | |
) -> cirq.Operation: | ||
return self.on(*merge_qubits(self.signature, **qubit_regs)) | ||
|
||
def __pow__(self, power: int) -> 'GateWithRegisters': | ||
if power == -1: | ||
return self.adjoint() | ||
if power == 1: | ||
return self | ||
raise NotImplementedError("pow is not implemented.") | ||
|
||
# pylint: disable=arguments-renamed | ||
def controlled( | ||
self, | ||
num_controls: Optional[int] = None, | ||
num_controls: Union[Optional[int], 'CtrlSpec'] = None, | ||
control_values=None, | ||
control_qid_shape: Optional[Tuple[int, ...]] = None, | ||
) -> 'cirq.Gate': | ||
from qualtran.cirq_interop import BloqAsCirqGate | ||
|
||
controlled_gate = cirq.ControlledGate( | ||
self, | ||
num_controls=num_controls, | ||
control_values=control_values, | ||
control_qid_shape=control_qid_shape, | ||
) | ||
ctrl_spec = CtrlSpec.from_cirq_cv(controlled_gate.control_values) | ||
return BloqAsCirqGate(Controlled(self, ctrl_spec)) | ||
) -> 'GateWithRegisters': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be useful to provide overloaded type hints? # pylint: disable=arguments-renamed
@overload
def controlled(
self,
num_controls: Optional[int] = None,
control_values=None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
) -> 'GateWithRegisters':
...
# pylint: disable=signature-differs
@overload
def controlled(self, ctrl_spec: 'CtrlSpec') -> 'GateWithRegisters':
...
# pylint: disable=arguments-renamed
def controlled(self, num_controls=None, control_values=None, control_qid_shape=None): Right now it's difficult to understand the usage without looking at the function body. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I think this needs a big docstring that describes the two modes of operation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a detailed docstring and the overloads. Also added tests to verify the different cases raise errors. |
||
from qualtran._infra.controlled import Controlled, CtrlSpec | ||
|
||
if isinstance(num_controls, CtrlSpec): | ||
if not (control_values is None and control_qid_shape is None): | ||
raise ValueError( | ||
'GateWithRegisters.controlled() must be called with either cirq-style API' | ||
f'or Bloq style API. Found arguments: {num_controls=}, ' | ||
f'{control_values=}, {control_qid_shape=}' | ||
) | ||
ctrl_spec = num_controls | ||
else: | ||
controlled_gate = cirq.ControlledGate( | ||
self, | ||
num_controls=num_controls, | ||
control_values=control_values, | ||
control_qid_shape=control_qid_shape, | ||
) | ||
ctrl_spec = CtrlSpec.from_cirq_cv(controlled_gate.control_values) | ||
|
||
return Controlled(self, ctrl_spec) | ||
|
||
def _unitary_(self): | ||
return NotImplemented | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,14 +12,16 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from functools import cached_property | ||
from typing import Set, Tuple | ||
from typing import Sequence, Set, Tuple | ||
|
||
import attrs | ||
import cirq | ||
|
||
from qualtran import Bloq, bloq_example, BloqDocSpec, GateWithRegisters, QFxp, Register, Signature | ||
from qualtran._infra.gate_with_registers import merge_qubits | ||
from qualtran.bloqs.basic_gates import Hadamard, OnEach | ||
from qualtran.bloqs.qft.qft_text_book import QFTTextBook | ||
from qualtran.bloqs.util_bloqs import Power | ||
from qualtran.resource_counting.symbolic_counting_utils import ( | ||
ceil, | ||
is_symbolic, | ||
|
@@ -220,13 +222,14 @@ def decompose_from_registers( | |
self, context: cirq.DecompositionContext, **quregs | ||
) -> cirq.OP_TREE: | ||
target_quregs = {reg.name: quregs[reg.name] for reg in self.target_registers} | ||
unitary_op = self.unitary.on_registers(**target_quregs) | ||
target_qubits = merge_qubits(self.target_registers, **target_quregs) | ||
|
||
phase_qubits = quregs['qpe_reg'] | ||
phase_qubits: Sequence[cirq.Qid] = quregs['qpe_reg'].flatten() | ||
|
||
yield self.ctrl_state_prep.on(*phase_qubits) | ||
for i, qbit in enumerate(phase_qubits[::-1]): | ||
yield cirq.pow(unitary_op.controlled_by(qbit), 2**i) | ||
yield Power(self.unitary.controlled(), 2**i).on(qbit, *target_qubits) | ||
# yield cirq.pow(unitary_op.controlled_by(qbit), 2**i) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove old code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reverted the file so now there are no changes and the original implementation uses bloqs instead of cirq gates. See added tests in |
||
yield self.qft_inv.on(*phase_qubits) | ||
|
||
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,16 +20,8 @@ | |
from numpy.polynomial import Polynomial | ||
from numpy.typing import NDArray | ||
|
||
from qualtran import ( | ||
bloq_example, | ||
BloqDocSpec, | ||
Controlled, | ||
CtrlSpec, | ||
GateWithRegisters, | ||
QBit, | ||
Register, | ||
Signature, | ||
) | ||
from qualtran import bloq_example, BloqDocSpec, GateWithRegisters, QBit, Register, Signature | ||
from qualtran._infra.gate_with_registers import merge_qubits | ||
from qualtran.bloqs.basic_gates.su2_rotation import SU2RotationGate | ||
|
||
if TYPE_CHECKING: | ||
|
@@ -382,14 +374,15 @@ def decompose_from_registers( | |
num_inverse_applications = self.negative_power | ||
|
||
yield self.signal_rotations[0].on(signal_qubit) | ||
flat_qubits_for_u = merge_qubits(self.U.signature, **quregs) | ||
for signal_rotation in self.signal_rotations[1:]: | ||
if num_inverse_applications > 0: | ||
# apply C-U^\dagger | ||
yield self.U.adjoint().on_registers(**quregs).controlled_by(signal_qubit) | ||
yield self.U.adjoint().controlled().on(signal_qubit, *flat_qubits_for_u) | ||
num_inverse_applications -= 1 | ||
else: | ||
# apply C[0]-U | ||
yield self.U.on_registers(**quregs).controlled_by(signal_qubit, control_values=[0]) | ||
yield self.U.controlled(control_values=[0]).on(signal_qubit, *flat_qubits_for_u) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we default to using qualtran style controls There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reverted this change because |
||
yield signal_rotation.on(signal_qubit) | ||
|
||
for _ in range(num_inverse_applications): | ||
|
@@ -401,12 +394,12 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: | |
counts = set(Counter(self.signal_rotations).items()) | ||
|
||
if degree > self.negative_power: | ||
counts.add((Controlled(self.U, CtrlSpec(cvs=0)), degree - self.negative_power)) | ||
counts.add((self.U.controlled(control_values=[0]), degree - self.negative_power)) | ||
elif self.negative_power > degree: | ||
counts.add((self.U.adjoint(), self.negative_power - degree)) | ||
|
||
if self.negative_power > 0: | ||
counts.add((Controlled(self.U.adjoint(), CtrlSpec()), min(degree, self.negative_power))) | ||
counts.add((self.U.adjoint().controlled(), min(degree, self.negative_power))) | ||
|
||
return counts | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return type annotation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done