Skip to content

Commit

Permalink
Controlled Bloq learns add_my_tensors (#815)
Browse files Browse the repository at this point in the history
* Support multiple registers and corresponding ctrl values in CtrlSpec

* Revert changes to select_and_prepare

* Documentation updates and renames

* More documentation updates

* docstring for shapes

* QDTypes learn to_bits and from_bits

* Fix pylint

* Add conversion between CtrlSpec and cirq.SumOfProducts

* Fix merge conflict

* Update docstrings

* Add my tensors for Controlled Bloq

* Reformat using latest black

* Refactor into simulation/tensor/_tensor_data_manipulation.py

* Add more tests and fix formatting

* update docstring of active_space_for_ctrl_spec
  • Loading branch information
tanujkhattar authored Mar 26, 2024
1 parent d0d6bd1 commit a0ba6e9
Show file tree
Hide file tree
Showing 6 changed files with 384 additions and 40 deletions.
44 changes: 39 additions & 5 deletions qualtran/_infra/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import attrs
import cirq
import numpy as np
import quimb.tensor as qtn
from numpy.typing import NDArray

from .bloq import Bloq
Expand Down Expand Up @@ -250,7 +251,8 @@ class AddControlledT(Protocol):

def __call__(
self, bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: Dict[str, 'SoquetT']
) -> Tuple[Iterable['SoquetT'], Iterable['SoquetT']]: ...
) -> Tuple[Iterable['SoquetT'], Iterable['SoquetT']]:
...


def _get_nice_ctrl_reg_names(reg_names: List[str], n: int) -> Tuple[str, ...]:
Expand Down Expand Up @@ -327,13 +329,16 @@ def ctrl_reg_names(self) -> Sequence[str]:
return _get_nice_ctrl_reg_names(reg_names, n)

