Skip to content

Commit

Permalink
Remove tket dependencies for PennyLane model (#195)
Browse files Browse the repository at this point in the history
This replaces the intermediate `tket` circuit with a circuital diagram that can be converted to a `PennyLaneCircuit` instance - making `tket` an optional dependency. This also makes the `PennyLaneModel` class use the native `lambeq.Symbol` class instead of `sympy` symbols for training speed-up.

---------

Co-authored-by: Dimitri Kartsaklis <dimkart@gmail.com>
Co-authored-by: Neil John D. Ortega <neil.ortega@quantinuum.com>
  • Loading branch information
3 people authored Feb 27, 2025
1 parent 93e9c89 commit fb9dacd
Show file tree
Hide file tree
Showing 12 changed files with 850 additions and 314 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ lambeq/version.py

# development files
*.egg-info/
.conda/
.coverage
.ipynb_checkpoints/
__pycache__/
Expand Down
230 changes: 68 additions & 162 deletions lambeq/backend/converters/tk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,19 @@

import numpy as np
import pytket as tk
from pytket.circuit import (Bit, Command, Op, OpType, Qubit)
from pytket.circuit import (Bit, Command, Op, OpType)
from pytket.utils import probs_from_counts
import sympy
from typing_extensions import Self

from lambeq.backend import Functor, Symbol
from lambeq.backend.quantum import (bit, Box, Bra, CCX, CCZ, Controlled, CRx,
CRy, CRz, Daggered, Diagram, Discard,
GATES, Id, Ket, Measure, quantum, qubit,
Rx, Ry, Rz, Scalar, Swap, X, Y, Z)
from lambeq.backend import Symbol
from lambeq.backend.quantum import (bit, Box, Bra, CCX, CCZ,
Controlled, CRx, CRy, CRz,
Diagram, Discard, GATES, Id,
Ket, Measure, qubit,
readoff_circuital,
Rx, Ry, Rz, Scalar, Swap,
to_circuital, X, Y, Z)

OPTYPE_MAP = {'H': OpType.H,
'X': OpType.X,
Expand All @@ -52,7 +55,7 @@
'CRy': OpType.CRy,
'CRz': OpType.CRz,
'CCX': OpType.CCX,
'Swap': OpType.SWAP}
'SWAP': OpType.SWAP}


