Skip to content

Commit

Permalink
Remove sneaky fredkin decomposition (#803)
Browse files Browse the repository at this point in the history
* Remove sneaky fredkin decomposition

This should be a gate/bloq

* lint

* Port clifford=10 to qualtran

* break from cirq

* lint
  • Loading branch information
mpharrigan authored Mar 26, 2024
1 parent a0ba6e9 commit 5c57c95
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 121 deletions.
52 changes: 34 additions & 18 deletions qualtran/bloqs/basic_gates/swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
SoquetT,
)
from qualtran.bloqs.util_bloqs import ArbitraryClifford
from qualtran.cirq_interop import CirqQuregT, decompose_from_cirq_style_method
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.drawing import Circle, TextBox, WireSymbol
from qualtran.resource_counting.generalizers import ignore_split_join

from .t_gate import TGate

if TYPE_CHECKING:
from qualtran.cirq_interop import CirqQuregT
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT

Expand Down Expand Up @@ -112,26 +112,40 @@ class TwoBitCSwap(Bloq):
ctrl: the control bit
x: the first bit
y: the second bit
"""
def short_name(self) -> str:
return 'swap'
References:
[An algorithm for the T-count](https://arxiv.org/abs/1308.4134).
Gosset et. al. 2013. Figure 5.2.
"""

@cached_property
def signature(self) -> Signature:
return Signature.build(ctrl=1, x=1, y=1)

def as_cirq_op(
def decompose_bloq(self) -> 'CompositeBloq':
return decompose_from_cirq_style_method(self)

def decompose_from_registers(
self,
qubit_manager: 'cirq.QubitManager',
ctrl: 'CirqQuregT',
x: 'CirqQuregT',
y: 'CirqQuregT',
) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]:
*,
context: cirq.DecompositionContext,
ctrl: NDArray[cirq.Qid],
x: NDArray[cirq.Qid],
y: NDArray[cirq.Qid],
) -> cirq.OP_TREE:
(ctrl,) = ctrl
(x,) = x
(y,) = y
return cirq.CSWAP.on(ctrl, x, y), {'ctrl': [ctrl], 'x': [x], 'y': [y]}
yield [cirq.CNOT(y, x)]
yield [cirq.CNOT(ctrl, x), cirq.H(y)]
yield [cirq.T(ctrl), cirq.T(x) ** -1, cirq.T(y)]
yield [cirq.CNOT(y, x)]
yield [cirq.CNOT(ctrl, y), cirq.T(x)]
yield [cirq.CNOT(ctrl, x), cirq.T(y) ** -1]
yield [cirq.T(x) ** -1, cirq.CNOT(ctrl, y)]
yield [cirq.CNOT(y, x)]
yield [cirq.T(x), cirq.H(y)]
yield [cirq.CNOT(y, x)]

