From 635853288963ed02ddb8d5b48c9cdc9d3fc4adcf Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Mon, 24 Jun 2024 17:14:53 +0100 Subject: [PATCH] Add qref dialect --- tests/filecheck/dialects/qref/ops.mlir | 28 +++++ xdsl/dialects/__init__.py | 6 + xdsl/dialects/qref.py | 147 +++++++++++++++++++++++++ 3 files changed, 181 insertions(+) create mode 100644 tests/filecheck/dialects/qref/ops.mlir create mode 100644 xdsl/dialects/qref.py diff --git a/tests/filecheck/dialects/qref/ops.mlir b/tests/filecheck/dialects/qref/ops.mlir new file mode 100644 index 0000000000..982c93842a --- /dev/null +++ b/tests/filecheck/dialects/qref/ops.mlir @@ -0,0 +1,28 @@ +// RUN: XDSL_ROUNDTRIP +// RUN: XDSL_GENERIC_ROUNDTRIP + +%q0, %q1 = qref.alloc<2> + +// CHECK: %q0, %q1 = qref.alloc<2> + +qref.h %q0 + +// CHECK-NEXT: qref.h %q0 + +qref.cz %q1, %q0 + +// CHECK-NEXT: qref.cz %q1, %q0 + +qref.cnot %q1, %q0 + +// CHECK-NEXT: qref.cnot %q1, %q0 + +%0 = qref.measure %q0 + +// CHECK-NEXT: %0 = qref.measure %q0 + +// CHECK-GENERIC: %q0, %q1 = "qref.alloc"() : () -> (!qref.qubit, !qref.qubit) +// CHECK-GENERIC-NEXT: "qref.h"(%q0) : (!qref.qubit) -> () +// CHECK-GENERIC-NEXT: "qref.cz"(%q1, %q0) : (!qref.qubit, !qref.qubit) -> () +// CHECK-GENERIC-NEXT: "qref.cnot"(%q1, %q0) : (!qref.qubit, !qref.qubit) -> () +// CHECK-GENERIC-NEXT: %0 = "qref.measure"(%q0) : (!qref.qubit) -> i1 diff --git a/xdsl/dialects/__init__.py b/xdsl/dialects/__init__.py index 7878d9db07..3defaf4817 100644 --- a/xdsl/dialects/__init__.py +++ b/xdsl/dialects/__init__.py @@ -166,6 +166,11 @@ def get_printf(): return Printf + def get_qref(): + from xdsl.dialects.qref import QREF + + return QREF + def get_qssa(): from xdsl.dialects.qssa import QSSA @@ -294,6 +299,7 @@ def get_x86(): "onnx": get_onnx, "pdl": get_pdl, "printf": get_printf, + "qref": get_qref, "qssa": get_qssa, "riscv": get_riscv, "riscv_debug": get_riscv_debug, diff --git a/xdsl/dialects/qref.py b/xdsl/dialects/qref.py new file mode 100644 index 0000000000..9de1817167 --- /dev/null +++ b/xdsl/dialects/qref.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from abc import ABC + +from xdsl.dialects.builtin import IntegerType +from xdsl.ir import Dialect, ParametrizedAttribute, SSAValue, TypeAttribute +from xdsl.irdl import ( + IRDLOperation, + VarOpResult, + irdl_attr_definition, + irdl_op_definition, + operand_def, + result_def, + var_result_def, +) +from xdsl.parser import Parser +from xdsl.printer import Printer + + +@irdl_attr_definition +class QubitAttr(ParametrizedAttribute, TypeAttribute): + """ + Reference to a qubit + """ + + name = "qref.qubit" + + +qubit = QubitAttr() + + +class QRefBase(IRDLOperation, ABC): + pass + + +@irdl_op_definition +class QRefAllocOp(QRefBase): + name = "qref.alloc" + + res: VarOpResult = var_result_def(qubit) + + def __init__(self, num_qubits: int): + super().__init__( + operands=(), + result_types=[[qubit] * num_qubits], + ) + + @property + def num_qubits(self): + return len(self.res) + + @classmethod + def parse(cls, parser: Parser) -> QRefAllocOp: + with parser.in_angle_brackets(): + num_qubits = parser.parse_integer() + attr_dict = parser.parse_optional_attr_dict() + return QRefAllocOp.create( + result_types=[qubit] * num_qubits, + attributes=attr_dict, + ) + + def print(self, printer: Printer): + with printer.in_angle_brackets(): + printer.print(self.num_qubits) + + printer.print_op_attributes(self.attributes) + + +@irdl_op_definition +class HGateOp(QRefBase): + name = "qref.h" + + input = operand_def(qubit) + + assembly_format = "$input attr-dict" + + def __init__(self, input: SSAValue): + super().__init__( + operands=(input,), + result_types=(), + ) + + +@irdl_op_definition +class CNotGateOp(QRefBase): + name = "qref.cnot" + + in1 = operand_def(qubit) + + in2 = operand_def(qubit) + + assembly_format = "$in1 `,` $in2 attr-dict" + + def __init__(self, in1: SSAValue, in2: SSAValue): + super().__init__( + operands=(in1, in2), + result_types=(qubit, qubit), + ) + + +@irdl_op_definition +class CZGateOp(QRefBase): + name = "qref.cz" + + in1 = operand_def(qubit) + + in2 = operand_def(qubit) + + assembly_format = "$in1 `,` $in2 attr-dict" + + def __init__(self, in1: SSAValue, in2: SSAValue): + super().__init__( + operands=(in1, in2), + result_types=(), + ) + + +@irdl_op_definition +class MeasureOp(QRefBase): + name = "qref.measure" + + input = operand_def(qubit) + + output = result_def(IntegerType(1)) + + assembly_format = "$input attr-dict" + + def __init__(self, input: SSAValue): + super().__init__( + operands=[input], + result_types=[IntegerType(1)], + ) + + +QREF = Dialect( + "qref", + [ + CNotGateOp, + CZGateOp, + HGateOp, + MeasureOp, + QRefAllocOp, + ], + [ + QubitAttr, + ], +)