Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove attribute power and default to Power/GateWithRegisters.__pow__ #903

Merged
merged 13 commits into from
May 6, 2024
15 changes: 15 additions & 0 deletions qualtran/_infra/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,18 @@ def as_cirq_op(
sub_op.controlled_by(*ctrl_qubits, control_values=self.ctrl_spec.to_cirq_cv()),
cirq_quregs | ctrl_regs,
)

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
from qualtran.cirq_interop._bloq_to_cirq import _wire_symbol_to_cirq_diagram_info

if isinstance(self.subbloq, cirq.Gate):
sub_info = cirq.circuit_diagram_info(self.subbloq, args, None)
if sub_info is not None:
cv_info = cirq.circuit_diagram_info(self.ctrl_spec.to_cirq_cv())

return cirq.CircuitDiagramInfo(
wire_symbols=(*cv_info.wire_symbols, *sub_info.wire_symbols),
exponent=sub_info.exponent,
)

return _wire_symbol_to_cirq_diagram_info(self, args)
Comment on lines +466 to +479
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One change after the approval: I had to add a special case here because _wire_symbol_to_cirq_diagram_info doesn't recurse into the subgate. eg. XPowGate(0.25).controlled() was rendering X instead of X^0.25.

I added tests now to check that this works correctly

49 changes: 49 additions & 0 deletions qualtran/_infra/controlled_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
OneState,
Swap,
XGate,
XPowGate,
YGate,
ZeroState,
ZGate,
Expand Down Expand Up @@ -402,3 +403,51 @@ def test_controlled_tensor_for_and_bloq(ctrl_spec: CtrlSpec):
_verify_ctrl_tensor_for_and(ctrl_spec, (1, 0))
_verify_ctrl_tensor_for_and(ctrl_spec, (0, 1))
_verify_ctrl_tensor_for_and(ctrl_spec, (0, 0))


def test_controlled_diagrams():
ctrl_gate = XPowGate(0.25).controlled()
cirq.testing.assert_has_diagram(
cirq.Circuit(ctrl_gate.on_registers(**get_named_qubits(ctrl_gate.signature))),
'''
ctrl: ───@────────
q: ──────X^0.25───''',
)

ctrl_0_gate = XPowGate(0.25).controlled(CtrlSpec(cvs=0))
cirq.testing.assert_has_diagram(
cirq.Circuit(ctrl_0_gate.on_registers(**get_named_qubits(ctrl_0_gate.signature))),
'''
ctrl: ───(0)──────
q: ──────X^0.25───''',
)

multi_ctrl_gate = XPowGate(0.25).controlled(CtrlSpec(cvs=[0, 1]))
cirq.testing.assert_has_diagram(
cirq.Circuit(multi_ctrl_gate.on_registers(**get_named_qubits(multi_ctrl_gate.signature))),
'''
ctrl[0]: ───(0)──────
ctrl[1]: ───@────────
q: ─────────X^0.25───''',
)

ctrl_bloq = Swap(2).controlled(CtrlSpec(cvs=[0, 1]))
cirq.testing.assert_has_diagram(
cirq.Circuit(ctrl_bloq.on_registers(**get_named_qubits(ctrl_bloq.signature))),
'''
ctrl[0]: ───(0)────
ctrl[1]: ───@──────
x0: ────────×(x)───
x1: ────────×(x)───
y0: ────────×(y)───
y1: ────────×(y)───''',
)
11 changes: 6 additions & 5 deletions qualtran/bloqs/for_testing/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from functools import cached_property
from typing import Any, Dict, Optional, TYPE_CHECKING

import attrs
import numpy as np
from attrs import frozen