class Circuit(tk.Circuit):
Expand Down Expand Up @@ -192,161 +195,6 @@ def get_counts(self,
return counts


def to_tk(circuit: Diagram):
"""
Takes a :py:class:`lambeq.quantum.Diagram`, returns
a :py:class:`Circuit`.
"""
# bits and qubits are lists of register indices, at layer i we want
# len(bits) == circuit[:i].cod.count(bit) and same for qubits
tk_circ = Circuit()
bits: list[int] = []
qubits: list[int] = []
circuit = circuit.init_and_discard()

def remove_ketbra1(_, box: Box) -> Diagram | Box:
ob_map: dict[Box, Diagram]
ob_map = {Ket(1): Ket(0) >> X, # type: ignore[dict-item]
Bra(1): X >> Bra(0)} # type: ignore[dict-item]
return ob_map.get(box, box)

def prepare_qubits(qubits: list[int],
box: Box,
offset: int) -> list[int]:
renaming = dict()
start = (tk_circ.n_qubits if not qubits else 0
if not offset else qubits[offset - 1] + 1)
for i in range(start, tk_circ.n_qubits):
old = Qubit('q', i)
new = Qubit('q', i + len(box.cod))
renaming.update({old: new})
tk_circ.rename_units(renaming)
tk_circ.add_blank_wires(len(box.cod))
return (qubits[:offset] + list(range(start, start + len(box.cod)))
+ [i + len(box.cod) for i in qubits[offset:]])

def measure_qubits(qubits: list[int],
bits: list[int],
box: Box,
bit_offset: int,
qubit_offset: int) -> tuple[list[int], list[int]]:
if isinstance(box, Bra):
tk_circ.post_select({len(tk_circ.bits): box.bit})
for j, _ in enumerate(box.dom):
i_bit, i_qubit = len(tk_circ.bits), qubits[qubit_offset + j]
offset = len(bits) if isinstance(box, Measure) else None
tk_circ.add_bit(Bit(i_bit), offset=offset)
tk_circ.Measure(i_qubit, i_bit)
if isinstance(box, Measure):
bits = bits[:bit_offset + j] + [i_bit] + bits[bit_offset + j:]
# remove measured qubits
qubits = (qubits[:qubit_offset]
+ qubits[qubit_offset + len(box.dom):])
return bits, qubits

def swap(i: int, j: int, unit_factory=Qubit) -> None:
old, tmp, new = (
unit_factory(i), unit_factory('tmp', 0), unit_factory(j))
tk_circ.rename_units({old: tmp})
tk_circ.rename_units({new: old})
tk_circ.rename_units({tmp: new})

def add_gate(qubits: list[int], box: Box, offset: int) -> None:

is_dagger = False
if isinstance(box, Daggered):
box = box.dagger()
is_dagger = True

i_qubits = [qubits[offset + j] for j in range(len(box.dom))]

if isinstance(box, (Rx, Ry, Rz)):
phase = box.phase
if isinstance(box.phase, Symbol):
# Tket uses sympy, lambeq uses custom symbol
phase = box.phase.to_sympy()
op = Op.create(OPTYPE_MAP[box.name[:2]], 2 * phase)
elif isinstance(box, Controlled):
# The following works only for controls on single qubit gates

# reverse the distance order
dists = []
curr_box: Box | Controlled = box
while isinstance(curr_box, Controlled):
dists.append(curr_box.distance)
curr_box = curr_box.controlled
dists.reverse()

# Index of the controlled qubit is the last entry in rel_idx
rel_idx = [0]
for dist in dists:
if dist > 0:
# Add control to the left, offset by distance
rel_idx = [0] + [i + dist for i in rel_idx]
else:
# Add control to the right, don't offset
right_most_idx = max(rel_idx)
rel_idx.insert(-1, right_most_idx - dist)

i_qubits = [i_qubits[i] for i in rel_idx]

name = box.name.split('(')[0]
if box.name in ('CX', 'CZ', 'CCX'):
op = Op.create(OPTYPE_MAP[name])
elif name in ('CRx', 'CRz'):
phase = box.phase
if isinstance(box.phase, Symbol):
# Tket uses sympy, lambeq uses custom symbol
phase = box.phase.to_sympy()

op = Op.create(OPTYPE_MAP[name], 2 * phase)
elif name in ('CCX'):
op = Op.create(OPTYPE_MAP[name])
elif box.name in OPTYPE_MAP:
op = Op.create(OPTYPE_MAP[box.name])
else:
raise NotImplementedError(box)

if is_dagger:
op = op.dagger

tk_circ.add_gate(op, i_qubits)

circuit = Functor(target_category=quantum, # type: ignore [assignment]
ob=lambda _, x: x,
ar=remove_ketbra1)(circuit) # type: ignore [arg-type]
for left, box, _ in circuit:
if isinstance(box, Ket):
qubits = prepare_qubits(qubits, box, left.count(qubit))
elif isinstance(box, (Measure, Bra)):
bits, qubits = measure_qubits(
qubits, bits, box, left.count(bit), left.count(qubit))
elif isinstance(box, Discard):
qubits = (qubits[:left.count(qubit)]
+ qubits[left.count(qubit) + box.dom.count(qubit):])
elif isinstance(box, Swap):
if box == Swap(qubit, qubit):
off = left.count(qubit)
swap(qubits[off], qubits[off + 1])
elif box == Swap(bit, bit):
off = left.count(bit)
if tk_circ.post_processing:
right = Id(tk_circ.post_processing.cod[off + 2:])
tk_circ.post_process(
Id(bit ** off) @ Swap(bit, bit) @ right)
else:
swap(bits[off], bits[off + 1], unit_factory=Bit)
else: # pragma: no cover
continue # bits and qubits live in different registers.
elif isinstance(box, Scalar):
tk_circ.scale(abs(box.array) ** 2)
elif isinstance(box, Box):
add_gate(qubits, box, left.count(qubit))
else: # pragma: no cover
raise NotImplementedError
return tk_circ


def _tk_to_lmbq_param(theta):
if not isinstance(theta, sympy.Expr):
return theta
Expand All @@ -362,6 +210,64 @@ def _tk_to_lmbq_param(theta):
raise ValueError('Parameter must be a (possibly scaled) sympy Symbol')


def to_tk(diagram: Diagram) -> Circuit:
"""Takes a :py:class:`lambeq.quantum.Diagram`, returns
a :class:`lambeq.backend.converters.tk.Circuit`
for t|ket>.
Parameters
----------
diagram : :py:class:`~lambeq.backend.quantum.Diagram`
The :py:class:`Circuits <lambeq.backend.quantum.Diagram>`
to be converted to a tket circuit.
Returns
-------
tk_circuit : lambeq.backend.quantum
A :class:`lambeq.backend.converters.tk.Circuit`.
Notes
-----
* Converts to circuital.
* Copies the diagram to avoid modifying the original.
"""

if not diagram.is_circuital:
diagram = to_circuital(diagram)

circuit_info = readoff_circuital(diagram, use_sympy=True)

circuit = Circuit(circuit_info.total_qubits,
len(circuit_info.bitmap),
post_selection=circuit_info.postmap)

for gate in circuit_info.gates:
if gate.gtype == 'Scalar':
if gate.phase is None:
raise ValueError(f'Scalar gate {gate} has phase type None')
else:
circuit.scale(abs(gate.phase)**2) # type: ignore [arg-type]
continue
elif gate.gtype not in OPTYPE_MAP:
raise NotImplementedError(f'Gate {gate.gtype} not supported')

if gate.phase:
op = Op.create(OPTYPE_MAP[gate.gtype], 2 * gate.phase)
else:
op = Op.create(OPTYPE_MAP[gate.gtype])

if gate.dagger:
op = op.dagger

qubits = gate.qubits
circuit.add_gate(op, qubits)

for mq, bi in circuit_info.bitmap.items():
circuit.Measure(mq, bi)

return circuit


def from_tk(tk_circuit: tk.Circuit) -> Diagram:
"""Translates from tket to a lambeq Diagram."""
tk_circ: Circuit = Circuit.upgrade(tk_circuit)
Expand Down
67 changes: 65 additions & 2 deletions lambeq/backend/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from copy import deepcopy
from dataclasses import dataclass, field, InitVar, replace
import json
import pickle
from typing import Any, ClassVar, Dict, Protocol, Type, TypeVar
from typing import cast, overload, TYPE_CHECKING

Expand Down Expand Up @@ -203,6 +204,62 @@ def __getitem__(self, index: int | slice) -> Self:
else:
return self._fromiter(objects[index])

def replace(self, other: Self, index: int) -> Self:
"""Replace a type at the specified index in the complex type list.
Parameters
----------
other : Ty
The type to insert. Can be atomic or complex.
index : int
The position where the type should be inserted.
"""
if not (index <= len(self) and index >= 0):
raise IndexError(f'Index {index} out of bounds for '
f'type {self} with length {len(self)}.')

if self.is_empty:
return other
else:
objects = self.objects.copy()

if len(objects) == 1:
return other

if index == 0:
objects = [*other] + objects[1:]
elif index == len(self):
objects = objects[:-1] + [*other]
else:
objects = objects[:index] + [*other] + objects[index+1:]

return self._fromiter(objects)

def insert(self, other: Self, index: int) -> Self:
"""Insert a type at the specified index in the complex type list.
Parameters
----------
other : Ty
The type to insert. Can be atomic or complex.
index : int
The position where the type should be inserted.
"""
if not (index <= len(self)):
raise IndexError(f'Index {index} out of bounds for '
f'type {self} with length {len(self)}.')

if self.is_empty:
return other
else:
if index == 0:
return other @ self
elif index == len(self):
return self @ other
objects = self.objects.copy()
objects = objects[:index] + [*other] + objects[index:]
return self._fromiter(objects)

@classmethod
def _fromiter(cls, objects: Iterable[Self]) -> Self:
"""Create a Ty from an iterable of atomic objects."""
Expand Down Expand Up @@ -1970,7 +2027,10 @@ def __call__(self, entity: Ty | Diagrammable) -> Ty | Diagrammable:
def ob_with_cache(self, ob: Ty) -> Ty:
"""Apply the functor to a type, caching the result."""
try:
return deepcopy(self.ob_cache[ob])
# Faster deepcopy
return pickle.loads( # type: ignore[no-any-return]
pickle.dumps(self.ob_cache[ob])
)
except KeyError:
pass

Expand All @@ -1985,7 +2045,10 @@ def ob_with_cache(self, ob: Ty) -> Ty:
def ar_with_cache(self, ar: Diagrammable) -> Diagrammable:
"""Apply the functor to a diagrammable, caching the result."""
try:
return deepcopy(self.ar_cache[ar])
# Faster deepcopy
return pickle.loads( # type: ignore[no-any-return]
pickle.dumps(self.ar_cache[ar])
)
except KeyError:
pass

Expand Down
Loading

0 comments on commit fb9dacd

Please sign in to comment.