Skip to content

Commit

Permalink
Raise an exception when applying ansatze on a diagram with frames (#222)
Browse files Browse the repository at this point in the history
* Added new `SplitTensorAnsatz` for tensor network ansatzes that splits large boxes into smaller units
* Reorganize tests and imports
  • Loading branch information
neiljdo authored Feb 28, 2025
1 parent fb9dacd commit 0d7eebd
Show file tree
Hide file tree
Showing 13 changed files with 290 additions and 199 deletions.
12 changes: 10 additions & 2 deletions lambeq/ansatz/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@
"""
from __future__ import annotations

__all__ = ['BaseAnsatz', 'Symbol']
__all__ = ['BaseAnsatz']

from abc import ABC, abstractmethod
from collections.abc import Mapping

from lambeq.backend import grammar, tensor
from lambeq.backend.symbol import Symbol


AnsatzWithFramesRuntimeError = RuntimeError(
'Attempting to apply an ansatz to a diagram '
'with frames. Try using `sandwich=True` when '
'calling `DisCoCircReader.text2circuit()` '
'or applying a custom functor that converts '
'frames to boxes before applying an ansatz.'
)


class BaseAnsatz(ABC):
Expand Down
6 changes: 5 additions & 1 deletion lambeq/ansatz/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

import numpy as np

from lambeq.ansatz import BaseAnsatz
from lambeq.ansatz.base import AnsatzWithFramesRuntimeError, BaseAnsatz
from lambeq.backend.grammar import Box, Diagram, Functor, Ty
from lambeq.backend.quantum import (
Bra,
Expand Down Expand Up @@ -111,6 +111,10 @@ def __init__(self,

def __call__(self, diagram: Diagram) -> Circuit:
"""Convert a lambeq diagram into a lambeq circuit."""

if diagram.has_frames:
raise AnsatzWithFramesRuntimeError

return self.functor(diagram) # type: ignore[return-value]

def ob_size(self, pg_type: Ty) -> int:
Expand Down
40 changes: 27 additions & 13 deletions lambeq/ansatz/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@

__all__ = ['TensorAnsatz', 'MPSAnsatz', 'SpiderAnsatz']

from abc import abstractmethod
from collections.abc import Mapping

from lambeq.ansatz import BaseAnsatz
from lambeq.ansatz.base import AnsatzWithFramesRuntimeError, BaseAnsatz
from lambeq.backend import grammar, Symbol, tensor
from lambeq.backend.grammar import Cup, Spider, Ty, Word
from lambeq.backend.tensor import Dim
Expand Down Expand Up @@ -108,10 +109,33 @@ def _generate_directed_dom_cod(self, box: grammar.Box) -> tuple[Dim, Dim]:

def __call__(self, diagram: grammar.Diagram) -> tensor.Diagram:
"""Convert a diagram into a tensor."""

if diagram.has_frames:
raise AnsatzWithFramesRuntimeError

return self.functor(diagram) # type: ignore[return-value]


class MPSAnsatz(TensorAnsatz):
class SplitTensorAnsatz(TensorAnsatz):
"""Base class for tensor network ansatzes that splits large boxes
into smaller units."""

split_functor: grammar.Functor

@abstractmethod
def _split_ar(self, _: grammar.Functor, ar: Word) -> grammar.Diagrammable:
"""Split large boxes into smaller units."""

def __call__(self, diagram: grammar.Diagram) -> tensor.Diagram:
if diagram.has_frames:
raise AnsatzWithFramesRuntimeError

return self.functor(
self.split_functor(diagram)
) # type: ignore[return-value]


class MPSAnsatz(SplitTensorAnsatz):
"""Split large boxes into matrix product states."""

BOND_TYPE: Ty = Ty('B')
Expand Down Expand Up @@ -169,13 +193,8 @@ def _split_ar(self, _: grammar.Functor, ar: Word) -> grammar.Diagrammable:
return (grammar.Id().tensor(*boxes)
>> grammar.Id().tensor(*cups[:-1])) # type: ignore[arg-type]

def __call__(self, diagram: grammar.Diagram) -> tensor.Diagram:
return self.functor(
self.split_functor(diagram)
) # type: ignore[return-value]


class SpiderAnsatz(TensorAnsatz):
class SpiderAnsatz(SplitTensorAnsatz):
"""Split large boxes into spiders."""

def __init__(self,
Expand Down Expand Up @@ -220,8 +239,3 @@ def _split_ar(self, _: grammar.Functor, ar: Word) -> grammar.Diagrammable:

return (grammar.Id().tensor(*boxes)
>> grammar.Id().tensor(*spiders))

def __call__(self, diagram: grammar.Diagram) -> tensor.Diagram:
return self.functor(
self.split_functor(diagram)
) # type: ignore[return-value]
2 changes: 1 addition & 1 deletion lambeq/backend/pennylane.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def param_substitution(self, weights):
The weights to substitute for the symbols.
Returns
-------e
-------
:class:`torch.FloatTensor`
The concrete (non-symbolic) parameters for the
circuit.
Expand Down
2 changes: 1 addition & 1 deletion lambeq/training/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from sympy import Symbol as SymPySymbol

from lambeq.ansatz.base import Symbol
from lambeq.backend.symbol import Symbol
from lambeq.backend.tensor import Diagram
from lambeq.training.checkpoint import Checkpoint
from lambeq.typing import StrPathT
Expand Down
2 changes: 1 addition & 1 deletion lambeq/training/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

import torch

from lambeq.ansatz.base import Symbol
from lambeq.backend.numerical_backend import backend
from lambeq.backend.symbol import Symbol
from lambeq.backend.tensor import Diagram
from lambeq.training.checkpoint import Checkpoint
from lambeq.training.model import Model
Expand Down
2 changes: 1 addition & 1 deletion lambeq/training/pytorch_quantum_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@

import torch

from lambeq.ansatz.base import Symbol
from lambeq.backend.numerical_backend import backend
from lambeq.backend.quantum import Diagram as Circuit
from lambeq.backend.symbol import Symbol
from lambeq.backend.tensor import Diagram
from lambeq.training.checkpoint import Checkpoint
from lambeq.training.quantum_model import QuantumModel
Expand Down
Empty file added tests/ansatz/__init__.py
Empty file.
Loading

0 comments on commit 0d7eebd

Please sign in to comment.