Expand Down Expand Up @@ -127,6 +128,7 @@ class TestGWRAtom(GateWithRegisters):
"""

tag: Optional[str] = None
is_adjoint: bool = False

@cached_property
def signature(self) -> Signature:
Expand Down Expand Up @@ -157,16 +159,15 @@ def _unitary_(self):
return np.eye(2)

def adjoint(self) -> 'Bloq':
return self
return attrs.evolve(self, is_adjoint=not self.is_adjoint)

def _t_complexity_(self) -> 'TComplexity':
return TComplexity(100)

def __repr__(self):
if self.tag:
return f'TestGWRAtom({self.tag!r})'
else:
return 'TestGWRAtom()'
tag = f'{self.tag!r}' if self.tag else ''
dagger = '†' if self.is_adjoint else ''
return f'TestGWRAtom({tag}){dagger}'

def short_name(self) -> str:
if self.tag:
Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/for_testing/atom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ def test_test_gwr_atom():
assert ta.short_name() == 'GWRAtom'
with pytest.raises(DecomposeTypeError):
ta.decompose_bloq()
assert ta.adjoint() == ta
assert ta.adjoint() == TestGWRAtom(is_adjoint=True)
np.testing.assert_allclose(cirq.unitary(ta), np.eye(2))
21 changes: 3 additions & 18 deletions qualtran/bloqs/mean_estimation/mean_estimation_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class MeanEstimationOperator(GateWithRegisters):
cv: Tuple[int, ...] = attrs.field(
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
)
power: int = 1
arctan_bitsize: int = 32

@cv.validator
Expand Down Expand Up @@ -124,17 +123,12 @@ def decompose_from_registers(
) -> cirq.OP_TREE:
select_reg = {reg.name: quregs[reg.name] for reg in self.select.signature}
reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.signature}
select_op = self.select.on_registers(**select_reg)
reflect_op = self.reflect.on_registers(**reflect_reg)
for _ in range(self.power):
yield select_op
yield reflect_op
yield self.select.on_registers(**select_reg)
yield self.reflect.on_registers(**reflect_reg)

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
wire_symbols = [] if self.cv == () else [["@(0)", "@"][self.cv[0]]]
wire_symbols = [] if self.cv == () else [["(0)", "@"][self.cv[0]]]
wire_symbols += ['U_ko'] * (total_bits(self.signature) - total_bits(self.control_registers))
if self.power != 1:
wire_symbols[-1] = f'U_ko^{self.power}'
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

def controlled(
Expand All @@ -160,17 +154,8 @@ def controlled(
return MeanEstimationOperator(
CodeForRandomVariable(encoder=c_select, synthesizer=self.code.synthesizer),
cv=self.cv + (control_values[0],),
power=self.power,
arctan_bitsize=self.arctan_bitsize,
)
raise NotImplementedError(
f'Cannot create a controlled version of {self} with control_values={control_values}.'
)

def with_power(self, new_power: int) -> 'MeanEstimationOperator':
return MeanEstimationOperator(
self.code, cv=self.cv, power=new_power, arctan_bitsize=self.arctan_bitsize
)

def __pow__(self, power: int):
return self.with_power(self.power * power)
23 changes: 10 additions & 13 deletions qualtran/bloqs/mean_estimation/mean_estimation_operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,22 +281,19 @@ def test_mean_estimation_operator_consistent_protocols():
with pytest.raises(NotImplementedError, match="Cannot create a controlled version"):
_ = mean_gate.controlled(num_controls=2)

# Test with_power
assert mean_gate.with_power(5) ** 2 == MeanEstimationOperator(
code, arctan_bitsize=arctan_bitsize, power=10
)
# Test diagrams
expected_symbols = ['U_ko'] * cirq.num_qubits(mean_gate)
assert cirq.circuit_diagram_info(mean_gate).wire_symbols == tuple(expected_symbols)
control_symbols = ['@']
n_qubits = cirq.num_qubits(mean_gate)

assert cirq.circuit_diagram_info(mean_gate).wire_symbols == tuple(['U_ko'] * n_qubits)

assert cirq.circuit_diagram_info(mean_gate.controlled()).wire_symbols == tuple(
control_symbols + expected_symbols
['@'] + ['U_ko'] * n_qubits
)
control_symbols = ['@(0)']

assert cirq.circuit_diagram_info(
mean_gate.controlled(control_values=(0,))
).wire_symbols == tuple(control_symbols + expected_symbols)
expected_symbols[-1] = 'U_ko^2'
).wire_symbols == tuple(['(0)'] + ['U_ko'] * n_qubits)

assert cirq.circuit_diagram_info(
mean_gate.with_power(2).controlled(control_values=(0,))
).wire_symbols == tuple(control_symbols + expected_symbols)
(mean_gate**2).controlled(control_values=(0,))
).wire_symbols == tuple(['(0)'] + ['U_ko^2'] * n_qubits)
3 changes: 1 addition & 2 deletions qualtran/bloqs/qubitization_walk_operator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@
"#### Parameters\n",
" - `select`: The SELECT lcu gate implementing $SELECT=\\sum_{l}|l><l|H_{l}$.\n",
" - `prepare`: Then PREPARE lcu gate implementing $PREPARE|00...00> = \\sum_{l=0}^{L - 1}\\sqrt{\\frac{w_{l}}{\\lambda}} |l> = |\\ell>$\n",
" - `control_val`: If 0/1, a controlled version of the walk operator is constructed. Defaults to None, in which case the resulting walk operator is not controlled.\n",
" - `power`: Constructs $W^{power}$ by repeatedly decomposing into `power` copies of $W$. Defaults to 1. \n",
" - `control_val`: If 0/1, a controlled version of the walk operator is constructed. Defaults to None, in which case the resulting walk operator is not controlled. \n",
"\n",
"#### References\n",
" - [Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity] (https://arxiv.org/abs/1805.03662). Babbush et. al. (2018). Figure 1.\n"
Expand Down
23 changes: 2 additions & 21 deletions qualtran/bloqs/qubitization_walk_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from qualtran._infra.gate_with_registers import total_bits
from qualtran.bloqs.reflection_using_prepare import ReflectionUsingPrepare
from qualtran.bloqs.select_and_prepare import PrepareOracle, SelectOracle
from qualtran.cirq_interop.t_complexity_protocol import t_complexity
from qualtran.resource_counting.generalizers import (
cirq_to_bloqs,
ignore_cliffords,
Expand Down Expand Up @@ -55,8 +54,6 @@ class QubitizationWalkOperator(GateWithRegisters):
$PREPARE|00...00> = \sum_{l=0}^{L - 1}\sqrt{\frac{w_{l}}{\lambda}} |l> = |\ell>$
control_val: If 0/1, a controlled version of the walk operator is constructed. Defaults to
None, in which case the resulting walk operator is not controlled.
power: Constructs $W^{power}$ by repeatedly decomposing into `power` copies of $W$.
Defaults to 1.

References:
[Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity]
Expand All @@ -67,7 +64,6 @@ class QubitizationWalkOperator(GateWithRegisters):
select: SelectOracle
prepare: PrepareOracle
control_val: Optional[int] = None
power: int = 1

def __attrs_post_init__(self):
assert self.select.control_registers == self.reflect.control_registers
Expand Down Expand Up @@ -105,18 +101,14 @@ def decompose_from_registers(
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
) -> cirq.OP_TREE:
select_reg = {reg.name: quregs[reg.name] for reg in self.select.signature}
select_op = self.select.on_registers(**select_reg)
yield self.select.on_registers(**select_reg)

reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.signature}
reflect_op = self.reflect.on_registers(**reflect_reg)
for _ in range(self.power):
yield select_op
yield reflect_op
yield self.reflect.on_registers(**reflect_reg)

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
wire_symbols = ['@' if self.control_val else '@(0)'] * total_bits(self.control_registers)
wire_symbols += ['W'] * (total_bits(self.signature) - total_bits(self.control_registers))
wire_symbols[-1] = f'W^{self.power}' if self.power != 1 else 'W'
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

def controlled(
Expand Down Expand Up @@ -144,17 +136,6 @@ def controlled(
f'Cannot create a controlled version of {self} with control_values={control_values}.'
)

def with_power(self, new_power: int) -> 'QubitizationWalkOperator':
return attrs.evolve(self, power=new_power)

def __pow__(self, power: int):
return self.with_power(self.power * power)

def _t_complexity_(self):
if self.power > 1:
return self.power * t_complexity(self.with_power(1))
return NotImplemented


@bloq_example(generalizer=[cirq_to_bloqs, ignore_split_join, ignore_cliffords])
def _walk_op() -> QubitizationWalkOperator:
Expand Down
11 changes: 9 additions & 2 deletions qualtran/bloqs/qubitization_walk_operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,16 @@ def test_qubitization_walk_operator_diagrams():
target3: ──────SelectPauliLCU─────────
''',
)

# 2. Diagram for $W^{2} = SELECT.R_{L}.SELCT.R_{L}$
walk_squared_op = walk.with_power(2).on_registers(**g.quregs)
circuit = cirq.Circuit(cirq.decompose_once(walk_squared_op))
def decompose_twice(op):
ops = []
for sub_op in cirq.decompose_once(op):
ops += cirq.decompose_once(sub_op)
return ops

walk_squared_op = (walk**2).on_registers(**g.quregs)
circuit = cirq.Circuit(decompose_twice(walk_squared_op))
cirq.testing.assert_has_diagram(
circuit,
'''
Expand Down
26 changes: 24 additions & 2 deletions qualtran/bloqs/util_bloqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@
)
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.drawing import directional_text_box, WireSymbol
from qualtran.resource_counting.symbolic_counting_utils import SymbolicInt
from qualtran.resource_counting.symbolic_counting_utils import is_symbolic, SymbolicInt
from qualtran.simulation.classical_sim import bits_to_ints, ints_to_bits

if TYPE_CHECKING:
import cirq
from numpy.typing import NDArray

from qualtran import AddControlledT, CtrlSpec
Expand Down Expand Up @@ -497,7 +498,10 @@ class Power(GateWithRegisters):
def __attrs_post_init__(self):
if any(reg.side != Side.THRU for reg in self.bloq.signature):
raise ValueError('Bloq to repeat must have only THRU registers')
if self.power < 1:

if not is_symbolic(self.power) and (
not isinstance(self.power, (int, np.integer)) or self.power < 1
):
raise ValueError(f'{self.power=} must be a positive integer.')

def adjoint(self) -> 'Bloq':
Expand All @@ -514,3 +518,21 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(self.bloq, self.power)}

def __pow__(self, power) -> 'Power':
bloq = self.bloq.adjoint() if power < 0 else self.bloq
return Power(bloq, self.power * abs(power))
mpharrigan marked this conversation as resolved.
Show resolved Hide resolved

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
import cirq

info = cirq.circuit_diagram_info(self.bloq, args, default=None)

if info is None:
info = super()._circuit_diagram_info_(args)

wire_symbols = [f'{symbol}^{self.power}' for symbol in info.wire_symbols]

return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
Loading
Loading