Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (stim) add qubit attribute and qubit coordinate attribute #3114

Merged
merged 28 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e71d99a
Add stim dialect
kimxworrall Aug 15, 2024
838d11e
Add tests
kimxworrall Aug 15, 2024
82ae376
core: CanonicaliZation naming consistency (#3040)
PapyChacal Aug 15, 2024
22c247f
update stim
kimxworrall Aug 27, 2024
f3f785a
Pull region assembly format update by merging branch 'main' into kim/…
kimxworrall Aug 28, 2024
e6c2ebe
initialise stim print and parse, reorganise files to have stim relate…
kimxworrall Aug 28, 2024
06662bc
fix tests
kimxworrall Aug 28, 2024
32585bc
fix precommit
kimxworrall Aug 28, 2024
2cfbf1a
Update __init__.py
kimxworrall Aug 28, 2024
a5c8b34
Update tests/filecheck/dialects/stim/stim_ops.mlir
kimxworrall Aug 28, 2024
5e99704
Update tests/filecheck/dialects/stim/stim_ops.mlir
kimxworrall Aug 28, 2024
123328a
Add qubit coordinates annotation and attributes.
kimxworrall Aug 29, 2024
bf56513
add tests for stim printer
kimxworrall Aug 29, 2024
1c4e0a1
Clean tests and add qubit coordinates printer
kimxworrall Aug 29, 2024
a63cb4d
Remove unnecessary files
kimxworrall Aug 29, 2024
94ce681
dialects: (stim) Add qubit attribute and qubit coordinate attribute
kimxworrall Aug 29, 2024
febe161
Remove StimOp reference
kimxworrall Aug 29, 2024
ff2b0c0
Align with precommit
kimxworrall Aug 29, 2024
2428317
Remove unused test functions
kimxworrall Aug 29, 2024
dc1f1b3
align with precommit
kimxworrall Aug 29, 2024
57e4162
Re-Add print_lists
kimxworrall Aug 29, 2024
7069931
Apply suggestions from code review
kimxworrall Aug 30, 2024
ae824ba
Add tests for stim printer parser
kimxworrall Aug 30, 2024
bcd210c
move syntax tests to filecheck
kimxworrall Aug 30, 2024
6c74405
replace stimattr with stimprintable
kimxworrall Aug 30, 2024
fb5d65e
steramline print functions
kimxworrall Aug 30, 2024
53e1a5b
Merge branch 'main' into kim/stim/first-attributes
kimxworrall Aug 30, 2024
a09cc57
Merge branch 'main' into kim/stim/first-attributes
superlopuh Sep 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion tests/dialects/stim/test_stim_printer_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,22 @@
import pytest

from xdsl.dialects import stim
from xdsl.dialects.stim.stim_printer_parser import StimPrinter
from xdsl.dialects.stim.ops import QubitAttr, QubitMappingAttr
from xdsl.dialects.stim.stim_printer_parser import StimPrintable, StimPrinter
from xdsl.dialects.test import TestOp
from xdsl.ir import Block, Region

################################################################################
# Utils for this test file #
################################################################################


def check_stim_print(program: StimPrintable, expected_stim: str):
res_io = StringIO()
printer = StimPrinter(stream=res_io)
program.print_stim(printer)
assert expected_stim == res_io.getvalue()


def test_empty_circuit():
empty_block = Block()
Expand All @@ -27,3 +39,16 @@ def test_stim_circuit_ops_stim_printable():
printer = StimPrinter(stream=res_io)

module.print_stim(printer)


def test_print_stim_qubit_attr():
qubit = QubitAttr(0)
expected_stim = "0"
check_stim_print(qubit, expected_stim)


def test_print_stim_qubit_coord_attr():
qubit = QubitAttr(0)
qubit_coord = QubitMappingAttr([0, 0], qubit)
expected_stim = "(0, 0) 0"
check_stim_print(qubit_coord, expected_stim)
13 changes: 13 additions & 0 deletions tests/filecheck/dialects/stim/attrs.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: XDSL_ROUNDTRIP

"test.op"() {
qubit = !stim.qubit<0>,
qubitcoord = #stim.qubit_coord<(0,0), !stim.qubit<0>>
} : () -> ()

%qubit0 = "test.op"() : () -> (!stim.qubit<0>)

// CHECK: builtin.module {
// CHECK-NEXT: "test.op"() {"qubit" = !stim.qubit<0>, "qubitcoord" = #stim.qubit_coord<(0, 0), !stim.qubit<0>>} : () -> ()
// CHECK-NEXT: %qubit0 = "test.op"() : () -> !stim.qubit<0>
// CHECK-NEXT: }
6 changes: 5 additions & 1 deletion xdsl/dialects/stim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from xdsl.ir import Dialect

from .ops import StimCircuitOp
from .ops import QubitAttr, QubitMappingAttr, StimCircuitOp

Stim = Dialect(
"stim",
[
StimCircuitOp,
],
[
QubitAttr,
QubitMappingAttr,
],
)
103 changes: 95 additions & 8 deletions xdsl/dialects/stim/ops.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,111 @@
from abc import ABC
from collections.abc import Sequence
from io import StringIO

from xdsl.ir import Region
from xdsl.dialects.builtin import ArrayAttr, IntAttr
from xdsl.dialects.stim.stim_printer_parser import StimPrintable, StimPrinter
from xdsl.ir import ParametrizedAttribute, Region, TypeAttribute
from xdsl.irdl import (
IRDLOperation,
PyRDLOpDefinitionError,
ParameterDef,
irdl_attr_definition,
irdl_op_definition,
region_def,
)
from xdsl.parser import AttrParser
from xdsl.printer import Printer

from .stim_printer_parser import StimPrintable, StimPrinter

@irdl_attr_definition
class QubitAttr(StimPrintable, ParametrizedAttribute, TypeAttribute):
"""
Type for a single qubit.
"""

name = "stim.qubit"

qubit: ParameterDef[IntAttr]

def __init__(self, qubit: int | IntAttr) -> None:
if not isinstance(qubit, IntAttr):
qubit = IntAttr(qubit)
super().__init__(parameters=[qubit])

@classmethod
def parse_parameters(cls, parser: AttrParser) -> Sequence[IntAttr]:
with parser.in_angle_brackets():
qubit = parser.parse_integer(allow_negative=False, allow_boolean=False)
return (IntAttr(qubit),)

def print_parameters(self, printer: Printer) -> None:
with printer.in_angle_brackets():
printer.print(self.qubit.data)

def print_stim(self, printer: StimPrinter):
printer.print_string(f"{self.qubit.data}")


@irdl_attr_definition
class QubitMappingAttr(StimPrintable, ParametrizedAttribute):
"""
This attribute provides a way to indicate the required connectivity or layout of `physical` qubits.

It consists of two parameters:
1. A co-ordinate array (currently it only anticipates a pair of qubits, but this is not fixed)
2. A value associated with a qubit referred to in a circuit.

The co-ordinates may be used as a physical address of a qubit, or the relative address with respect to some known physical address.

Operations that attach this as a property may represent the lattice-like structure of a physical quantum computer by having a property with an ArrayAttr[QubitCoordsAttr].
"""

name = "stim.qubit_coord"

class StimOp(IRDLOperation, ABC):
def print_stim(self, printer: StimPrinter) -> None:
raise (PyRDLOpDefinitionError("print_stim not implemented!"))
coords: ParameterDef[ArrayAttr[IntAttr]]
qubit_name: ParameterDef[QubitAttr]

def __init__(
self, coords: list[int] | ArrayAttr[IntAttr], qubit_name: int | QubitAttr
) -> None:
if not isinstance(qubit_name, QubitAttr):
qubit_name = QubitAttr(qubit_name)
if not isinstance(coords, ArrayAttr):
coords = ArrayAttr(IntAttr(c) for c in coords)
super().__init__(parameters=[coords, qubit_name])

@classmethod
def parse_parameters(
cls, parser: AttrParser
) -> tuple[ArrayAttr[IntAttr], QubitAttr]:
parser.parse_punctuation("<")
coords = parser.parse_comma_separated_list(
delimiter=parser.Delimiter.PAREN,
parse=lambda: IntAttr(parser.parse_integer(allow_boolean=False)),
)
parser.parse_punctuation(",")
qubit = parser.parse_attribute()
if not isinstance(qubit, QubitAttr):
parser.raise_error("Expected qubit attr", at_position=parser.pos)
parser.parse_punctuation(">")
return (ArrayAttr(coords), qubit)

def print_parameters(self, printer: Printer) -> None:
with printer.in_angle_brackets():
printer.print("(")
for i, elem in enumerate(self.coords):
if i:
printer.print_string(", ")
printer.print(elem.data)
printer.print("), ")
printer.print(self.qubit_name)

def print_stim(self, printer: StimPrinter):
printer.print_attribute(self.coords)
printer.print_string(" ")
self.qubit_name.print_stim(printer)


@irdl_op_definition
class StimCircuitOp(StimOp, IRDLOperation):
class StimCircuitOp(StimPrintable, IRDLOperation):
"""
Base operation containing a stim program
"""
Expand Down
39 changes: 34 additions & 5 deletions xdsl/dialects/stim/stim_printer_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import abc
from collections.abc import Callable, Iterable
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any
from typing import Any, TypeVar, cast

from xdsl.dialects.builtin import ArrayAttr, IntAttr
from xdsl.ir import Attribute


@dataclass(eq=False, repr=False)
Expand All @@ -14,10 +18,35 @@ def print_string(self, text: str) -> None:
@contextmanager
def in_braces(self):
self.print_string("{")
try:
yield
finally:
self.print_string("}")
yield
self.print_string("}")

@contextmanager
def in_parens(self):
self.print_string("(")
yield
self.print_string(")")

T = TypeVar("T")

def print_list(
self, elems: Iterable[T], print_fn: Callable[[T], Any], delimiter: str = ", "
) -> None:
for i, elem in enumerate(elems):
if i:
self.print_string(delimiter)
print_fn(elem)

def print_attribute(self, attribute: Attribute) -> None:
if isinstance(attribute, ArrayAttr):
attribute = cast(ArrayAttr[Attribute], attribute)
with self.in_parens():
self.print_list(attribute, self.print_attribute)
return
if isinstance(attribute, IntAttr):
self.print_string(f"{attribute.data}")
return
raise ValueError(f"Cannot print in stim format: {attribute}")


class StimPrintable(abc.ABC):
Expand Down
Loading