diff --git a/tests/filecheck/dialects/qssa/ops.mlir b/tests/filecheck/dialects/qssa/ops.mlir new file mode 100644 index 0000000000..5fbd8e7bf3 --- /dev/null +++ b/tests/filecheck/dialects/qssa/ops.mlir @@ -0,0 +1,28 @@ +// RUN: XDSL_ROUNDTRIP +// RUN: XDSL_GENERIC_ROUNDTRIP + +%q0, %q1 = qssa.alloc<2> + +// CHECK: %q0, %q1 = qssa.alloc<2> + +%q2 = qssa.h %q0 + +// CHECK-NEXT: %q2 = qssa.h %q0 + +%q3, %q4 = qssa.cz %q1, %q2 + +// CHECK-NEXT: %q3, %q4 = qssa.cz %q1, %q2 + +%q5, %q6 = qssa.cnot %q3, %q4 + +// CHECK-NEXT: %q5, %q6 = qssa.cnot %q3, %q4 + +%0 = qssa.measure %q6 + +// CHECK-NEXT: %0 = qssa.measure %q6 + +// CHECK-GENERIC: %q0, %q1 = "qssa.alloc"() : () -> (!qssa.qubit, !qssa.qubit) +// CHECK-GENERIC-NEXT: %q2 = "qssa.h"(%q0) : (!qssa.qubit) -> !qssa.qubit +// CHECK-GENERIC-NEXT: %q3, %q4 = "qssa.cz"(%q1, %q2) : (!qssa.qubit, !qssa.qubit) -> (!qssa.qubit, !qssa.qubit) +// CHECK-GENERIC-NEXT: %q5, %q6 = "qssa.cnot"(%q3, %q4) : (!qssa.qubit, !qssa.qubit) -> (!qssa.qubit, !qssa.qubit) +// CHECK-GENERIC-NEXT: %0 = "qssa.measure"(%q6) : (!qssa.qubit) -> i1 diff --git a/xdsl/dialects/qssa.py b/xdsl/dialects/qssa.py new file mode 100644 index 0000000000..29481755f3 --- /dev/null +++ b/xdsl/dialects/qssa.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from abc import ABC + +from xdsl.dialects.builtin import IntegerType +from xdsl.ir import Dialect, ParametrizedAttribute, SSAValue, TypeAttribute +from xdsl.irdl import ( + IRDLOperation, + VarOpResult, + irdl_attr_definition, + irdl_op_definition, + operand_def, + result_def, + var_result_def, +) +from xdsl.parser import Parser +from xdsl.printer import Printer + + +@irdl_attr_definition +class QubitAttr(ParametrizedAttribute, TypeAttribute): + """ + Type for a single qubit + """ + + name = "qssa.qubit" + + +qubit = QubitAttr() + + +class QubitBase(IRDLOperation, ABC): + pass + + +@irdl_op_definition +class QubitAllocOp(QubitBase): + name = "qssa.alloc" + + res: VarOpResult = var_result_def(qubit) + + def __init__(self, num_qubits: int): + super().__init__( + operands=(), + result_types=[[qubit] * num_qubits], + ) + + @property + def num_qubits(self): + return len(self.res) + + @classmethod + def parse(cls, parser: Parser) -> QubitAllocOp: + with parser.in_angle_brackets(): + num_qubits = parser.parse_integer() + attr_dict = parser.parse_optional_attr_dict() + return QubitAllocOp.create( + result_types=[qubit] * num_qubits, + attributes=attr_dict, + ) + + def print(self, printer: Printer): + with printer.in_angle_brackets(): + printer.print(self.num_qubits) + + printer.print_op_attributes(self.attributes) + + +@irdl_op_definition +class HGateOp(QubitBase): + name = "qssa.h" + + input = operand_def(qubit) + + output = result_def(qubit) + + assembly_format = "$input attr-dict" + + def __init__(self, input: SSAValue): + super().__init__( + operands=(input,), + result_types=(qubit,), + ) + + +@irdl_op_definition +class CNotGateOp(QubitBase): + name = "qssa.cnot" + + in1 = operand_def(qubit) + + in2 = operand_def(qubit) + + out1 = result_def(qubit) + + out2 = result_def(qubit) + + assembly_format = "$in1 `,` $in2 attr-dict" + + def __init__(self, in1: SSAValue, in2: SSAValue): + super().__init__( + operands=(in1, in2), + result_types=(qubit, qubit), + ) + + +@irdl_op_definition +class CZGateOp(QubitBase): + name = "qssa.cz" + + in1 = operand_def(qubit) + + in2 = operand_def(qubit) + + out1 = result_def(qubit) + + out2 = result_def(qubit) + + assembly_format = "$in1 `,` $in2 attr-dict" + + def __init__(self, in1: SSAValue, in2: SSAValue): + super().__init__( + operands=(in1, in2), + result_types=(qubit, qubit), + ) + + +@irdl_op_definition +class MeasureOp(QubitBase): + name = "qssa.measure" + + input = operand_def(qubit) + + output = result_def(IntegerType(1)) + + assembly_format = "$input attr-dict" + + def __init__(self, input: SSAValue): + super().__init__( + operands=[input], + result_types=[IntegerType(1)], + ) + + +QSSA = Dialect( + "qssa", + [ + CNotGateOp, + CZGateOp, + HGateOp, + MeasureOp, + QubitAllocOp, + ], + [ + QubitAttr, + ], +) diff --git a/xdsl/tools/command_line_tool.py b/xdsl/tools/command_line_tool.py index 1b588e6c21..bd2cb4e129 100644 --- a/xdsl/tools/command_line_tool.py +++ b/xdsl/tools/command_line_tool.py @@ -175,6 +175,11 @@ def get_printf(): return Printf + def get_qssa(): + from xdsl.dialects.qssa import QSSA + + return QSSA + def get_riscv_debug(): from xdsl.dialects.riscv_debug import RISCV_Debug @@ -298,6 +303,7 @@ def get_x86(): "onnx": get_onnx, "pdl": get_pdl, "printf": get_printf, + "qssa": get_qssa, "riscv": get_riscv, "riscv_debug": get_riscv_debug, "riscv_func": get_riscv_func,