@cached_property
def signature(self) -> 'Signature':
# Prepend register(s) corresponding to `ctrl_spec`.
ctrl_regs = tuple(
def ctrl_regs(self) -> Tuple[Register, ...]:
return tuple(
Register(name=self.ctrl_reg_names[i], dtype=qdtype, shape=shape, side=Side.THRU)
for i, (qdtype, shape) in enumerate(self.ctrl_spec.activation_function_dtypes())
)
return Signature(ctrl_regs + tuple(self.subbloq.signature))

@cached_property
def signature(self) -> 'Signature':
# Prepend register(s) corresponding to `ctrl_spec`.
return Signature(self.ctrl_regs + tuple(self.subbloq.signature))

def decompose_bloq(self) -> 'CompositeBloq':
# Use subbloq's decomposition but wire up the additional ctrl_soqs.
Expand Down Expand Up @@ -376,6 +381,35 @@ def on_classical_vals(self, **vals: 'ClassicalValT') -> Dict[str, 'ClassicalValT

return vals

def add_my_tensors(
self,
tn: 'qtn.TensorNetwork',
tag: Any,
*,
incoming: Dict[str, 'SoquetT'],
outgoing: Dict[str, 'SoquetT'],
):
from qualtran._infra.composite_bloq import _flatten_soquet_collection
from qualtran.simulation.tensor._tensor_data_manipulation import (
active_space_for_ctrl_spec,
eye_tensor_for_signature,
tensor_shape_from_signature,
)

# Create an identity tensor corresponding to the signature of current Bloq
data = eye_tensor_for_signature(self.signature)
# Verify it has the right shape
in_ind = _flatten_soquet_collection(incoming[reg.name] for reg in self.signature.lefts())
out_ind = _flatten_soquet_collection(outgoing[reg.name] for reg in self.signature.rights())
assert data.shape == tuple(2**soq.reg.bitsize for ind in [out_ind, in_ind] for soq in ind)
# Figure out the ctrl indexes for which the ctrl is "active"
active_idx = active_space_for_ctrl_spec(self.signature, self.ctrl_spec)
# Put the subbloq tensor at indices where ctrl is active.
subbloq_shape = tensor_shape_from_signature(self.subbloq.signature)
data[active_idx] = self.subbloq.tensor_contract().reshape(subbloq_shape)
# Add the data to the tensor network.
tn.add(qtn.Tensor(data=data, inds=out_ind + in_ind, tags=[self.short_name(), tag]))

def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
if soq.reg.name not in self.ctrl_reg_names:
# Delegate to subbloq
Expand Down
111 changes: 108 additions & 3 deletions qualtran/_infra/controlled_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import Dict, List, Tuple

import attrs
import cirq
import numpy as np
import pytest

import qualtran.testing as qlt_testing
from qualtran import Bloq, Controlled, CtrlSpec, QBit, QInt, QUInt
from qualtran import (
Bloq,
BloqBuilder,
Controlled,
CtrlSpec,
QBit,
QInt,
QUInt,
Register,
Side,
Signature,
)
from qualtran._infra.gate_with_registers import get_named_qubits, merge_qubits
from qualtran.bloqs.basic_gates import Swap, XGate
from qualtran.bloqs.basic_gates import (
CSwap,
IntEffect,
IntState,
OneState,
Swap,
XGate,
YGate,
ZeroState,
ZGate,
)
from qualtran.bloqs.for_testing import TestAtom, TestParallelCombo, TestSerialCombo
from qualtran.bloqs.mcmt import And
from qualtran.cirq_interop.testing import GateHelper
from qualtran.drawing import get_musical_score_data
from qualtran.drawing.musical_score import Circle, SoqData, TextBox

Expand Down Expand Up @@ -294,3 +318,84 @@ def test_classical_sim_int_multi_reg():
@pytest.mark.notebook
def test_notebook():
qlt_testing.execute_notebook('../Controlled')


def _verify_ctrl_tensor_for_unitary(ctrl_spec: CtrlSpec, bloq: Bloq, gate: cirq.Gate):
cbloq = Controlled(bloq, ctrl_spec)
cgate = gate.controlled(control_values=ctrl_spec.to_cirq_cv())
np.testing.assert_array_equal(cbloq.tensor_contract(), cirq.unitary(cgate))


interesting_ctrl_specs = [
CtrlSpec(),
CtrlSpec(cvs=0),
CtrlSpec(qdtypes=QUInt(4), cvs=0b0110),
CtrlSpec(cvs=[0, 1, 1, 0]),
CtrlSpec(qdtypes=[QBit(), QBit()], cvs=[[0, 1], [1, 0]]),
]


@pytest.mark.parametrize('ctrl_spec', interesting_ctrl_specs)
def test_controlled_tensor_for_unitary(ctrl_spec: CtrlSpec):
# Test one qubit unitaries
_verify_ctrl_tensor_for_unitary(ctrl_spec, XGate(), cirq.X)
_verify_ctrl_tensor_for_unitary(ctrl_spec, YGate(), cirq.Y)
_verify_ctrl_tensor_for_unitary(ctrl_spec, ZGate(), cirq.Z)
# Test multi-qubit unitaries with non-trivial signature
_verify_ctrl_tensor_for_unitary(ctrl_spec, CSwap(3), CSwap(3))


@attrs.frozen
class TestCtrlStatePrepAnd(Bloq):
"""Decomposes into a Controlled-AND gate + int effects & targets where ctrl is active.
Tensor contraction should give the output state vector corresponding to applying an
`And(and_ctrl)`; assuming all the control bits are active.
"""

ctrl_spec: CtrlSpec
and_ctrl: Tuple[int, int]

@property
def signature(self) -> 'Signature':
return Signature([Register('x', QBit(), shape=(3,), side=Side.RIGHT)])

def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']:
one_or_zero = [ZeroState(), OneState()]
cbloq = Controlled(And(*self.and_ctrl), ctrl_spec=self.ctrl_spec)

ctrl_soqs = {}
for reg, cvs in zip(cbloq.ctrl_regs, self.ctrl_spec.cvs):
soqs = np.empty(shape=reg.shape, dtype=object)
for idx in reg.all_idxs():
soqs[idx] = bb.add(IntState(val=cvs[idx], bitsize=reg.dtype.num_qubits))
ctrl_soqs[reg.name] = soqs

and_ctrl = [bb.add(one_or_zero[cv]) for cv in self.and_ctrl]

ctrl_soqs = bb.add_d(cbloq, **ctrl_soqs, ctrl=and_ctrl)
out_soqs = [*ctrl_soqs.pop('ctrl'), ctrl_soqs.pop('target')]

for reg, cvs in zip(cbloq.ctrl_regs, self.ctrl_spec.cvs):
for idx in reg.all_idxs():
ctrl_soq = np.asarray(ctrl_soqs[reg.name])[idx]
bb.add(IntEffect(val=cvs[idx], bitsize=reg.dtype.num_qubits), val=ctrl_soq)
return {'x': out_soqs}


def _verify_ctrl_tensor_for_and(ctrl_spec: CtrlSpec, and_ctrl: Tuple[int, int]):
cbloq = TestCtrlStatePrepAnd(ctrl_spec, and_ctrl)
bloq_tensor = cbloq.tensor_contract()
cirq_state_vector = GateHelper(And(*and_ctrl)).circuit.final_state_vector(
initial_state=and_ctrl + (0,)
)
np.testing.assert_allclose(bloq_tensor, cirq_state_vector, atol=1e-8)


@pytest.mark.parametrize('ctrl_spec', interesting_ctrl_specs)
def test_controlled_tensor_for_and_bloq(ctrl_spec: CtrlSpec):
# Test AND gate with one-sided signature (aka controlled state preparation).
_verify_ctrl_tensor_for_and(ctrl_spec, (1, 1))
_verify_ctrl_tensor_for_and(ctrl_spec, (1, 0))
_verify_ctrl_tensor_for_and(ctrl_spec, (0, 1))
_verify_ctrl_tensor_for_and(ctrl_spec, (0, 0))
36 changes: 4 additions & 32 deletions qualtran/cirq_interop/_cirq_to_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Cirq gates/circuits to Qualtran Bloqs conversion."""
import abc
import itertools
from collections import defaultdict
from functools import cached_property
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

Expand Down Expand Up @@ -46,6 +45,9 @@
)
from qualtran.cirq_interop._interop_qubit_manager import InteropQubitManager
from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity
from qualtran.simulation.tensor._tensor_data_manipulation import (
tensor_data_from_unitary_and_signature,
)

if TYPE_CHECKING:
from qualtran.drawing import WireSymbol
Expand Down Expand Up @@ -194,37 +196,7 @@ def _add_my_tensors_from_gate(
f"CirqGateAsBloq.add_my_tensors is currently supported only for unitary gates. "
f"Found {gate}."
)
unitary_shape = []
reg_to_idx = defaultdict(list)
for reg in signature:
start = len(unitary_shape)
for i in range(int(np.prod(reg.shape))):
reg_to_idx[reg.name].append(start + i)
unitary_shape.append(2**reg.bitsize)

unitary_shape = (*unitary_shape, *unitary_shape)
unitary = cirq.unitary(gate).reshape(unitary_shape)
idx: List[Union[int, slice]] = [slice(x) for x in unitary_shape]
n = len(unitary_shape) // 2
for reg in signature:
if reg.side == Side.LEFT:
for i in reg_to_idx[reg.name]:
# LEFT register ends, extract right subspace that's equivalent to 0.
idx[i] = 0
if reg.side == Side.RIGHT:
for i in reg_to_idx[reg.name]:
# Right register begins, extract the left subspace that's equivalent to 0.
idx[i + n] = 0
unitary = unitary[tuple(idx)]
new_shape = tuple(
[
*itertools.chain.from_iterable(
(2**reg.bitsize,) * int(np.prod(reg.shape))
for reg in [*signature.rights(), *signature.lefts()]
)
]
)
assert unitary.shape == new_shape
unitary = tensor_data_from_unitary_and_signature(cirq.unitary(gate), signature)
incoming_list = [
*itertools.chain.from_iterable(
[np.array(incoming[reg.name]).flatten() for reg in signature.lefts()]
Expand Down
7 changes: 7 additions & 0 deletions qualtran/simulation/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,10 @@
from ._dense import bloq_to_dense, get_right_and_left_inds
from ._flattening import bloq_has_custom_tensors, flatten_for_tensor_contraction
from ._quimb import cbloq_as_contracted_tensor, cbloq_to_quimb
from ._tensor_data_manipulation import (
active_space_for_ctrl_spec,
eye_tensor_for_signature,
tensor_data_from_unitary_and_signature,
tensor_out_inp_shape_from_signature,
tensor_shape_from_signature,
)
Loading

0 comments on commit a0ba6e9

Please sign in to comment.