def add_my_tensors(
self,
Expand All @@ -156,13 +170,6 @@ def on_classical_vals(
raise ValueError("Bad control value for TwoBitCSwap classical simulation.")

def _t_complexity_(self) -> 'TComplexity':
"""The t complexity.
References:
[An algorithm for the T-count](https://arxiv.org/abs/1308.4134). Gosset et. al. 2013.
Figure 5.2.
"""
# https://arxiv.org/abs/1308.4134
return TComplexity(t=7, clifford=10)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
Expand All @@ -171,6 +178,15 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def adjoint(self) -> 'Bloq':
return self

def short_name(self) -> str:
return 'swap'

def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
if soq.reg.name == 'ctrl':
return Circle(filled=True)
else:
return TextBox('×')


@frozen
class Swap(Bloq):
Expand Down
18 changes: 7 additions & 11 deletions qualtran/bloqs/basic_gates/swap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
_swap_matrix,
_swap_small,
)
from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity
from qualtran.resource_counting.generalizers import ignore_split_join


Expand Down Expand Up @@ -91,6 +92,9 @@ def _set_ctrl_two_bit_swap(ctrl_bit):
def test_two_bit_cswap():
cswap = TwoBitCSwap()
np.testing.assert_array_equal(cswap.tensor_contract(), cirq.unitary(cirq.CSWAP))
np.testing.assert_allclose(
cswap.decompose_bloq().tensor_contract(), cirq.unitary(cirq.CSWAP), atol=1e-8
)

# Zero ctrl -- it's identity
np.testing.assert_array_equal(np.eye(4), _set_ctrl_two_bit_swap(0).tensor_contract())
Expand All @@ -105,13 +109,6 @@ def test_two_bit_cswap():
ctrl, x, y = cswap.call_classically(ctrl=1, x=1, y=0)
assert (ctrl, x, y) == (1, 0, 1)

# cirq
c1 = cirq.Circuit([cirq.CSWAP(*cirq.LineQubit.range(3))]).freeze()
c2, _ = cswap.as_composite_bloq().to_cirq_circuit(
ctrl=[cirq.LineQubit(0)], x=[cirq.LineQubit(1)], y=[cirq.LineQubit(2)]
)
assert c1 == c2


def _set_ctrl_swap(ctrl_bit, bloq: CSwap):
states = [ZeroState(), OneState()]
Expand Down Expand Up @@ -154,10 +151,6 @@ def test_cswap_cirq_decomp():
y2: ─────×(y)───────────×───
''',
)
expected_circuit = cirq.Circuit(
cswap_op, [cirq.CSWAP(*quregs['ctrl'], x, y) for (x, y) in zip(quregs['x'], quregs['y'])]
)
cirq.testing.assert_same_circuits(circuit, expected_circuit)


def test_cswap_unitary():
Expand Down Expand Up @@ -195,6 +188,9 @@ def test_cswap_bloq_counts():
counts2 = bloq.decompose_bloq().bloq_counts(generalizer=ignore_split_join)
assert counts1 == counts2

assert t_complexity(CSwap(1)) == TComplexity(t=7, clifford=10)
assert t_complexity(TwoBitCSwap()) == TComplexity(t=7, clifford=10)


def test_cswap_symbolic():
n = sympy.symbols('n')
Expand Down
6 changes: 5 additions & 1 deletion qualtran/bloqs/swap_network/cswap_approx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,12 @@ def get_t_count_and_clifford(bc: Dict[Bloq, int]) -> Tuple[int, int]:


@pytest.mark.parametrize("n", [*range(1, 6)])
def test_t_complexity(n):
def test_t_complexity_cswap(n):
cq_testing.assert_decompose_is_consistent_with_t_complexity(CSwap(n))


@pytest.mark.parametrize("n", [*range(1, 6)])
def test_t_complexity_cswap_approx(n):
cq_testing.assert_decompose_is_consistent_with_t_complexity(CSwapApprox(n))


Expand Down
57 changes: 1 addition & 56 deletions qualtran/cirq_interop/decompose_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,62 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, FrozenSet, Sequence
from typing import Any

import cirq
from cirq.protocols.decompose_protocol import DecomposeResult

_FREDKIN_GATESET = cirq.Gateset(cirq.FREDKIN, unroll_circuit_op=False)


def _fredkin(qubits: Sequence[cirq.Qid], context: cirq.DecompositionContext) -> cirq.OP_TREE:
"""Decomposition with 7 T and 10 clifford operations from https://arxiv.org/abs/1308.4134"""
c, t1, t2 = qubits
yield [cirq.CNOT(t2, t1)]
yield [cirq.CNOT(c, t1), cirq.H(t2)]
yield [cirq.T(c), cirq.T(t1) ** -1, cirq.T(t2)]
yield [cirq.CNOT(t2, t1)]
yield [cirq.CNOT(c, t2), cirq.T(t1)]
yield [cirq.CNOT(c, t1), cirq.T(t2) ** -1]
yield [cirq.T(t1) ** -1, cirq.CNOT(c, t2)]
yield [cirq.CNOT(t2, t1)]
yield [cirq.T(t1), cirq.H(t2)]
yield [cirq.CNOT(t2, t1)]


def _try_decompose_from_known_decompositions(
val: Any, context: cirq.DecompositionContext
) -> DecomposeResult:
"""Returns a flattened decomposition of the object into operations, if possible.
Args:
val: The object to decompose.
context: Decomposition context storing common configurable options for `cirq.decompose`.
Returns:
A flattened decomposition of `val` if it's a gate or operation with a known decomposition.
"""
if not isinstance(val, (cirq.Gate, cirq.Operation)):
return None
qubits = cirq.LineQid.for_gate(val) if isinstance(val, cirq.Gate) else val.qubits
known_decompositions = [(_FREDKIN_GATESET, _fredkin)]

classical_controls: FrozenSet[cirq.Condition] = frozenset()
if isinstance(val, cirq.ClassicallyControlledOperation):
classical_controls = val.classical_controls
val = val.without_classical_controls()

decomposition = None
for gateset, decomposer in known_decompositions:
if val in gateset:
decomposition = cirq.flatten_to_ops(decomposer(qubits, context))
break
return (
tuple(op.with_classical_controls(*classical_controls) for op in decomposition)
if decomposition
else None
)


def _decompose_once_considering_known_decomposition(val: Any) -> DecomposeResult:
"""Decomposes a value into operations, if possible.
Expand All @@ -84,10 +33,6 @@ def _decompose_once_considering_known_decomposition(val: Any) -> DecomposeResult
qubit_manager=cirq.GreedyQubitManager(prefix=f'_{uuid.uuid4()}', maximize_reuse=True)
)

decomposed = _try_decompose_from_known_decompositions(val, context)
if decomposed is not None:
return decomposed

if isinstance(val, cirq.Gate):
decomposed = cirq.decompose_once_with_qubits(
val, cirq.LineQid.for_gate(val), context=context, flatten=False, default=None
Expand Down
35 changes: 1 addition & 34 deletions qualtran/cirq_interop/decompose_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,8 @@
# limitations under the License.

import cirq
import numpy as np
import pytest

from qualtran.cirq_interop.decompose_protocol import (
_decompose_once_considering_known_decomposition,
_fredkin,
_try_decompose_from_known_decompositions,
)


def test_fredkin_unitary():
c, t1, t2 = cirq.LineQid.for_gate(cirq.FREDKIN)
context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager())
np.testing.assert_allclose(
cirq.Circuit(_fredkin((c, t1, t2), context)).unitary(),
cirq.unitary(cirq.FREDKIN(c, t1, t2)),
atol=1e-8,
)


@pytest.mark.parametrize('gate', [cirq.FREDKIN, cirq.FREDKIN**-1])
def test_decompose_fredkin(gate):
c, t1, t2 = cirq.LineQid.for_gate(cirq.FREDKIN)
op = cirq.FREDKIN(c, t1, t2)
context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager())
want = tuple(cirq.flatten_op_tree(_fredkin((c, t1, t2), context)))
assert want == _try_decompose_from_known_decompositions(op, context)

op = cirq.FREDKIN(c, t1, t2).with_classical_controls('key')
classical_controls = op.classical_controls
want = tuple(
o.with_classical_controls(*classical_controls)
for o in cirq.flatten_op_tree(_fredkin((c, t1, t2), context))
)
assert want == _try_decompose_from_known_decompositions(op, context)
from qualtran.cirq_interop.decompose_protocol import _decompose_once_considering_known_decomposition


def test_known_decomposition_empty_unitary():
Expand Down
2 changes: 1 addition & 1 deletion qualtran/cirq_interop/t_complexity_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_gates():
assert t_complexity(And()) == TComplexity(t=4, clifford=9)
assert t_complexity(And() ** -1) == TComplexity(clifford=4)

assert t_complexity(cirq.FREDKIN) == TComplexity(t=7, clifford=10)
assert t_complexity(cirq.FREDKIN) == TComplexity(t=7, clifford=14)


def test_operations():
Expand Down

0 comments on commit 5c57c95

Please sign in to comment.