From 8c988b2840cffa1f91ca3b587c952db488158801 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Goens?= Date: Fri, 10 Mar 2023 09:55:46 +0000 Subject: [PATCH 01/10] Mention Zulip (#528) --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index d4a27407f8..7b85dd090f 100644 --- a/README.md +++ b/README.md @@ -88,3 +88,7 @@ format the code in a uniform manner. To automate the formatting within vim, one can use https://github.com/vim-autoformat/vim-autoformat and trigger a `:Autoformat` on save. + +### Discussion + +You can also join the discussion at our [Zulip chat room](https://xdsl.zulipchat.com), kindly supported by community hosting from [Zulip](https://zulip.com/). From ae1988d4743bc48b3099dbcf985552595490f751 Mon Sep 17 00:00:00 2001 From: George Bisbas Date: Fri, 10 Mar 2023 11:52:41 +0000 Subject: [PATCH 02/10] misc: Improve messages (#518) * misc: Improve messages * misc: fix yapf --- xdsl/dialects/arith.py | 6 +++--- xdsl/dialects/experimental/stencil.py | 10 ++++------ xdsl/irdl.py | 2 +- xdsl/utils/deprecation.py | 7 +++++-- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index aff069b9f1..7e48a67c60 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -5,9 +5,9 @@ from enum import Enum from typing import Annotated, TypeVar, Union, Set, Optional -from xdsl.dialects.builtin import (ContainerOf, Float16Type, Float64Type, IndexType, IntAttr, - IntegerType, Float32Type, IntegerAttr, FloatAttr, - Attribute, AnyFloat, AnyIntegerAttr) +from xdsl.dialects.builtin import (ContainerOf, Float16Type, Float64Type, IndexType, + IntAttr, IntegerType, Float32Type, IntegerAttr, + FloatAttr, Attribute, AnyFloat, AnyIntegerAttr) from xdsl.ir import Operation, SSAValue, Dialect, OpResult, Data from xdsl.irdl import (AnyOf, irdl_op_definition, OpAttr, AnyAttr, Operand, irdl_attr_definition) diff --git a/xdsl/dialects/experimental/stencil.py b/xdsl/dialects/experimental/stencil.py index 37a7563f02..100670869a 100644 --- a/xdsl/dialects/experimental/stencil.py +++ b/xdsl/dialects/experimental/stencil.py @@ -39,11 +39,10 @@ def from_shape(shape: list[int] | list[IntAttr]) -> FieldType: # TODO: why do we need all these casts here, can we tell pyright "trust me" if all(isinstance(elm, IntAttr) for elm in shape): shape = cast(list[IntAttr], shape) - return FieldType([ArrayAttr.from_list(shape)]) + return FieldType([ArrayAttr(shape)]) shape = cast(list[int], shape) - return FieldType( - [ArrayAttr.from_list([IntAttr.from_int(d) for d in shape])]) + return FieldType([ArrayAttr([IntAttr.from_int(d) for d in shape])]) @irdl_attr_definition @@ -65,10 +64,9 @@ def from_shape( if isinstance(shape[0], IntAttr): # the if above is a sufficient type guard, but pyright does not understand :/ - return TempType([ArrayAttr.from_list(shape)]) # type: ignore + return TempType([ArrayAttr(shape)]) # type: ignore shape = cast(list[int], shape) - return TempType( - [ArrayAttr.from_list([IntAttr.from_int(d) for d in shape])]) + return TempType([ArrayAttr([IntAttr.from_int(d) for d in shape])]) def __repr__(self): repr: str = "stencil.Temp<[" diff --git a/xdsl/irdl.py b/xdsl/irdl.py index f4cf30135f..eca49e912d 100644 --- a/xdsl/irdl.py +++ b/xdsl/irdl.py @@ -991,7 +991,7 @@ def irdl_op_builder( if not isinstance(attr, Attribute): raise ValueError(error_prefix + f"{attr_name} is expected to be an " - "attribute, but got {type(attr)}.") + f"attribute, but got {type(attr)}.") built_attributes[attr_name] = attr # Take care of variadic operand and result segment sizes. diff --git a/xdsl/utils/deprecation.py b/xdsl/utils/deprecation.py index 036e3a0ecc..3756abc1a6 100644 --- a/xdsl/utils/deprecation.py +++ b/xdsl/utils/deprecation.py @@ -16,7 +16,8 @@ def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]: def new_func(*args: _P.args, **kwargs: _P.kwargs) -> _T: warnings.warn( - f'Call to deprecated method {func.__name__}: {reason}') + f'Call to deprecated method {str(func).split(" ")[1]}: {reason}' + ) return func(*args, **kwargs) return new_func @@ -25,4 +26,6 @@ def new_func(*args: _P.args, **kwargs: _P.kwargs) -> _T: def deprecated_constructor(func: Callable[_P, _T]) -> Callable[_P, _T]: - return deprecated(f'use the constructor (`ClassName(...)`) instead.')(func) + # TOFIX: improve printing + return deprecated(f'{"use the constructor (`ClassName(...)`) instead."}')( + func) From e2dbd3e4071e73eff7e733094c2df14439ad67a4 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 10 Mar 2023 11:55:59 +0000 Subject: [PATCH 03/10] mpi: Add library info generated by mpi-info.c (#527) * mpi: add all info generated by mpi-info.c * mpi: add mpich library version * mpi: make magic values hex --- xdsl/transforms/lower_mpi.py | 94 +++++++++++++++++++++++------------- 1 file changed, 61 insertions(+), 33 deletions(-) diff --git a/xdsl/transforms/lower_mpi.py b/xdsl/transforms/lower_mpi.py index 53ce2ce566..142c039ee4 100644 --- a/xdsl/transforms/lower_mpi.py +++ b/xdsl/transforms/lower_mpi.py @@ -25,29 +25,65 @@ class MpiLibraryInfo: This way of doing it is inherently fragile, but we don't know of any better way. We plan to include a C file that automagically extracts all this information from MPI headers. + You can see the current C file used in this PR: https://github.com/xdslproject/xdsl/pull/526 + You can see the status of OpenMPI support here: https://github.com/xdslproject/xdsl/issues/523 - These defaults have been chosen to work with **our** version of OpenMPI. No guarantees of portability! + These defaults have been extracted from MPICH 3.3a2. We would highly suggest + running the mpi-info.c file yourself with your version of the library! """ - mpi_comm_world_val: int = 0x44000000 + # MPI_Datatype + MPI_Datatype_size: int = 4 + MPI_CHAR: int = 0x4c000101 + MPI_SIGNED_CHAR: int = 0x4c000118 + MPI_UNSIGNED_CHAR: int = 0x4c000102 + MPI_BYTE: int = 0x4c00010d + MPI_WCHAR: int = 0x4c00040e + MPI_SHORT: int = 0x4c000203 + MPI_UNSIGNED_SHORT: int = 0x4c000204 MPI_INT: int = 0x4c000405 MPI_UNSIGNED: int = 0x4c000406 MPI_LONG: int = 0x4c000807 MPI_UNSIGNED_LONG: int = 0x4c000808 MPI_FLOAT: int = 0x4c00040a MPI_DOUBLE: int = 0x4c00080b - MPI_UNSIGNED_CHAR: int = -1 - MPI_UNSIGNED_SHORT: int = -1 - MPI_UNSIGNED_LONG_LONG: int = -1 - MPI_CHAR: int = -1 - MPI_SHORT: int = -1 - MPI_LONG_LONG_INT: int = -1 - - MPI_STATUS_IGNORE: int = 1 - - request_size: int = 4 - status_size: int = 4 * 5 - mpi_comm_size: int = 4 + MPI_LONG_DOUBLE: int = 0x4c00100c + MPI_LONG_LONG_INT: int = 0x4c000809 + MPI_UNSIGNED_LONG_LONG: int = 0x4c000819 + MPI_LONG_LONG: int = 0x4c000809 + + # MPI_Op + MPI_Op_size: int = 4 + MPI_MAX: int = 0x58000001 + MPI_MIN: int = 0x58000002 + MPI_SUM: int = 0x58000003 + MPI_PROD: int = 0x58000004 + MPI_LAND: int = 0x58000005 + MPI_BAND: int = 0x58000006 + MPI_LOR: int = 0x58000007 + MPI_BOR: int = 0x58000008 + MPI_LXOR: int = 0x58000009 + MPI_BXOR: int = 0x5800000a + MPI_MINLOC: int = 0x5800000b + MPI_MAXLOC: int = 0x5800000c + MPI_REPLACE: int = 0x5800000d + MPI_NO_OP: int = 0x5800000e + + # MPI_Comm + MPI_Comm_size: int = 4 + MPI_COMM_WORLD: int = 0x44000000 + MPI_COMM_SELF: int = 0x44000001 + + # MPI_Request + MPI_Request_size: int = 4 + + # MPI_Status + MPI_Status_size: int = 20 + MPI_STATUS_IGNORE: int = 0x00000001 + MPI_STATUSES_IGNORE: int = 0x00000001 + MPI_Status_field_MPI_SOURCE: int = 8 # offset of field MPI_SOURCE in struct MPI_Status + MPI_Status_field_MPI_TAG: int = 12 # offset of field MPI_TAG in struct MPI_Status + MPI_Status_field_MPI_ERROR: int = 16 # offset of field MPI_ERROR in struct MPI_Status _RewriteT = TypeVar('_RewriteT', bound=mpi.MPIBaseOp) @@ -107,7 +143,7 @@ def _emit_mpi_status_obj( lit1 := arith.Constant.from_int_and_width(1, builtin.i64), res := llvm.AllocaOp.get(lit1, builtin.IntegerType( - 8 * self.info.status_size), + 8 * self.info.MPI_Status_size), as_untyped_ptr=True), ], [res.res], res @@ -293,14 +329,12 @@ def lower(self, return [ *count_ops, comm_global := - arith.Constant.from_int_and_width(self.info.mpi_comm_world_val, - i32), + arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), datatype := self._emit_mpi_type_load(memref_elm_typ), tag := arith.Constant.from_int_and_width(op.tag.value.data, i32), lit1 := arith.Constant.from_int_and_width(1, builtin.i64), - request := - llvm.AllocaOp.get(lit1, - builtin.IntegerType(8 * self.info.request_size)), + request := llvm.AllocaOp.get( + lit1, builtin.IntegerType(8 * self.info.MPI_Request_size)), *(ptr := self._memref_get_llvm_ptr(op.buffer))[0], func.Call.get(self._mpi_name(op), [ ptr[1], count_ssa_val, datatype, op.dest, tag, comm_global, @@ -335,12 +369,10 @@ def lower(self, datatype := self._emit_mpi_type_load(memref_elm_typ), tag := arith.Constant.from_int_and_width(op.tag.value.data, i32), comm_global := - arith.Constant.from_int_and_width(self.info.mpi_comm_world_val, - i32), + arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), lit1 := arith.Constant.from_int_and_width(1, builtin.i64), - request := - llvm.AllocaOp.get(lit1, - builtin.IntegerType(8 * self.info.request_size)), + request := llvm.AllocaOp.get( + lit1, builtin.IntegerType(8 * self.info.MPI_Request_size)), func.Call.get(self._mpi_name(op), [ ptr[1], count_ssa_val, datatype, op.source, tag, comm_global, request @@ -367,8 +399,7 @@ def lower(self, return [ comm_global := - arith.Constant.from_int_and_width(self.info.mpi_comm_world_val, - i32), + arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), func.Call.get(self._mpi_name(op), [ op.buffer, op.count, op.datatype, op.dest, op.tag, comm_global ], [i32]), @@ -398,8 +429,7 @@ def lower(self, return [ *mpi_status_ops, comm_global := - arith.Constant.from_int_and_width(self.info.mpi_comm_world_val, - i32), + arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), func.Call.get(self._mpi_name(op), [ op.buffer, op.count, op.datatype, op.source, op.tag, comm_global, status @@ -462,8 +492,7 @@ def lower( """ return [ comm_global := - arith.Constant.from_int_and_width(self.info.mpi_comm_world_val, - i32), + arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), lit1 := arith.Constant.from_int_and_width(1, 64), int_ptr := llvm.AllocaOp.get(lit1, i32), func.Call.get(self._mpi_name(op), [comm_global, int_ptr], [i32]), @@ -488,8 +517,7 @@ def lower( """ return [ comm_global := - arith.Constant.from_int_and_width(self.info.mpi_comm_world_val, - i32), + arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), lit1 := arith.Constant.from_int_and_width(1, 64), int_ptr := llvm.AllocaOp.get(lit1, i32), func.Call.get(self._mpi_name(op), [comm_global, int_ptr], [i32]), From 21d2a7d91883ec065c78f0a74ce39439b61186dc Mon Sep 17 00:00:00 2001 From: Michel Weber <55622065+webmiche@users.noreply.github.com> Date: Fri, 10 Mar 2023 17:45:49 +0100 Subject: [PATCH 04/10] CI: Fix append-timestamp and naming (#529) This PR removes the append-timestamp that is not needed anymore after removing ccache from the repo, and renames one of the pyright CIs (they had the same name). --- .github/workflows/ci-mlir.yml | 1 - .github/workflows/ci-pyright.yml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/ci-mlir.yml b/.github/workflows/ci-mlir.yml index 47b8b0a07e..6e7957045a 100644 --- a/.github/workflows/ci-mlir.yml +++ b/.github/workflows/ci-mlir.yml @@ -54,7 +54,6 @@ jobs: path: llvm-project/build key: binaries-${{ runner.os }}-${{ env.MLIR-Version }} restore-keys: binaries-${{ runner.os }}-${{ env.MLIR-Version }} - append-timestamp: false - name: Checkout MLIR if: steps.cache-binary.outputs.cache-hit != 'true' diff --git a/.github/workflows/ci-pyright.yml b/.github/workflows/ci-pyright.yml index 13fcb758e4..080edc23f6 100644 --- a/.github/workflows/ci-pyright.yml +++ b/.github/workflows/ci-pyright.yml @@ -1,7 +1,7 @@ # This workflow will install Python dependencies, run tests and lint with a single version of Python # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions -name: CI - Pyright +name: CI - Pyright - Reviewdog on: # Trigger the workflow on push or pull request, From b71e558d737a9d516cf7f3e89603a7c66c081963 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Sat, 11 Mar 2023 19:06:09 +0000 Subject: [PATCH 05/10] misc: lint parser.py (#524) --- pyright-ci.json | 1 - xdsl/dialects/builtin.py | 2 +- xdsl/dialects/memref.py | 6 +- xdsl/ir.py | 3 +- xdsl/parser.py | 201 ++++++++++----------------------------- xdsl/utils/exceptions.py | 10 +- xdsl/utils/lexer.py | 124 ++++++++++++++++++++++++ 7 files changed, 185 insertions(+), 162 deletions(-) create mode 100644 xdsl/utils/lexer.py diff --git a/pyright-ci.json b/pyright-ci.json index 62f4bc659c..fa0972c607 100644 --- a/pyright-ci.json +++ b/pyright-ci.json @@ -4,7 +4,6 @@ "reportUnnecessaryIsInstance": false, "typeCheckingMode": "strict", "exclude": [ - "xdsl/parser.py", "xdsl/irdl_mlir_printer.py", "xdsl/irdl.py", "xdsl/ir.py" diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 534f9448b7..eec1899350 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -491,7 +491,7 @@ def get_shape(self) -> List[int]: @staticmethod def from_element_type_and_shape( referenced_type: _VectorTypeElems, - shape: List[int | IntegerAttr[IndexType]] + shape: Sequence[int | IntegerAttr[IndexType]] ) -> VectorType[_VectorTypeElems]: return VectorType([ ArrayAttr([ diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index 1160e9b7df..2940d0e130 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Annotated, TypeVar, Optional, List, TypeAlias, cast +from typing import Annotated, Sequence, TypeVar, Optional, List, TypeAlias, cast from xdsl.dialects.builtin import (DenseIntOrFPElementsAttr, IntegerAttr, DenseArrayBase, IndexType, ArrayAttr, @@ -33,8 +33,8 @@ def get_shape(self) -> List[int]: @staticmethod def from_element_type_and_shape( - referenced_type: _MemRefTypeElement, - shape: List[int | AnyIntegerAttr] + referenced_type: _MemRefTypeElement, + shape: Sequence[int | AnyIntegerAttr] ) -> MemRefType[_MemRefTypeElement]: return MemRefType([ ArrayAttr[AnyIntegerAttr]([ diff --git a/xdsl/ir.py b/xdsl/ir.py index d1e6cbd7ab..c25c8b47a5 100644 --- a/xdsl/ir.py +++ b/xdsl/ir.py @@ -14,6 +14,7 @@ from xdsl.parser import BaseParser from xdsl.printer import Printer from xdsl.irdl import OpDef, ParamAttrDef + from xdsl.utils.lexer import Span OpT = TypeVar('OpT', bound='Operation') @@ -735,7 +736,7 @@ def irdl_definition(cls) -> OpDef | None: class Block(IRNode): """A sequence of operations""" - declared_at: 'Span' | None = None + declared_at: Span | None = None _args: tuple[BlockArgument, ...] = field(default_factory=lambda: (), init=False) diff --git a/xdsl/parser.py b/xdsl/parser.py index 4ee144515e..976d216022 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -15,16 +15,18 @@ from typing import Any, TypeVar, Iterable, IO, cast from xdsl.utils.exceptions import ParseError, MultipleSpansParseError +from xdsl.utils.lexer import Input, Span from xdsl.dialects.memref import MemRefType, UnrankedMemrefType from xdsl.dialects.builtin import ( - AnyFloat, AnyTensorType, AnyUnrankedTensorType, AnyVectorType, - DenseResourceAttr, DictionaryAttr, Float16Type, Float32Type, Float64Type, - FloatAttr, FunctionType, IndexType, IntegerType, Signedness, StringAttr, - IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, VectorType, - SymbolRefAttr, DenseArrayBase, DenseIntOrFPElementsAttr, UnregisteredOp, - OpaqueAttr, NoneAttr, ModuleOp, UnitAttr, i64) + AnyArrayAttr, AnyFloat, AnyFloatAttr, AnyTensorType, AnyUnrankedTensorType, + AnyVectorType, DenseResourceAttr, DictionaryAttr, Float16Type, Float32Type, + Float64Type, FloatAttr, FunctionType, IndexType, IntegerType, Signedness, + StringAttr, IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, + VectorType, SymbolRefAttr, DenseArrayBase, DenseIntOrFPElementsAttr, + UnregisteredOp, OpaqueAttr, NoneAttr, ModuleOp, UnitAttr, i64) from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, BlockArgument, MLContext, ParametrizedAttribute, Data) +from xdsl.utils.hints import isa @dataclass @@ -85,81 +87,6 @@ def __hash__(self): return id(self) -@dataclass(frozen=True) -class Span: - """ - Parts of the input are always passed around as spans, so we know where " - "they originated. - """ - - start: int - """ - Start of tokens location in source file, global byte offset in file - """ - end: int - """ - End of tokens location in source file, global byte offset in file - """ - input: Input - """ - The input being operated on - """ - - def __len__(self): - return self.len - - @property - def len(self): - return self.end - self.start - - @property - def text(self): - return self.input.content[self.start:self.end] - - def get_line_col(self) -> tuple[int, int]: - info = self.input.get_lines_containing(self) - if info is None: - return -1, -1 - lines, offset_of_first_line, line_no = info - return line_no, self.start - offset_of_first_line - - def print_with_context(self, msg: str | None = None) -> str: - """ - returns a string containing lines relevant to the span. The Span's contents - are highlighted by up-carets beneath them (`^`). The message msg is printed - along these. - """ - info = self.input.get_lines_containing(self) - if info is None: - return "Unknown location of span {}. Error: ".format(msg) - lines, offset_of_first_line, line_no = info - # Offset relative to the first line: - offset = self.start - offset_of_first_line - remaining_len = max(self.len, 1) - capture = StringIO() - print("{}:{}:{}".format(self.input.name, line_no, offset), - file=capture) - for line in lines: - print(line, file=capture) - if remaining_len < 0: - continue - len_on_this_line = min(remaining_len, len(line) - offset) - remaining_len -= len_on_this_line - print("{}{}".format(" " * offset, "^" * max(len_on_this_line, 1)), - file=capture) - if msg is not None: - print("{}{}".format(" " * offset, msg), file=capture) - msg = None - offset = 0 - if msg is not None: - print(msg, file=capture) - return capture.getvalue() - - def __repr__(self): - return "{}[{}:{}](text='{}')".format(self.__class__.__name__, - self.start, self.end, self.text) - - @dataclass(frozen=True, repr=False) class StringLiteral(Span): @@ -184,52 +111,6 @@ def string_contents(self): return ast.literal_eval(self.text) -@dataclass(frozen=True) -class Input: - """ - This is a very simple class that is used to keep track of the input. - """ - content: str = field(repr=False) - name: str - - @property - def len(self): - return len(self.content) - - def __len__(self): - return self.len - - def get_lines_containing(self, - span: Span) -> tuple[list[str], int, int] | None: - # A pointer to the start of the first line - start = 0 - line_no = 0 - source = self.content - while True: - next_start = source.find('\n', start) - line_no += 1 - # Handle eof - if next_start == -1: - if span.start > len(source): - return None - return [source[start:]], start, line_no - # As long as the next newline comes before the spans start we can continue - if next_start < span.start: - start = next_start + 1 - continue - # If the whole span is on one line, we are good as well - if next_start >= span.end: - return [source[start:next_start]], start, line_no - while next_start < span.end: - next_start = source.find('\n', next_start + 1) - return source[start:next_start].split('\n'), start, line_no - - def at(self, i: int): - if i >= self.len: - raise EOFError() - return self.content[i] - - save_t = tuple[int, tuple[str, ...]] @@ -657,11 +538,13 @@ def parse_block(self) -> Block: block.declared_at = block_id else: if block_id.text in self.blocks: + block = self.blocks[block_id.text] + assert block.declared_at is not None, "Parsed block must have a span" raise MultipleSpansParseError( block_id, "Re-declaration of block {}".format(block_id.text), "Originally declared here:", - [(self.blocks[block_id.text].declared_at, None)], + [(block.declared_at, None)], self.tokenizer.history, ) block = Block(block_id) @@ -746,14 +629,13 @@ def parse_list_of(self, as it will read [3,4,4], then see another separator, and expects the next `try_parse` call to succeed (which won't as i32 is not a valid integer literal) """ - items = list() first_item = try_parse() if first_item is None: if allow_empty: - return items + return [] self.raise_error(error_msg) - items.append(first_item) + items = [first_item] while (match := self.tokenizer.next_token_of_pattern(separator_pattern) ) is not None: @@ -932,9 +814,12 @@ def unimplemented() -> ParametrizedAttribute: def parse_complex_attrs(self): self.raise_error("ComplexType is unimplemented!") - def parse_memref_attrs(self) -> MemRefType | UnrankedMemrefType: + def parse_memref_attrs( + self) -> MemRefType[Attribute] | UnrankedMemrefType[Attribute]: dims = self._parse_tensor_or_memref_dims() - type = self.try_parse_type() + type = self.expect( + self.try_parse_type, + "Type cannot be nil when parsing memref attributes") if dims is None: return UnrankedMemrefType.from_type(type) return MemRefType.from_element_type_and_shape(type, dims) @@ -1094,6 +979,10 @@ def parse_operation(self) -> Operation: # Check for custom op format op_name = self.try_parse_bare_id() if op_name is not None: + assert isinstance( + self, XDSLParser + ), "Only xDSL format currently supports custom op parsing" + assert ret_types is not None, "Return types must be in xDSL format" op_type = self._get_op_by_name(op_name) op = op_type.parse(ret_types, self) else: @@ -1104,7 +993,7 @@ def parse_operation(self) -> Operation: "Expected an operation name here, either a bare-id, or a string " "literal!") - args, successors, attrs, regions, func_type = self._parse_operation_details( + args, successors, attrs, regions, func_type = self.parse_operation_details( ) if ret_types is None: @@ -1293,6 +1182,9 @@ def _parse_builtin_dense_attr(self, _name: Span) -> Attribute | None: self.parse_characters(":", err_msg) type = self.expect(self.try_parse_type, "Dense attribute must be typed!") + + assert isa(type, AnyTensorType) + return DenseIntOrFPElementsAttr.from_list(type, info) def _parse_builtin_opaque_attr(self, _name: Span): @@ -1399,7 +1291,8 @@ def try_parse_ref_attr(self) -> SymbolRefAttr | None: else: return None - def try_parse_builtin_int_attr(self) -> IntegerAttr | None: + def try_parse_builtin_int_attr( + self) -> IntegerAttr[IntegerType | IndexType] | None: bool = self.try_parse_builtin_boolean_attr() if bool is not None: return bool @@ -1411,9 +1304,14 @@ def try_parse_builtin_int_attr(self) -> IntegerAttr | None: if self.tokenizer.next_token(peek=True).text != ':': return IntegerAttr.from_params(int(value.text), i64) type = self._parse_attribute_type() + + if not isinstance(type, IntegerType | IndexType): + self.raise_error( + f"Expected IntegerType | IndexType, got {type}") + return IntegerAttr.from_params(int(value.text), type) - def try_parse_builtin_float_attr(self) -> FloatAttr | None: + def try_parse_builtin_float_attr(self) -> AnyFloatAttr | None: with self.tokenizer.backtracking("float literal"): value = self.expect( self.try_parse_float_literal, @@ -1429,7 +1327,8 @@ def try_parse_builtin_float_attr(self) -> FloatAttr | None: "Float attribute must be typed with a float type!") return FloatAttr(float(value.text), type) - def try_parse_builtin_boolean_attr(self) -> IntegerAttr | None: + def try_parse_builtin_boolean_attr( + self) -> IntegerAttr[IntegerType | IndexType] | None: span = self.try_parse_boolean_literal() if span is None: @@ -1448,7 +1347,7 @@ def try_parse_builtin_str_attr(self): self.raise_error("Invalid string literal") return StringAttr(literal.string_contents) - def try_parse_builtin_arr_attr(self) -> ArrayAttr | None: + def try_parse_builtin_arr_attr(self) -> AnyArrayAttr | None: if not self.tokenizer.starts_with("["): return None with self.tokenizer.backtracking("array literal"): @@ -1526,7 +1425,7 @@ def parse_function_type(self) -> FunctionType: return FunctionType.from_lists(args, self._parse_type_or_type_list_parens()) - def _parse_type_or_type_list_parens(self) -> list[Attribute | None]: + def _parse_type_or_type_list_parens(self) -> list[Attribute]: """ Parses type-or-type-list-parens, which is used in function-type. @@ -1535,15 +1434,15 @@ def _parse_type_or_type_list_parens(self) -> list[Attribute | None]: type-list-no-parens ::= type (`,` type)* """ if self.tokenizer.next_token_of_pattern("(") is not None: - args: list[Attribute | None] = self.parse_list_of( - self.try_parse_type, "Expected type here!") + args = self.parse_list_of(self.try_parse_type, + "Expected type here!") self.parse_characters(")", "Unclosed function type argument list!") else: - args = [self.try_parse_type()] - if args[0] is None: - self.raise_error( - "Function type must either be single type or list of types in" - " parenthesis!") + arg = self.expect( + self.try_parse_type, + "Function type must either be single type or list of types in parentheses" + ) + args = [arg] return args def try_parse_function_type(self) -> FunctionType | None: @@ -1591,7 +1490,7 @@ def _parse_builtin_type_with_name(self, name: Span): return self._parse_builtin_parametrized_type(name) @abstractmethod - def _parse_operation_details( + def parse_operation_details( self, ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: @@ -1631,7 +1530,7 @@ def parse_op_with_default_format( """ # TODO: remove this function and restructure custom op / irdl parsing assert isinstance(self, XDSLParser) - args, successors, attributes, regions, _ = self._parse_operation_details( + args, successors, attributes, regions, _ = self.parse_operation_details( ) for x in args: @@ -1745,7 +1644,7 @@ def _parse_op_result_list( def parse_optional_attr_dict(self) -> dict[str, Attribute]: return self.parse_optional_dictionary_attr_dict() - def _parse_operation_details( + def parse_operation_details( self, ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: @@ -1891,7 +1790,7 @@ def parse_optional_attr_dict(self) -> dict[str, Attribute]: return self._attr_dict_from_tuple_list(attrs) - def _parse_operation_details( + def parse_operation_details( self, ) -> tuple[list[Span], list[Span], dict[str, Attribute], list[Region], FunctionType | None]: @@ -1974,7 +1873,7 @@ def Parser(ctx: MLContext, prog: str, source: Source = Source.XDSL, filename: str = '', - allow_unregistered_ops=False) -> BaseParser: + allow_unregistered_ops: bool = False) -> BaseParser: selected_parser = { Source.XDSL: XDSLParser, Source.MLIR: MLIRParser diff --git a/xdsl/utils/exceptions.py b/xdsl/utils/exceptions.py index 3cbb3361e4..7e2d61409e 100644 --- a/xdsl/utils/exceptions.py +++ b/xdsl/utils/exceptions.py @@ -94,12 +94,12 @@ def __contains__(self, item: str): class ParseError(Exception): - span: 'Span' + span: Span msg: str history: 'BacktrackingHistory' | None def __init__(self, - span: 'Span', + span: Span, msg: str, history: 'BacktrackingHistory' | None = None): super().__init__(DeferredExceptionMessage(lambda: repr(self))) @@ -125,14 +125,14 @@ def __repr__(self): class MultipleSpansParseError(ParseError): ref_text: str | None - refs: list[tuple['Span', str | None]] + refs: list[tuple[Span, str | None]] def __init__( self, - span: 'Span', + span: Span, msg: str, ref_text: str, - refs: list[tuple['Span', str | None]], + refs: list[tuple[Span, str | None]], history: 'BacktrackingHistory' | None = None, ): super(MultipleSpansParseError, self).__init__(span, msg, history) diff --git a/xdsl/utils/lexer.py b/xdsl/utils/lexer.py new file mode 100644 index 0000000000..42eb089815 --- /dev/null +++ b/xdsl/utils/lexer.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from io import StringIO + + +@dataclass(frozen=True) +class Input: + """ + Used to keep track of the input when parsing. + """ + content: str = field(repr=False) + name: str + + @property + def len(self): + return len(self.content) + + def __len__(self): + return self.len + + def get_lines_containing(self, + span: Span) -> tuple[list[str], int, int] | None: + # A pointer to the start of the first line + start = 0 + line_no = 0 + source = self.content + while True: + next_start = source.find('\n', start) + line_no += 1 + # Handle eof + if next_start == -1: + if span.start > len(source): + return None + return [source[start:]], start, line_no + # As long as the next newline comes before the spans start we can continue + if next_start < span.start: + start = next_start + 1 + continue + # If the whole span is on one line, we are good as well + if next_start >= span.end: + return [source[start:next_start]], start, line_no + while next_start < span.end: + next_start = source.find('\n', next_start + 1) + return source[start:next_start].split('\n'), start, line_no + + def at(self, i: int): + if i >= self.len: + raise EOFError() + return self.content[i] + + +@dataclass(frozen=True) +class Span: + """ + Parts of the input are always passed around as spans, so we know where they originated. + """ + + start: int + """ + Start of tokens location in source file, global byte offset in file + """ + end: int + """ + End of tokens location in source file, global byte offset in file + """ + input: Input + """ + The input being operated on + """ + + def __len__(self): + return self.len + + @property + def len(self): + return self.end - self.start + + @property + def text(self): + return self.input.content[self.start:self.end] + + def get_line_col(self) -> tuple[int, int]: + info = self.input.get_lines_containing(self) + if info is None: + return -1, -1 + _lines, offset_of_first_line, line_no = info + return line_no, self.start - offset_of_first_line + + def print_with_context(self, msg: str | None = None) -> str: + """ + returns a string containing lines relevant to the span. The Span's contents + are highlighted by up-carets beneath them (`^`). The message msg is printed + along these. + """ + info = self.input.get_lines_containing(self) + if info is None: + return "Unknown location of span {}. Error: ".format(msg) + lines, offset_of_first_line, line_no = info + # Offset relative to the first line: + offset = self.start - offset_of_first_line + remaining_len = max(self.len, 1) + capture = StringIO() + print("{}:{}:{}".format(self.input.name, line_no, offset), + file=capture) + for line in lines: + print(line, file=capture) + if remaining_len < 0: + continue + len_on_this_line = min(remaining_len, len(line) - offset) + remaining_len -= len_on_this_line + print("{}{}".format(" " * offset, "^" * max(len_on_this_line, 1)), + file=capture) + if msg is not None: + print("{}{}".format(" " * offset, msg), file=capture) + msg = None + offset = 0 + if msg is not None: + print(msg, file=capture) + return capture.getvalue() + + def __repr__(self): + return "{}[{}:{}](text='{}')".format(self.__class__.__name__, + self.start, self.end, self.text) From f05a0c4aa657887bfe6234ac315df9554990f04f Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 13 Mar 2023 10:11:37 +0000 Subject: [PATCH 06/10] MPI: Move `isend` and `irecv` to pointers (#530) * mpi: change ISend and IRecv to take pointers * mpi: add filecheck for lowering isend, irecv * format: remove unused import * format: make mpi imports conform to col limit * mpi: add back asserts to mpi baseof tests * mpi: change AnyAttr() to Attribute * format: remove unused import --- tests/dialects/test_mpi.py | 33 +++++-- tests/dialects/test_mpi_lowering.py | 14 ++- .../dialects/mpi/mpi-hello-world-async.mlir | 44 +++++++++ xdsl/dialects/mpi.py | 99 ++++++++++--------- xdsl/transforms/lower_mpi.py | 24 +---- 5 files changed, 129 insertions(+), 85 deletions(-) create mode 100644 tests/filecheck/mlir-conversion/with-mlir/dialects/mpi/mpi-hello-world-async.mlir diff --git a/tests/dialects/test_mpi.py b/tests/dialects/test_mpi.py index 63fc8ac028..8a8291705d 100644 --- a/tests/dialects/test_mpi.py +++ b/tests/dialects/test_mpi.py @@ -4,15 +4,32 @@ def test_mpi_baseop(): + """ + This test is used to track changes in `.get` and other accessors + """ alloc0 = memref.Alloc.get(f64, 32, [100, 14, 14]) dest = Constant.from_int_and_width(1, i32) - send = mpi.ISend.get(alloc0, dest, 1) - recv = mpi.IRecv.get(dest, alloc0.memref, 1) + unwrap = mpi.UnwrapMemrefOp.get(alloc0) + tag = Constant.from_int_and_width(1, i32) + send = mpi.ISend.get(unwrap.ptr, unwrap.len, unwrap.typ, dest, tag) + wait = mpi.Wait.get(send.request, ignore_status=False) + recv = mpi.IRecv.get(unwrap.ptr, unwrap.len, unwrap.typ, dest, tag) test_res = mpi.Test.get(recv.request) - code2 = mpi.Wait.get(recv.request) + source = mpi.GetStatusField.get(wait.status, + mpi.StatusTypeField.MPI_SOURCE) - assert send.operands[0] is alloc0.results[0] - assert send.operands[1] is dest.results[0] - assert recv.operands[0] is send.operands[1] - assert code2.operands[0] is recv.results[0] - assert test_res.operands[0] is recv.results[0] + assert unwrap.ref == alloc0.memref + assert send.buffer == unwrap.ptr + assert send.count == unwrap.len + assert send.datatype == unwrap.typ + assert send.dest == dest.result + assert send.tag == tag.result + assert wait.request == send.request + assert recv.buffer == unwrap.ptr + assert recv.count == unwrap.len + assert recv.datatype == unwrap.typ + assert recv.source == dest.result + assert recv.tag == tag.result + assert test_res.request == recv.request + assert source.status == wait.status + assert source.field.data == mpi.StatusTypeField.MPI_SOURCE.value diff --git a/tests/dialects/test_mpi_lowering.py b/tests/dialects/test_mpi_lowering.py index 406b24b200..06edbddba3 100644 --- a/tests/dialects/test_mpi_lowering.py +++ b/tests/dialects/test_mpi_lowering.py @@ -153,12 +153,11 @@ def test_lower_mpi_send(): def test_lower_mpi_isend(): - buff, dest = CreateTestValsOp.get( - mpi.MemRefType.from_element_type_and_shape(builtin.f64, [32, 32, 32]), - i32).results + ptr, count, dtype, dest, tag = CreateTestValsOp.get( + llvm.LLVMPointerType.opaque(), i32, mpi.DataType(), i32, i32).results ops, result = lower_mpi.LowerMpiISend(info).lower( - mpi.ISend.get(buff, dest, 1)) + mpi.ISend.get(ptr, count, dtype, dest, tag)) """ Check for function with signature like: int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest, @@ -219,16 +218,15 @@ def test_lower_mpi_recv_with_status(): def test_lower_mpi_irecv(): - buff, source = CreateTestValsOp.get( - mpi.MemRefType.from_element_type_and_shape(builtin.f64, [32, 32, 32]), - i32).results + ptr, count, dtype, source, tag = CreateTestValsOp.get( + llvm.LLVMPointerType.opaque(), i32, mpi.DataType(), i32, i32).results """ int MPI_Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request) """ ops, result = lower_mpi.LowerMpiIRecv(info).lower( - mpi.IRecv.get(source, buff, tag=3)) + mpi.IRecv.get(ptr, count, dtype, source, tag)) assert len(result) == 1 diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/mpi/mpi-hello-world-async.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/mpi/mpi-hello-world-async.mlir new file mode 100644 index 0000000000..14145c7ae2 --- /dev/null +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/mpi/mpi-hello-world-async.mlir @@ -0,0 +1,44 @@ +// RUN: xdsl-opt %s -t mlir -p lower-mpi | mlir-opt --convert-func-to-llvm --convert-memref-to-llvm --reconcile-unrealized-casts | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + "mpi.init"() : () -> () + %0 = "mpi.comm.rank"() : () -> i32 + %1 = "arith.constant"() {"value" = 0 : i32} : () -> i32 + %2 = "arith.cmpi"(%0, %1) {"predicate" = 0 : i64} : (i32, i32) -> i1 + %ref = "memref.alloc"() {"alignment" = 32 : i64, "operand_segment_sizes" = array} : () -> memref<100x14x14xf64> + %tag = "arith.constant"() {"value" = 1 : i32} : () -> i32 + %buff, %count, %dtype = "mpi.unwrap_memref"(%ref) : (memref<100x14x14xf64>) -> (!llvm.ptr, i32, !mpi.datatype) + "scf.if"(%2) ({ + %dest = "arith.constant"() {"value" = 1 : i32} : () -> i32 + %req = "mpi.isend"(%buff, %count, %dtype, %dest, %tag) : (!llvm.ptr, i32, !mpi.datatype, i32, i32) -> !mpi.request + "mpi.wait"(%req) : (!mpi.request) -> () + "scf.yield"() : () -> () + }, { + %source = "arith.constant"() {"value" = 0 : i32} : () -> i32 + %req = "mpi.irecv"(%buff, %count, %dtype, %source, %tag) {"tag" = 1 : i32} : (!llvm.ptr, i32, !mpi.datatype, i32, i32) -> !mpi.request + %status = "mpi.wait"(%req) : (!mpi.request) -> !mpi.status + "scf.yield"() : () -> () + }) : (i1) -> () + "mpi.finalize"() : () -> () + "func.return"() : () -> () + }) {"sym_name" = "main", "function_type" = () -> (), "sym_visibility" = "private"} : () -> () +}) : () -> () + +// we don't really care about the whole structure, we just want to make sure mlir-opt can lower all this down to llvm + +// CHECK: llvm.call @MPI_Init({{%\d+}}, {{%\d+}}) : (!llvm.ptr, !llvm.ptr) -> i32 +// CHECK: llvm.call @MPI_Comm_rank({{%\d+}}, {{%\d+}}) : (i32, !llvm.ptr) -> i32 +// CHECK: llvm.call @MPI_Isend({{%\d+}}, {{%\d+}}, {{%\d+}}, {{%\d+}}, {{%\d+}}, {{%\d+}}, {{%\d+}}) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 +// CHECK: llvm.call @MPI_Wait({{%\d+}}, {{%\d+}}) : (!llvm.ptr, !llvm.ptr) -> i32 +// CHECK: llvm.call @MPI_Irecv({{%\d+}}, {{%\d+}}, {{%\d+}}, {{%\d+}}, {{%\d+}}, {{%\d+}}, {{%\d+}}) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 +// CHECK: llvm.call @MPI_Wait({{%\d+}}, {{%\d+}}) : (!llvm.ptr, !llvm.ptr) -> i32 + +// also check that external funcs were declared correctly: + +// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32 attributes {sym_visibility = "private"} +// CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 attributes {sym_visibility = "private"} +// CHECK: llvm.func @MPI_Isend(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 attributes {sym_visibility = "private"} +// CHECK: llvm.func @MPI_Wait(!llvm.ptr, !llvm.ptr) -> i32 attributes {sym_visibility = "private"} +// CHECK: llvm.func @MPI_Irecv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 attributes {sym_visibility = "private"} +// CHECK: llvm.func @MPI_Finalize() -> i32 attributes {sym_visibility = "private"} diff --git a/xdsl/dialects/mpi.py b/xdsl/dialects/mpi.py index 4acbc5b4ff..f2fc8962cd 100644 --- a/xdsl/dialects/mpi.py +++ b/xdsl/dialects/mpi.py @@ -5,11 +5,12 @@ from typing import cast from xdsl.dialects import llvm -from xdsl.dialects.builtin import (IntegerType, Signedness, IntegerAttr, - StringAttr, AnyFloat, i32) +from xdsl.dialects.builtin import (IntegerType, Signedness, StringAttr, + AnyFloat, i32) from xdsl.dialects.memref import MemRefType -from xdsl.ir import Operation, Attribute, SSAValue, OpResult, ParametrizedAttribute, Dialect, MLIRType -from xdsl.irdl import (Operand, Annotated, irdl_op_definition, AnyAttr, +from xdsl.ir import (Operation, Attribute, SSAValue, OpResult, + ParametrizedAttribute, Dialect, MLIRType) +from xdsl.irdl import (Operand, Annotated, irdl_op_definition, irdl_attr_definition, OpAttr, OptOpResult) t_bool: IntegerType = IntegerType(1, Signedness.SIGNLESS) @@ -61,15 +62,6 @@ class MPIBaseOp(Operation, ABC): pass -def _build_attr_dict_with_optional_tag( - tag: int | None = None) -> dict[str, Attribute]: - """ - Helper function for building attribute dicts that have an optional `tag` entry - """ - - return {} if tag is None else {'tag': IntegerAttr.from_params(tag, i32)} - - @irdl_op_definition class ISend(MPIBaseOp): """ @@ -90,26 +82,33 @@ class ISend(MPIBaseOp): ## Our Abstraction: - - We Summarize buf, count and datatype by using memrefs - - We assume that tag is compile-time constant - - We omit the possibility of using multiple communicators + - We omit the possibility of using multiple communicators, defaulting + to MPI_COMM_WORLD """ name = 'mpi.isend' - buffer: Annotated[Operand, MemRefType[AnyNumericType]] + buffer: Annotated[Operand, Attribute] + count: Annotated[Operand, i32] + datatype: Annotated[Operand, DataType] dest: Annotated[Operand, i32] - - tag: OpAttr[IntegerAttr[Annotated[IntegerType, i32]]] + tag: Annotated[Operand, i32] request: Annotated[OpResult, RequestType] @classmethod - def get(cls, buff: SSAValue | Operation, dest: SSAValue | Operation, - tag: int | None): - return cls.build(operands=[buff, dest], - attributes=_build_attr_dict_with_optional_tag(tag), - result_types=[RequestType()]) + def get( + cls, + buffer: SSAValue | Operation, + count: SSAValue | Operation, + datatype: SSAValue | Operation, + dest: SSAValue | Operation, + tag: SSAValue | Operation, + ): + return cls.build( + operands=[buffer, count, datatype, dest, tag], + result_types=[RequestType()], + ) @irdl_op_definition @@ -132,14 +131,13 @@ class Send(MPIBaseOp): ## Our Abstraction: - - We Summarize buf, count and datatype by using memrefs - - We assume that tag is compile-time constant - - We omit the possibility of using multiple communicators + - We omit the possibility of using multiple communicators, defaulting + to MPI_COMM_WORLD """ name = 'mpi.send' - buffer: Annotated[Operand, AnyAttr()] + buffer: Annotated[Operand, Attribute] count: Annotated[Operand, i32] datatype: Annotated[Operand, DataType] dest: Annotated[Operand, i32] @@ -174,27 +172,33 @@ class IRecv(MPIBaseOp): ## Our Abstractions: - - We bundle buf, count and datatype into the type definition and use `memref` - - We assume tag is compile-time known - - We omit the possibility of using multiple communicators + - We omit the possibility of using multiple communicators, defaulting + to MPI_COMM_WORLD """ name = "mpi.irecv" + buffer: Annotated[Operand, Attribute] + count: Annotated[Operand, i32] + datatype: Annotated[Operand, DataType] source: Annotated[Operand, i32] - buffer: Annotated[Operand, MemRefType[AnyNumericType]] - tag: OpAttr[IntegerAttr[Annotated[IntegerType, i32]]] + tag: Annotated[Operand, i32] request: Annotated[OpResult, RequestType] @classmethod - def get(cls, - source: SSAValue | Operation, - buffer: SSAValue | Operation, - tag: int | None = None): - return cls.build(operands=[source, buffer], - attributes=_build_attr_dict_with_optional_tag(tag), - result_types=[RequestType()]) + def get( + cls, + buffer: SSAValue | Operation, + count: SSAValue | Operation, + datatype: SSAValue | Operation, + source: SSAValue | Operation, + tag: SSAValue | Operation, + ): + return cls.build( + operands=[buffer, count, datatype, source, tag], + result_types=[RequestType()], + ) @irdl_op_definition @@ -218,15 +222,13 @@ class Recv(MPIBaseOp): ## Our Abstractions: - - We bundle buf, count and datatype into the type definition and use `memref` - - We assume this type information is compile-time known - - We assume tag is compile-time known - - We omit the possibility of using multiple communicators + - We omit the possibility of using multiple communicators, defaulting + to MPI_COMM_WORLD """ name = "mpi.recv" - buffer: Annotated[Operand, AnyAttr()] + buffer: Annotated[Operand, Attribute] count: Annotated[Operand, i32] datatype: Annotated[Operand, DataType] source: Annotated[Operand, i32] @@ -292,11 +294,11 @@ class Wait(MPIBaseOp): name = "mpi.wait" request: Annotated[Operand, RequestType] - status: Annotated[OptOpResult, i32] + status: Annotated[OptOpResult, StatusType] @classmethod def get(cls, request: Operand, ignore_status: bool = True): - result_types: list[list[Attribute]] = [[i32]] + result_types: list[list[Attribute]] = [[StatusType()]] if ignore_status: result_types = [[]] @@ -394,7 +396,7 @@ class UnwrapMemrefOp(MPIBaseOp): typ: Annotated[OpResult, DataType] @staticmethod - def get(ref: SSAValue | Operation): + def get(ref: SSAValue | Operation) -> UnwrapMemrefOp: ssa_val = SSAValue.get(ref) assert isinstance(ssa_val.typ, MemRefType) elem_typ = cast(MemRefType[AnyNumericType], ssa_val.typ).element_type @@ -437,6 +439,7 @@ def get(typ: Attribute): Test, Recv, Send, + Wait, GetStatusField, Init, Finalize, diff --git a/xdsl/transforms/lower_mpi.py b/xdsl/transforms/lower_mpi.py index 142c039ee4..74b1d5b8ac 100644 --- a/xdsl/transforms/lower_mpi.py +++ b/xdsl/transforms/lower_mpi.py @@ -320,24 +320,15 @@ def lower(self, int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request) """ - count_ops, count_ssa_val = self._emit_memref_counts(op.buffer) - - assert isinstance(op.buffer.typ, memref.MemRefType) - memref_elm_typ = cast(memref.MemRefType[Attribute], - op.buffer.typ).element_type return [ - *count_ops, comm_global := arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), - datatype := self._emit_mpi_type_load(memref_elm_typ), - tag := arith.Constant.from_int_and_width(op.tag.value.data, i32), lit1 := arith.Constant.from_int_and_width(1, builtin.i64), request := llvm.AllocaOp.get( lit1, builtin.IntegerType(8 * self.info.MPI_Request_size)), - *(ptr := self._memref_get_llvm_ptr(op.buffer))[0], func.Call.get(self._mpi_name(op), [ - ptr[1], count_ssa_val, datatype, op.dest, tag, comm_global, + op.buffer, op.count, op.datatype, op.dest, op.tag, comm_global, request ], [i32]), ], [request.results[0]] @@ -357,25 +348,16 @@ def lower(self, int MPI_Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request) """ - count_ops, count_ssa_val = self._emit_memref_counts(op.buffer) - - assert isinstance(op.buffer.typ, memref.MemRefType) - memref_elm_typ = cast(memref.MemRefType[Attribute], - op.buffer.typ).element_type return [ - *count_ops, - *(ptr := self._memref_get_llvm_ptr(op.buffer))[0], - datatype := self._emit_mpi_type_load(memref_elm_typ), - tag := arith.Constant.from_int_and_width(op.tag.value.data, i32), comm_global := arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), lit1 := arith.Constant.from_int_and_width(1, builtin.i64), request := llvm.AllocaOp.get( lit1, builtin.IntegerType(8 * self.info.MPI_Request_size)), func.Call.get(self._mpi_name(op), [ - ptr[1], count_ssa_val, datatype, op.source, tag, comm_global, - request + op.buffer, op.count, op.datatype, op.source, op.tag, + comm_global, request ], [i32]), ], [request.res] From e5591a86f23ab5a051bddda729e64d07df971268 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Mon, 13 Mar 2023 11:02:35 +0000 Subject: [PATCH 07/10] core: always print a newline in a block, make consistent with MLIR syntax (#532) --- .../dialects/gpu/all_reduce_types.mlir | 3 +- tests/filecheck/dialects/gpu/example.mlir | 12 ++- tests/filecheck/dialects/gpu/launch_args.mlir | 3 +- .../dialects/gpu/terminator_launch.mlir | 3 +- .../dialects/gpu/terminator_terminate.mlir | 3 +- tests/filecheck/frontend/programs/programs.py | 3 +- tests/filecheck/func_ops.xdsl | 3 +- .../with-mlir/dialects/gpu/example.mlir | 12 ++- .../parser-printer/escaped_characters.xdsl | 30 +++++--- tests/test_pattern_rewriter.py | 76 +++++++++++++------ tests/test_rewriter.py | 12 ++- tests/xdsl_opt/empty_program.wrong | 3 +- tests/xdsl_opt/empty_program.xdsl | 3 +- xdsl/printer.py | 43 +++++------ 14 files changed, 130 insertions(+), 79 deletions(-) diff --git a/tests/filecheck/dialects/gpu/all_reduce_types.mlir b/tests/filecheck/dialects/gpu/all_reduce_types.mlir index 49a1689df9..d0c42b38ba 100644 --- a/tests/filecheck/dialects/gpu/all_reduce_types.mlir +++ b/tests/filecheck/dialects/gpu/all_reduce_types.mlir @@ -3,7 +3,8 @@ "builtin.module"()({ "gpu.module"()({ %init = "arith.constant"() {"value" = 42 : index} : () -> index - %sum = "gpu.all_reduce"(%init) ({}) {"op" = #gpu} : (index) -> f32 + %sum = "gpu.all_reduce"(%init) ({ + }) {"op" = #gpu} : (index) -> f32 "gpu.module_end"() : () -> () }) {"sym_name" = "gpu"} : () -> () }) {} : () -> () diff --git a/tests/filecheck/dialects/gpu/example.mlir b/tests/filecheck/dialects/gpu/example.mlir index 09eacb042a..966e38b6ea 100644 --- a/tests/filecheck/dialects/gpu/example.mlir +++ b/tests/filecheck/dialects/gpu/example.mlir @@ -39,7 +39,8 @@ %subgroupid = "gpu.subgroup_id"() : () -> index %subgroupsize = "gpu.subgroup_size"() : () -> index - %globalprodx = "gpu.all_reduce"(%globalidx) ({}) {"op" = #gpu} : (index) -> index + %globalprodx = "gpu.all_reduce"(%globalidx) ({ + }) {"op" = #gpu} : (index) -> index %globalsumy = "gpu.all_reduce"(%globalidy) ({ ^bb(%lhs : index, %rhs : index): @@ -52,7 +53,8 @@ %tx : index, %ty : index, %tz : index, %num_bx : index, %num_by : index, %num_bz : index, %num_tx : index, %num_ty : index, %num_tz : index): - %sum = "gpu.all_reduce"(%tx) ({}) {"op" = #gpu} : (index) -> index + %sum = "gpu.all_reduce"(%tx) ({ + }) {"op" = #gpu} : (index) -> index %final = "arith.muli"(%sum, %one) : (index, index) -> index "gpu.terminator"() : () -> () }) {"operand_segment_sizes" = array} : (index, index, index, index, index, index) -> () @@ -102,7 +104,8 @@ // CHECK-NEXT: %{{.*}} = "gpu.subgroup_id"() : () -> index // CHECK-NEXT: %{{.*}} = "gpu.subgroup_size"() : () -> index -// CHECK-NEXT: %{{.*}} = "gpu.all_reduce"(%{{.*}}) ({}) {"op" = #gpu} : (index) -> index +// CHECK-NEXT: %{{.*}} = "gpu.all_reduce"(%{{.*}}) ({ +// CHECK-NEXT: }) {"op" = #gpu} : (index) -> index // CHECK-NEXT: %{{.*}} = "gpu.all_reduce"(%{{.*}}) ({ // CHECK-NEXT: ^{{.*}}(%{{.*}} : index, %{{.*}} : index): @@ -115,7 +118,8 @@ // CHECK-SAME: %{{.*}} : index, %{{.*}} : index, %{{.*}} : index, // CHECK-SAME: %{{.*}} : index, %{{.*}} : index, %{{.*}} : index, // CHECK-SAME: %{{.*}} : index, %{{.*}} : index, %{{.*}} : index): -// CHECK-NEXT: %{{.*}} = "gpu.all_reduce"(%{{.*}}) ({}) {"op" = #gpu} : (index) -> index +// CHECK-NEXT: %{{.*}} = "gpu.all_reduce"(%{{.*}}) ({ +// CHECK-NEXT: }) {"op" = #gpu} : (index) -> index // CHECK-NEXT: %{{.*}} = "arith.muli"(%{{.*}}, %{{.*}}) : (index, index) -> index // CHECK-NEXT: "gpu.terminator"() : () -> () // CHECK-NEXT: }) {"operand_segment_sizes" = array} : (index, index, index, index, index, index) -> () diff --git a/tests/filecheck/dialects/gpu/launch_args.mlir b/tests/filecheck/dialects/gpu/launch_args.mlir index 864736ebd5..d84dcbaa6b 100644 --- a/tests/filecheck/dialects/gpu/launch_args.mlir +++ b/tests/filecheck/dialects/gpu/launch_args.mlir @@ -8,7 +8,8 @@ "gpu.launch"(%one, %one, %n, %one, %one, %one) ({ ^bb0(%bx : index, %by : index, %bz : index, %tx : index, %ty : index, %tz : index): - %sum = "gpu.all_reduce"(%tx) ({}) {"op" = #gpu} : (index) -> index + %sum = "gpu.all_reduce"(%tx) ({ + }) {"op" = #gpu} : (index) -> index %final = "arith.muli"(%sum, %one) : (index, index) -> index "gpu.terminator"() : () -> () }) {"operand_segment_sizes" = array} : (index, index, index, index, index, index) -> () diff --git a/tests/filecheck/dialects/gpu/terminator_launch.mlir b/tests/filecheck/dialects/gpu/terminator_launch.mlir index 4987f0864c..63c17262b6 100644 --- a/tests/filecheck/dialects/gpu/terminator_launch.mlir +++ b/tests/filecheck/dialects/gpu/terminator_launch.mlir @@ -10,7 +10,8 @@ %tx : index, %ty : index, %tz : index, %num_bx : index, %num_by : index, %num_bz : index, %num_tx : index, %num_ty : index, %num_tz : index): - %sum = "gpu.all_reduce"(%tx) ({}) {"op" = #gpu} : (index) -> index + %sum = "gpu.all_reduce"(%tx) ({ + }) {"op" = #gpu} : (index) -> index %final = "arith.muli"(%sum, %one) : (index, index) -> index }) {"operand_segment_sizes" = array} : (index, index, index, index, index, index) -> () diff --git a/tests/filecheck/dialects/gpu/terminator_terminate.mlir b/tests/filecheck/dialects/gpu/terminator_terminate.mlir index 70b994158e..ee0b846690 100644 --- a/tests/filecheck/dialects/gpu/terminator_terminate.mlir +++ b/tests/filecheck/dialects/gpu/terminator_terminate.mlir @@ -10,7 +10,8 @@ %tx : index, %ty : index, %tz : index, %num_bx : index, %num_by : index, %num_bz : index, %num_tx : index, %num_ty : index, %num_tz : index): - %sum = "gpu.all_reduce"(%tx) ({}) {"op" = #gpu} : (index) -> index + %sum = "gpu.all_reduce"(%tx) ({ + }) {"op" = #gpu} : (index) -> index "gpu.terminator"() : () -> () %final = "arith.muli"(%sum, %one) : (index, index) -> index }) {"operand_segment_sizes" = array} : (index, index, index, index, index, index) -> () diff --git a/tests/filecheck/frontend/programs/programs.py b/tests/filecheck/frontend/programs/programs.py index 708c367f1b..bbd4c0426d 100644 --- a/tests/filecheck/frontend/programs/programs.py +++ b/tests/filecheck/frontend/programs/programs.py @@ -5,7 +5,8 @@ p = FrontendProgram() with CodeContext(p): - # CHECK: builtin.module() {} + # CHECK: builtin.module() { + # CHECK-NEXT: } pass p.compile(desymref=False) diff --git a/tests/filecheck/func_ops.xdsl b/tests/filecheck/func_ops.xdsl index 54ec9f4bc3..a20d6b4662 100644 --- a/tests/filecheck/func_ops.xdsl +++ b/tests/filecheck/func_ops.xdsl @@ -34,6 +34,7 @@ builtin.module() { func.func() ["sym_name" = "external_fn", "function_type" = !fun<[!i32], [!i32]>, "sym_visibility" = "private"] { } - // CHECK: func.func() ["sym_name" = "external_fn", "function_type" = !fun<[!i32], [!i32]>, "sym_visibility" = "private"] {} + // CHECK: func.func() ["sym_name" = "external_fn", "function_type" = !fun<[!i32], [!i32]>, "sym_visibility" = "private"] { + // CHECK-NEXT: } } diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/gpu/example.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/gpu/example.mlir index 6814409379..8fb10ca8d7 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/gpu/example.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/gpu/example.mlir @@ -39,7 +39,8 @@ %subgroupid = "gpu.subgroup_id"() : () -> index %subgroupsize = "gpu.subgroup_size"() : () -> index - %globalprodx = "gpu.all_reduce"(%globalidx) ({}) {"op" = #gpu} : (index) -> index + %globalprodx = "gpu.all_reduce"(%globalidx) ({ + }) {"op" = #gpu} : (index) -> index %globalsumy = "gpu.all_reduce"(%globalidy) ({ ^bb(%lhs : index, %rhs : index): @@ -52,7 +53,8 @@ %tx : index, %ty : index, %tz : index, %num_bx : index, %num_by : index, %num_bz : index, %num_tx : index, %num_ty : index, %num_tz : index): - %sum = "gpu.all_reduce"(%tx) ({}) {"op" = #gpu} : (index) -> index + %sum = "gpu.all_reduce"(%tx) ({ + }) {"op" = #gpu} : (index) -> index %final = "arith.muli"(%sum, %one) : (index, index) -> index "gpu.terminator"() : () -> () }) {"operand_segment_sizes" = array} : (index, index, index, index, index, index) -> () @@ -102,7 +104,8 @@ // CHECK-NEXT: %{{.*}} = "gpu.subgroup_id"() : () -> index // CHECK-NEXT: %{{.*}} = "gpu.subgroup_size"() : () -> index -// CHECK-NEXT: %{{.*}} = "gpu.all_reduce"(%{{.*}}) ({}) {"op" = #gpu} : (index) -> index +// CHECK-NEXT: %{{.*}} = "gpu.all_reduce"(%{{.*}}) ({ +// CHECK-NEXT: }) {"op" = #gpu} : (index) -> index // CHECK-NEXT: %{{.*}} = "gpu.all_reduce"(%{{.*}}) ({ // CHECK-NEXT: ^{{.*}}(%{{.*}} : index, %{{.*}} : index): @@ -115,7 +118,8 @@ // CHECK-SAME: %{{.*}} : index, %{{.*}} : index, %{{.*}} : index, // CHECK-SAME: %{{.*}} : index, %{{.*}} : index, %{{.*}} : index, // CHECK-SAME: %{{.*}} : index, %{{.*}} : index, %{{.*}} : index): -// CHECK-NEXT: %{{.*}} = "gpu.all_reduce"(%{{.*}}) ({}) {"op" = #gpu} : (index) -> index +// CHECK-NEXT: %{{.*}} = "gpu.all_reduce"(%{{.*}}) ({ +// CHECK-NEXT: }) {"op" = #gpu} : (index) -> index // CHECK-NEXT: %{{.*}} = "arith.muli"(%{{.*}}, %{{.*}}) : (index, index) -> index // CHECK-NEXT: "gpu.terminator"() : () -> () // CHECK-NEXT: }) {"operand_segment_sizes" = array} : (index, index, index, index, index, index) -> () diff --git a/tests/filecheck/parser-printer/escaped_characters.xdsl b/tests/filecheck/parser-printer/escaped_characters.xdsl index 71d81e9627..f4e5b91acd 100644 --- a/tests/filecheck/parser-printer/escaped_characters.xdsl +++ b/tests/filecheck/parser-printer/escaped_characters.xdsl @@ -1,15 +1,25 @@ // RUN: xdsl-opt %s | xdsl-opt | filecheck %s builtin.module() { - func.func() ["sym_name" = "\"", "function_type" = !fun<[], []>, "sym_visibility" = "private"] {} - func.func() ["sym_name" = "\n", "function_type" = !fun<[], []>, "sym_visibility" = "private"] {} - func.func() ["sym_name" = "\t", "function_type" = !fun<[], []>, "sym_visibility" = "private"] {} - func.func() ["sym_name" = "\\", "function_type" = !fun<[], []>, "sym_visibility" = "private"] {} - func.func() ["sym_name" = "\r", "function_type" = !fun<[], []>, "sym_visibility" = "private"] {} + func.func() ["sym_name" = "\"", "function_type" = !fun<[], []>, "sym_visibility" = "private"] { + } + func.func() ["sym_name" = "\n", "function_type" = !fun<[], []>, "sym_visibility" = "private"] { + } + func.func() ["sym_name" = "\t", "function_type" = !fun<[], []>, "sym_visibility" = "private"] { + } + func.func() ["sym_name" = "\\", "function_type" = !fun<[], []>, "sym_visibility" = "private"] { + } + func.func() ["sym_name" = "\r", "function_type" = !fun<[], []>, "sym_visibility" = "private"] { + } } -// CHECK: func.func() ["sym_name" = "\"", "function_type" = !fun<[], []>, "sym_visibility" = "private"] {} -// CHECK-NEXT: func.func() ["sym_name" = "\n", "function_type" = !fun<[], []>, "sym_visibility" = "private"] {} -// CHECK-NEXT: func.func() ["sym_name" = "\t", "function_type" = !fun<[], []>, "sym_visibility" = "private"] {} -// CHECK-NEXT: func.func() ["sym_name" = "\\", "function_type" = !fun<[], []>, "sym_visibility" = "private"] {} -// CHECK-NEXT: func.func() ["sym_name" = "\r", "function_type" = !fun<[], []>, "sym_visibility" = "private"] {} +// CHECK: func.func() ["sym_name" = "\"", "function_type" = !fun<[], []>, "sym_visibility" = "private"] { +// CHECK-NEXT: } +// CHECK-NEXT: func.func() ["sym_name" = "\n", "function_type" = !fun<[], []>, "sym_visibility" = "private"] { +// CHECK-NEXT: } +// CHECK-NEXT: func.func() ["sym_name" = "\t", "function_type" = !fun<[], []>, "sym_visibility" = "private"] { +// CHECK-NEXT: } +// CHECK-NEXT: func.func() ["sym_name" = "\\", "function_type" = !fun<[], []>, "sym_visibility" = "private"] { +// CHECK-NEXT: } +// CHECK-NEXT: func.func() ["sym_name" = "\r", "function_type" = !fun<[], []>, "sym_visibility" = "private"] { +// CHECK-NEXT: } diff --git a/tests/test_pattern_rewriter.py b/tests/test_pattern_rewriter.py index f189436570..e39222cbc8 100644 --- a/tests/test_pattern_rewriter.py +++ b/tests/test_pattern_rewriter.py @@ -404,7 +404,8 @@ def test_operation_deletion(): }""" expected = \ -"""builtin.module() {}""" +"""builtin.module() { +}""" @op_type_rewrite_pattern def match_and_rewrite(op: Constant, rewriter: PatternRewriter): @@ -428,7 +429,8 @@ def test_operation_deletion_reversed(): }""" expected = \ -"""builtin.module() {}""" +"""builtin.module() { +}""" def match_and_rewrite(op: Operation, rewriter: PatternRewriter): if not isinstance(op, ModuleOp): @@ -478,7 +480,8 @@ def test_delete_inner_op(): }""" expected = \ -"""builtin.module() {}""" +"""builtin.module() { +}""" @op_type_rewrite_pattern def match_and_rewrite(op: ModuleOp, rewriter: PatternRewriter): @@ -529,7 +532,8 @@ def test_block_argument_type_change(): scf.if(%0 : !i1) { ^0(%1 : !i64): %2 : !i32 = arith.addi(%1 : !i64, %1 : !i64) - } {} + } { + } }""" @op_type_rewrite_pattern @@ -557,7 +561,9 @@ def test_block_argument_erasure(): expected = \ """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] - scf.if(%0 : !i1) {} {} + scf.if(%0 : !i1) { + } { + } }""" @op_type_rewrite_pattern @@ -576,7 +582,9 @@ def test_block_argument_insertion(): prog = \ """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] -scf.if(%0 : !i1) {} {} +scf.if(%0 : !i1) { +} { +} }""" expected = \ @@ -584,7 +592,8 @@ def test_block_argument_insertion(): %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { ^0(%1 : !i32): - } {} + } { + } }""" @op_type_rewrite_pattern @@ -606,8 +615,10 @@ def test_inline_block_at_pos(): scf.if(%0 : !i1) { scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] - } {} -} {} + } { + } +} { +} }""" expected = \ @@ -615,8 +626,11 @@ def test_inline_block_at_pos(): %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] - scf.if(%0 : !i1) {} {} - } {} + scf.if(%0 : !i1) { + } { + } + } { + } }""" @op_type_rewrite_pattern @@ -648,7 +662,9 @@ def test_inline_block_before_matched_op(): """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] %1 : !i32 = arith.constant() ["value" = 2 : !i32] - scf.if(%0 : !i1) {} {} + scf.if(%0 : !i1) { + } { + } }""" @op_type_rewrite_pattern @@ -671,7 +687,8 @@ def test_inline_block_before(): scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] } {} -} {} +} { +} }""" expected = \ @@ -679,8 +696,11 @@ def test_inline_block_before(): %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] - scf.if(%0 : !i1) {} {} - } {} + scf.if(%0 : !i1) { + } { + } + } { + } }""" @op_type_rewrite_pattern @@ -706,14 +726,17 @@ def test_inline_block_at_before_when_op_is_matched_op(): %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] -} {} +} { +} }""" expected = \ """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] %1 : !i32 = arith.constant() ["value" = 2 : !i32] - scf.if(%0 : !i1) {} {} + scf.if(%0 : !i1) { + } { + } }""" @op_type_rewrite_pattern @@ -735,17 +758,22 @@ def test_inline_block_after(): scf.if(%0 : !i1) { scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] - } {} - } {} + } { + } + } { + } }""" expected = \ """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { - scf.if(%0 : !i1) {} {} + scf.if(%0 : !i1) { + } { + } %1 : !i32 = arith.constant() ["value" = 2 : !i32] - } {} + } { + } }""" @op_type_rewrite_pattern @@ -777,10 +805,12 @@ def test_move_region_contents_to_new_regions(): expected = \ """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] - scf.if(%0 : !i1) {} + scf.if(%0 : !i1) { + } scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] - } {} + } { + } }""" @op_type_rewrite_pattern diff --git a/tests/test_rewriter.py b/tests/test_rewriter.py index a490fccbe2..a374b0e361 100644 --- a/tests/test_rewriter.py +++ b/tests/test_rewriter.py @@ -36,7 +36,8 @@ def test_operation_deletion(): }""" expected = \ -"""builtin.module() {}""" +"""builtin.module() { +}""" def transformation(module: ModuleOp, rewriter: Rewriter) -> None: constant_op = module.ops[0] @@ -129,7 +130,8 @@ def test_inline_block_at_pos(): """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] %1 : !i32 = arith.constant() ["value" = 2 : !i32] - scf.if(%0 : !i1) {} + scf.if(%0 : !i1) { + } }""" def transformation(module: ModuleOp, rewriter: Rewriter) -> None: @@ -156,7 +158,8 @@ def test_inline_block_before(): """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] %1 : !i32 = arith.constant() ["value" = 2 : !i32] - scf.if(%0 : !i1) {} + scf.if(%0 : !i1) { + } }""" def transformation(module: ModuleOp, rewriter: Rewriter) -> None: @@ -182,7 +185,8 @@ def test_inline_block_after(): """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] %1 : !i32 = arith.constant() ["value" = 2 : !i32] - scf.if(%0 : !i1) {} + scf.if(%0 : !i1) { + } }""" def transformation(module: ModuleOp, rewriter: Rewriter) -> None: diff --git a/tests/xdsl_opt/empty_program.wrong b/tests/xdsl_opt/empty_program.wrong index e37af90b8f..355b53d3ce 100644 --- a/tests/xdsl_opt/empty_program.wrong +++ b/tests/xdsl_opt/empty_program.wrong @@ -1 +1,2 @@ -builtin.module() {} +builtin.module() { +} diff --git a/tests/xdsl_opt/empty_program.xdsl b/tests/xdsl_opt/empty_program.xdsl index e37af90b8f..355b53d3ce 100644 --- a/tests/xdsl_opt/empty_program.xdsl +++ b/tests/xdsl_opt/empty_program.xdsl @@ -1 +1,2 @@ -builtin.module() {} +builtin.module() { +} diff --git a/xdsl/printer.py b/xdsl/printer.py index 5ee491352f..a04177ec32 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -59,6 +59,7 @@ def print(self, *argv: Any) -> None: continue if isinstance(arg, Block): self.print_block(arg) + self._print_new_line() continue if isinstance(arg, Operation): self.print_op(arg) @@ -211,35 +212,32 @@ def _print_operand(self, operand: SSAValue) -> None: self.print(" : ") self.print_attribute(operand.typ) - def _print_ops(self, ops: List[Operation]) -> None: - self._indent += 1 - for op in ops: - self._print_new_line() - self.print_op(op) - self._indent -= 1 - if len(ops) > 0: - self._print_new_line() - def print_block_name(self, block: Block) -> None: self.print("^") if block not in self._block_names: self._block_names[block] = self._get_new_valid_block_id() self.print(self._block_names[block]) - def print_block(self, block: Block) -> None: + def print_block(self, block: Block, print_block_name: bool = True) -> None: if not isinstance(block, Block): raise TypeError('Expected a Block; got %s' % type(block).__name__) - self.print_block_name(block) - if len(block.args) > 0: + print_block_args = len(block.args) > 0 + if print_block_args or print_block_name: + self._print_new_line() + self.print_block_name(block) + if print_block_args: self.print("(") self.print_list(block.args, self._print_block_arg) self.print(")") - self.print(":") - if len(block.ops) > 0: - self._print_ops(block.ops) - else: + if print_block_args or print_block_name: + self.print(":") + + self._indent += 1 + for op in block.ops: self._print_new_line() + self.print_op(op) + self._indent -= 1 def _print_block_arg(self, arg: BlockArgument) -> None: self.print("%") @@ -252,20 +250,13 @@ def print_region(self, region: Region) -> None: if not isinstance(region, Region): raise TypeError('Expected a Region; got %s' % type(region).__name__) - if len(region.blocks) == 0: - self.print("{}") - return - if len(region.blocks) == 1 and len(region.blocks[0].args) == 0: - self.print("{") - self._print_ops(region.blocks[0].ops) - self.print("}") - return + print_block_name = len(region.blocks) != 1 self.print("{") - self._print_new_line() for block in region.blocks: - self.print_block(block) + self.print_block(block, print_block_name=print_block_name) + self._print_new_line() self.print("}") def print_regions(self, regions: List[Region]) -> None: From 020250601a13f1c0c1250a8afebaadf8c8be1086 Mon Sep 17 00:00:00 2001 From: Nick Brown Date: Mon, 13 Mar 2023 12:08:29 +0000 Subject: [PATCH 08/10] dialects: Initial math dialect (#536) * Initial math dialect with sqrt and pow operations * Fix for incorrect return types * Updates to fix typing for Pyright * Issuing SSAValue.get on RHS too across operations to ensure that it is consistent --------- Co-authored-by: Nick Brown --- xdsl/dialects/experimental/math.py | 173 +++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 xdsl/dialects/experimental/math.py diff --git a/xdsl/dialects/experimental/math.py b/xdsl/dialects/experimental/math.py new file mode 100644 index 0000000000..348d465631 --- /dev/null +++ b/xdsl/dialects/experimental/math.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from typing import Annotated, Union + +from xdsl.dialects.builtin import IntegerType, AnyFloat, Attribute +from xdsl.ir import Operation, SSAValue, OpResult +from xdsl.irdl import irdl_op_definition, OptOpAttr, Operand +from xdsl.dialects.arith import FastMathFlagsAttr + + +@irdl_op_definition +class FPowIOp(Operation): + """ + Syntax: + operation ::= ssa-id `=` `math.fpowi` ssa-use `,` ssa-use `:` type + + The fpowi operation takes a `base` operand of floating point type + (i.e. scalar, tensor or vector) and a `power` operand of integer type + (also scalar, tensor or vector) and returns one result of the same type + as `base`. The result is `base` raised to the power of `power`. + The operation is elementwise for non-scalars, e.g.: + + %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32 + + The result is a vector of: + + [, ] + + Example: + + // Scalar exponentiation. + %a = math.fpowi %base, %power : f64, i32 + """ + name: str = "math.fpowi" + fastmath: OptOpAttr[FastMathFlagsAttr] + lhs: Annotated[Operand, AnyFloat] + rhs: Annotated[Operand, IntegerType] + result: Annotated[OpResult, AnyFloat] + + @staticmethod + def get(lhs: Union[Operation, SSAValue], + rhs: Union[Operation, SSAValue], + fastmath: FastMathFlagsAttr | None = None) -> FPowIOp: + attributes: dict[str, Attribute] = {} + if fastmath is not None: + attributes["fastmath"] = fastmath + + lhs = SSAValue.get(lhs) + rhs = SSAValue.get(rhs) + return FPowIOp.build(attributes=attributes, + operands=[lhs, rhs], + result_types=[lhs.typ]) + + +@irdl_op_definition +class IPowIOp(Operation): + """ + Syntax: + operation ::= ssa-id `=` `math.ipowi` ssa-use `,` ssa-use `:` type + + The ipowi operation takes two operands of integer type (i.e., scalar, + tensor or vector) and returns one result of the same type. Operands + must have the same type. + + Example: + // Scalar signed integer exponentiation. + %a = math.ipowi %b, %c : i32 + """ + name: str = "math.ipowi" + lhs: Annotated[Operand, IntegerType] + rhs: Annotated[Operand, IntegerType] + result: Annotated[OpResult, IntegerType] + + @staticmethod + def get(lhs: Union[Operation, SSAValue], rhs: Union[Operation, + SSAValue]) -> IPowIOp: + lhs = SSAValue.get(lhs) + rhs = SSAValue.get(rhs) + return IPowIOp.build(operands=[lhs, rhs], result_types=[lhs.typ]) + + +@irdl_op_definition +class PowFOp(Operation): + """ + Syntax: + operation ::= ssa-id `=` `math.powf` ssa-use `,` ssa-use `:` type + + The powf operation takes two operands of floating point type (i.e., + scalar, tensor or vector) and returns one result of the same type. Operands + must have the same type. + + Example: + + // Scalar exponentiation. + %a = math.powf %b, %c : f64 + """ + name: str = "math.powf" + fastmath: OptOpAttr[FastMathFlagsAttr] + lhs: Annotated[Operand, AnyFloat] + rhs: Annotated[Operand, AnyFloat] + result: Annotated[OpResult, AnyFloat] + + @staticmethod + def get(lhs: Union[Operation, SSAValue], + rhs: Union[Operation, SSAValue], + fastmath: FastMathFlagsAttr | None = None) -> PowFOp: + attributes: dict[str, Attribute] = {} + if fastmath is not None: + attributes["fastmath"] = fastmath + + lhs = SSAValue.get(lhs) + rhs = SSAValue.get(rhs) + return PowFOp.build(attributes=attributes, + operands=[lhs, rhs], + result_types=[lhs.typ]) + + +@irdl_op_definition +class RsqrtOp(Operation): + """ + The rsqrt operation computes the reciprocal of the square root. It takes + one operand of floating point type (i.e., scalar, tensor or vector) and returns + one result of the same type. It has no standard attributes. + + Example: + // Scalar reciprocal square root value. + %a = math.rsqrt %b : f64 + """ + name: str = "math.rsqrt" + fastmath: OptOpAttr[FastMathFlagsAttr] + operand: Annotated[Operand, AnyFloat] + result: Annotated[OpResult, AnyFloat] + + @staticmethod + def get(operand: Union[Operation, SSAValue], + fastmath: FastMathFlagsAttr | None = None) -> RsqrtOp: + attributes: dict[str, Attribute] = {} + if fastmath is not None: + attributes["fastmath"] = fastmath + + operand = SSAValue.get(operand) + return RsqrtOp.build(attributes=attributes, + operands=[operand], + result_types=[operand.typ]) + + +@irdl_op_definition +class SqrtOp(Operation): + """ + The sqrt operation computes the square root. It takes one operand of + floating point type (i.e., scalar, tensor or vector) and returns one result of + the same type. It has no standard attributes. + + Example: + // Scalar square root value. + %a = math.sqrt %b : f64 + """ + name: str = "math.sqrt" + fastmath: OptOpAttr[FastMathFlagsAttr] + operand: Annotated[Operand, AnyFloat] + result: Annotated[OpResult, AnyFloat] + + @staticmethod + def get(operand: Union[Operation, SSAValue], + fastmath: FastMathFlagsAttr | None = None) -> SqrtOp: + attributes: dict[str, Attribute] = {} + if fastmath is not None: + attributes["fastmath"] = fastmath + + operand = SSAValue.get(operand) + return SqrtOp.build(attributes=attributes, + operands=[operand], + result_types=[operand.typ]) From 8695a1f9862bda5acfff0ba10975f573847fb83e Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Mon, 13 Mar 2023 16:00:35 +0000 Subject: [PATCH 09/10] misc: add __init__.py to dialects/experimental folder (#538) --- xdsl/dialects/experimental/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 xdsl/dialects/experimental/__init__.py diff --git a/xdsl/dialects/experimental/__init__.py b/xdsl/dialects/experimental/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From 6bc3779e0aa4b289a47e1e53cf0610864a368db5 Mon Sep 17 00:00:00 2001 From: Prathamesh Tagore <63031630+meshtag@users.noreply.github.com> Date: Mon, 13 Mar 2023 22:20:27 +0530 Subject: [PATCH 10/10] dialects: Add filecheck tests for stencil dialect and initial lowering (#499) Co-authored-by: Emilien Bauer --- tests/dialects/test_scf.py | 4 +- tests/filecheck/dialects/stencil/copy.mlir | 35 ++++ tests/filecheck/dialects/stencil/hdiff.mlir | 57 +++++++ tests/filecheck/dialects/stencil/laplace.mlir | 55 +++++++ .../stencil/test_funcop_lowering.mlir | 16 ++ tests/xdsl_opt/test_xdsl_opt.py | 4 +- xdsl/dialects/experimental/stencil.py | 155 ++++++++++-------- .../experimental/ConvertStencilToLLMLIR.py | 68 ++++++++ xdsl/xdsl_opt_main.py | 7 + 9 files changed, 329 insertions(+), 72 deletions(-) create mode 100644 tests/filecheck/dialects/stencil/copy.mlir create mode 100644 tests/filecheck/dialects/stencil/hdiff.mlir create mode 100644 tests/filecheck/dialects/stencil/laplace.mlir create mode 100644 tests/filecheck/dialects/stencil/test_funcop_lowering.mlir create mode 100644 xdsl/transforms/experimental/ConvertStencilToLLMLIR.py diff --git a/tests/dialects/test_scf.py b/tests/dialects/test_scf.py index af88b192a8..4611ab87c7 100644 --- a/tests/dialects/test_scf.py +++ b/tests/dialects/test_scf.py @@ -1,5 +1,5 @@ -from xdsl.dialects.arith import Constant, IndexType -from xdsl.dialects.builtin import Region +from xdsl.dialects.arith import Constant +from xdsl.dialects.builtin import Region, IndexType from xdsl.dialects.cf import Block from xdsl.dialects.scf import For diff --git a/tests/filecheck/dialects/stencil/copy.mlir b/tests/filecheck/dialects/stencil/copy.mlir new file mode 100644 index 0000000000..847438cfbd --- /dev/null +++ b/tests/filecheck/dialects/stencil/copy.mlir @@ -0,0 +1,35 @@ +// RUN: xdsl-opt %s -t mlir | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + ^0(%0 : !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>, %1 : !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>): + %2 = "stencil.cast"(%0) {"lb" = #stencil.index<[-4 : i32, -4 : i32, -4 : i32]>, "ub" = #stencil.index<[68 : i32, 68 : i32, 68 : i32]>} : (!stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>) -> !stencil.field<[72 : i32, 72 : i32, 72 : i32], f64> + %3 = "stencil.cast"(%1) {"lb" = #stencil.index<[-4 : i32, -4 : i32, -4 : i32]>, "ub" = #stencil.index<[68 : i32, 68 : i32, 68 : i32]>} : (!stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>) -> !stencil.field<[72 : i32, 72 : i32, 72 : i32], f64> + %4 = "stencil.load"(%2) {"lb" = #stencil.index<[0 : i32, 0 : i32, 0 : i32]>, "ub" = #stencil.index<[64 : i32, 64 : i32, 64 : i32]>} : (!stencil.field<[72 : i32, 72 : i32, 72 : i32], f64>) -> !stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64> + %5 = "stencil.apply"(%4) ({ + ^1(%6 : !stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64>): + %7 = "stencil.access"(%6) {"offset" = #stencil.index<[0 : i32, 0 : i32, 0 : i32]>} : (!stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64>) -> f64 + %8 = "stencil.store_result"(%7) : (f64) -> !stencil.result + "stencil.return"(%8) : (!stencil.result) -> () + }) {"lb" = #stencil.index<[0 : i32, 0 : i32, 0 : i32]>, "ub" = #stencil.index<[64 : i32, 64 : i32, 64 : i32]>} : (!stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64>) -> !stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64> + "stencil.store"(%5, %3) {"lb" = #stencil.index<[0 : i32, 0 : i32, 0 : i32]>, "ub" = #stencil.index<[64 : i32, 64 : i32, 64 : i32]>} : (!stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64>, !stencil.field<[72 : i32, 72 : i32, 72 : i32], f64>) -> () + "func.return"() : () -> () + }) {"sym_name" = "stencil_copy", "function_type" = (!stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>, !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>) -> (), "sym_visibility" = "private"} : () -> () +}) : () -> () + +// CHECK-NEXT: "builtin.module"() ({ +// CHECK-NEXT: "func.func"() ({ +// CHECK-NEXT: ^0(%0 : !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>, %1 : !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>): +// CHECK-NEXT: %2 = "stencil.cast"(%0) {"lb" = #stencil.index<[-4 : i32, -4 : i32, -4 : i32]>, "ub" = #stencil.index<[68 : i32, 68 : i32, 68 : i32]>} : (!stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>) -> !stencil.field<[72 : i32, 72 : i32, 72 : i32], f64> +// CHECK-NEXT: %3 = "stencil.cast"(%1) {"lb" = #stencil.index<[-4 : i32, -4 : i32, -4 : i32]>, "ub" = #stencil.index<[68 : i32, 68 : i32, 68 : i32]>} : (!stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>) -> !stencil.field<[72 : i32, 72 : i32, 72 : i32], f64> +// CHECK-NEXT: %4 = "stencil.load"(%2) {"lb" = #stencil.index<[0 : i32, 0 : i32, 0 : i32]>, "ub" = #stencil.index<[64 : i32, 64 : i32, 64 : i32]>} : (!stencil.field<[72 : i32, 72 : i32, 72 : i32], f64>) -> !stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64> +// CHECK-NEXT: %5 = "stencil.apply"(%4) ({ +// CHECK-NEXT: ^1(%6 : !stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64>): +// CHECK-NEXT: %7 = "stencil.access"(%6) {"offset" = #stencil.index<[0 : i32, 0 : i32, 0 : i32]>} : (!stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64>) -> f64 +// CHECK-NEXT: %8 = "stencil.store_result"(%7) : (f64) -> !stencil.result +// CHECK-NEXT: "stencil.return"(%8) : (!stencil.result) -> () +// CHECK-NEXT: }) {"lb" = #stencil.index<[0 : i32, 0 : i32, 0 : i32]>, "ub" = #stencil.index<[64 : i32, 64 : i32, 64 : i32]>} : (!stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64>) -> !stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64> +// CHECK-NEXT: "stencil.store"(%5, %3) {"lb" = #stencil.index<[0 : i32, 0 : i32, 0 : i32]>, "ub" = #stencil.index<[64 : i32, 64 : i32, 64 : i32]>} : (!stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64>, !stencil.field<[72 : i32, 72 : i32, 72 : i32], f64>) -> () +// CHECK-NEXT: "func.return"() : () -> () +// CHECK-NEXT: }) {"sym_name" = "stencil_copy", "function_type" = (!stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>, !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>) -> (), "sym_visibility" = "private"} : () -> () +// CHECK-NEXT: }) : () -> () diff --git a/tests/filecheck/dialects/stencil/hdiff.mlir b/tests/filecheck/dialects/stencil/hdiff.mlir new file mode 100644 index 0000000000..71defdd636 --- /dev/null +++ b/tests/filecheck/dialects/stencil/hdiff.mlir @@ -0,0 +1,57 @@ +// RUN: xdsl-opt %s -t mlir | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + ^0(%0 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, %1 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, %2 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>): + %3 = "stencil.cast"(%0) {"lb" = #stencil.index<[-4 : i64, -4 : i64, -4 : i64]>, "ub" = #stencil.index<[68 : i64, 68 : i64, 68 : i64]>} : (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64> + %4 = "stencil.cast"(%1) {"lb" = #stencil.index<[-4 : i64, -4 : i64, -4 : i64]>, "ub" = #stencil.index<[68 : i64, 68 : i64, 68 : i64]>} : (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64> + %5 = "stencil.cast"(%2) {"lb" = #stencil.index<[-4 : i64, -4 : i64, -4 : i64]>, "ub" = #stencil.index<[68 : i64, 68 : i64, 68 : i64]>} : (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64> + %6 = "stencil.load"(%3) : (!stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>) -> !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64> + %7 = "stencil.load"(%4) : (!stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>) -> !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64> + %8 = "stencil.apply"(%6) ({ + ^1(%9 : !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>): + %10 = "stencil.access"(%9) {"offset" = #stencil.index<[-1 : i64, 0 : i64, 0 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> f64 + %11 = "stencil.access"(%9) {"offset" = #stencil.index<[1 : i64, 0 : i64, 0 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> f64 + %12 = "stencil.access"(%9) {"offset" = #stencil.index<[0 : i64, 1 : i64, 0 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> f64 + %13 = "stencil.access"(%9) {"offset" = #stencil.index<[0 : i64, -1 : i64, 0 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> f64 + %14 = "stencil.access"(%9) {"offset" = #stencil.index<[0 : i64, 0 : i64, 0 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> f64 + %15 = "arith.addf"(%10, %11) : (f64, f64) -> f64 + %16 = "arith.addf"(%12, %13) : (f64, f64) -> f64 + %17 = "arith.addf"(%15, %16) : (f64, f64) -> f64 + %cst = "arith.constant"() {"value" = -4.0 : f32} : () -> f64 + %18 = "arith.mulf"(%14, %cst) : (f64, f64) -> f64 + %19 = "arith.addf"(%18, %17) : (f64, f64) -> f64 + %20 = "stencil.store_result"(%19) : (f64) -> !stencil.result + "stencil.return"(%20) : (!stencil.result) -> () + }) : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64> + }) {"function_type" = (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, "sym_name" = "stencil_hdiff"} : () -> () +}) : () -> () + + +// CHECK-NEXT: "builtin.module"() ({ +// CHECK-NEXT: "func.func"() ({ +// CHECK-NEXT: ^0(%0 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, %1 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, %2 : !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>): +// CHECK-NEXT: %3 = "stencil.cast"(%0) {"lb" = #stencil.index<[-4 : i64, -4 : i64, -4 : i64]>, "ub" = #stencil.index<[68 : i64, 68 : i64, 68 : i64]>} : (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64> +// CHECK-NEXT: %4 = "stencil.cast"(%1) {"lb" = #stencil.index<[-4 : i64, -4 : i64, -4 : i64]>, "ub" = #stencil.index<[68 : i64, 68 : i64, 68 : i64]>} : (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64> +// CHECK-NEXT: %5 = "stencil.cast"(%2) {"lb" = #stencil.index<[-4 : i64, -4 : i64, -4 : i64]>, "ub" = #stencil.index<[68 : i64, 68 : i64, 68 : i64]>} : (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.field<[72 : i64, 72 : i64, 72 : i64], f64> +// CHECK-NEXT: %6 = "stencil.load"(%3) : (!stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>) -> !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64> +// CHECK-NEXT: %7 = "stencil.load"(%4) : (!stencil.field<[72 : i64, 72 : i64, 72 : i64], f64>) -> !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64> +// CHECK-NEXT: %8 = "stencil.apply"(%6) ({ +// CHECK-NEXT: ^1(%9 : !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>): +// CHECK-NEXT: %10 = "stencil.access"(%9) {"offset" = #stencil.index<[-1 : i64, 0 : i64, 0 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> f64 +// CHECK-NEXT: %11 = "stencil.access"(%9) {"offset" = #stencil.index<[1 : i64, 0 : i64, 0 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> f64 +// CHECK-NEXT: %12 = "stencil.access"(%9) {"offset" = #stencil.index<[0 : i64, 1 : i64, 0 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> f64 +// CHECK-NEXT: %13 = "stencil.access"(%9) {"offset" = #stencil.index<[0 : i64, -1 : i64, 0 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> f64 +// CHECK-NEXT: %14 = "stencil.access"(%9) {"offset" = #stencil.index<[0 : i64, 0 : i64, 0 : i64]>} : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> f64 +// CHECK-NEXT: %15 = "arith.addf"(%10, %11) : (f64, f64) -> f64 +// CHECK-NEXT: %16 = "arith.addf"(%12, %13) : (f64, f64) -> f64 +// CHECK-NEXT: %17 = "arith.addf"(%15, %16) : (f64, f64) -> f64 +// CHECK-NEXT: %cst = "arith.constant"() {"value" = -4.0 : f32} : () -> f64 +// CHECK-NEXT: %18 = "arith.mulf"(%14, %cst) : (f64, f64) -> f64 +// CHECK-NEXT: %19 = "arith.addf"(%18, %17) : (f64, f64) -> f64 +// CHECK-NEXT: %20 = "stencil.store_result"(%19) : (f64) -> !stencil.result +// CHECK-NEXT: "stencil.return"(%20) : (!stencil.result) -> () +// CHECK-NEXT: }) : (!stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.temp<[-1 : i64, -1 : i64, -1 : i64], f64> +// CHECK-NEXT: }) {"function_type" = (!stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>) -> !stencil.field<[-1 : i64, -1 : i64, -1 : i64], f64>, "sym_name" = "stencil_hdiff"} : () -> () +// CHECK-NEXT: }) : () -> () + diff --git a/tests/filecheck/dialects/stencil/laplace.mlir b/tests/filecheck/dialects/stencil/laplace.mlir new file mode 100644 index 0000000000..9dab46022e --- /dev/null +++ b/tests/filecheck/dialects/stencil/laplace.mlir @@ -0,0 +1,55 @@ +// RUN: xdsl-opt %s -t mlir | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + ^0(%0 : !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>, %1 : !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>): + %2 = "stencil.cast"(%0) {"lb" = #stencil.index<[-4 : i32, -4 : i32, -4 : i32]>, "ub" = #stencil.index<[68 : i32, 68 : i32, 68 : i32]>} : (!stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>) -> !stencil.field<[72 : i32, 72 : i32, 72 : i32], f64> + %3 = "stencil.cast"(%1) {"lb" = #stencil.index<[-4 : i32, -4 : i32, -4 : i32]>, "ub" = #stencil.index<[68 : i32, 68 : i32, 68 : i32]>} : (!stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>) -> !stencil.field<[72 : i32, 72 : i32, 72 : i32], f64> + %4 = "stencil.load"(%2) {"lb" = #stencil.index<[-1 : i32, -1 : i32, 0 : i32]>, "ub" = #stencil.index<[65 : i32, 65 : i32, 64 : i32]>} : (!stencil.field<[72 : i32, 72 : i32, 72 : i32], f64>) -> !stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64> + %5 = "stencil.apply"(%4) ({ + ^1(%6 : !stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64>): + %7 = "stencil.access"(%6) {"offset" = #stencil.index<[-1 : i32, 0 : i32, 0 : i32]>} : (!stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64>) -> f64 + %8 = "stencil.access"(%6) {"offset" = #stencil.index<[1 : i32, 0 : i32, 0 : i32]>} : (!stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64>) -> f64 + %9 = "stencil.access"(%6) {"offset" = #stencil.index<[0 : i32, 1 : i32, 0 : i32]>} : (!stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64>) -> f64 + %10 = "stencil.access"(%6) {"offset" = #stencil.index<[0 : i32, -1 : i32, 0 : i32]>} : (!stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64>) -> f64 + %11 = "stencil.access"(%6) {"offset" = #stencil.index<[0 : i32, 0 : i32, 0 : i32]>} : (!stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64>) -> f64 + %12 = "arith.addf"(%7, %8) : (f64, f64) -> f64 + %13 = "arith.addf"(%9, %10) : (f64, f64) -> f64 + %14 = "arith.addf"(%12, %13) : (f64, f64) -> f64 + %15 = "arith.constant"() {"value" = -4.0 : f64} : () -> f64 + %16 = "arith.mulf"(%11, %15) : (f64, f64) -> f64 + %17 = "arith.mulf"(%16, %13) : (f64, f64) -> f64 + %18 = "stencil.store_result"(%17) : (f64) -> !stencil.result + "stencil.return"(%18) : (!stencil.result) -> () + }) {"lb" = #stencil.index<[0 : i32, 0 : i32, 0 : i32]>, "ub" = #stencil.index<[64 : i32, 64 : i32, 64 : i32]>} : (!stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64>) -> !stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64> + "stencil.store"(%5, %3) {"lb" = #stencil.index<[0 : i32, 0 : i32, 0 : i32]>, "ub" = #stencil.index<[64 : i32, 64 : i32, 64 : i32]>} : (!stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64>, !stencil.field<[72 : i32, 72 : i32, 72 : i32], f64>) -> () + "func.return"() : () -> () + }) {"sym_name" = "stencil_laplace", "function_type" = (!stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>, !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>) -> (), "sym_visibility" = "private"} : () -> () +}) : () -> () + +// CHECK-NEXT: "builtin.module"() ({ +// CHECK-NEXT: "func.func"() ({ +// CHECK-NEXT: ^0(%0 : !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>, %1 : !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>): +// CHECK-NEXT: %2 = "stencil.cast"(%0) {"lb" = #stencil.index<[-4 : i32, -4 : i32, -4 : i32]>, "ub" = #stencil.index<[68 : i32, 68 : i32, 68 : i32]>} : (!stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>) -> !stencil.field<[72 : i32, 72 : i32, 72 : i32], f64> +// CHECK-NEXT: %3 = "stencil.cast"(%1) {"lb" = #stencil.index<[-4 : i32, -4 : i32, -4 : i32]>, "ub" = #stencil.index<[68 : i32, 68 : i32, 68 : i32]>} : (!stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>) -> !stencil.field<[72 : i32, 72 : i32, 72 : i32], f64> +// CHECK-NEXT: %4 = "stencil.load"(%2) {"lb" = #stencil.index<[-1 : i32, -1 : i32, 0 : i32]>, "ub" = #stencil.index<[65 : i32, 65 : i32, 64 : i32]>} : (!stencil.field<[72 : i32, 72 : i32, 72 : i32], f64>) -> !stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64> +// CHECK-NEXT: %5 = "stencil.apply"(%4) ({ +// CHECK-NEXT: ^1(%6 : !stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64>): +// CHECK-NEXT: %7 = "stencil.access"(%6) {"offset" = #stencil.index<[-1 : i32, 0 : i32, 0 : i32]>} : (!stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64>) -> f64 +// CHECK-NEXT: %8 = "stencil.access"(%6) {"offset" = #stencil.index<[1 : i32, 0 : i32, 0 : i32]>} : (!stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64>) -> f64 +// CHECK-NEXT: %9 = "stencil.access"(%6) {"offset" = #stencil.index<[0 : i32, 1 : i32, 0 : i32]>} : (!stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64>) -> f64 +// CHECK-NEXT: %10 = "stencil.access"(%6) {"offset" = #stencil.index<[0 : i32, -1 : i32, 0 : i32]>} : (!stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64>) -> f64 +// CHECK-NEXT: %11 = "stencil.access"(%6) {"offset" = #stencil.index<[0 : i32, 0 : i32, 0 : i32]>} : (!stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64>) -> f64 +// CHECK-NEXT: %12 = "arith.addf"(%7, %8) : (f64, f64) -> f64 +// CHECK-NEXT: %13 = "arith.addf"(%9, %10) : (f64, f64) -> f64 +// CHECK-NEXT: %14 = "arith.addf"(%12, %13) : (f64, f64) -> f64 +// CHECK-NEXT: %15 = "arith.constant"() {"value" = -4.0 : f64} : () -> f64 +// CHECK-NEXT: %16 = "arith.mulf"(%11, %15) : (f64, f64) -> f64 +// CHECK-NEXT: %17 = "arith.mulf"(%16, %13) : (f64, f64) -> f64 +// CHECK-NEXT: %18 = "stencil.store_result"(%17) : (f64) -> !stencil.result +// CHECK-NEXT: "stencil.return"(%18) : (!stencil.result) -> () +// CHECK-NEXT: }) {"lb" = #stencil.index<[0 : i32, 0 : i32, 0 : i32]>, "ub" = #stencil.index<[64 : i32, 64 : i32, 64 : i32]>} : (!stencil.temp<[66 : i32, 66 : i32, 64 : i32], f64>) -> !stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64> +// CHECK-NEXT: "stencil.store"(%5, %3) {"lb" = #stencil.index<[0 : i32, 0 : i32, 0 : i32]>, "ub" = #stencil.index<[64 : i32, 64 : i32, 64 : i32]>} : (!stencil.temp<[64 : i32, 64 : i32, 64 : i32], f64>, !stencil.field<[72 : i32, 72 : i32, 72 : i32], f64>) -> () +// CHECK-NEXT: "func.return"() : () -> () +// CHECK-NEXT: }) {"sym_name" = "stencil_laplace", "function_type" = (!stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>, !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>) -> (), "sym_visibility" = "private"} : () -> () +// CHECK-NEXT: }) : () -> () diff --git a/tests/filecheck/dialects/stencil/test_funcop_lowering.mlir b/tests/filecheck/dialects/stencil/test_funcop_lowering.mlir new file mode 100644 index 0000000000..fc154ea544 --- /dev/null +++ b/tests/filecheck/dialects/stencil/test_funcop_lowering.mlir @@ -0,0 +1,16 @@ +// RUN: xdsl-opt %s -t mlir -p convert-stencil-to-ll-mlir | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + ^0(%0 : !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>, %1 : !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>): + "func.return"() : () -> () + }) {"sym_name" = "test_funcop_lowering", "function_type" = (!stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>, !stencil.field<[-1 : i32, -1 : i32, -1 : i32], f64>) -> (), "sym_visibility" = "private"} : () -> () +}) : () -> () + +// CHECK-NEXT: "builtin.module"() ({ +// CHECK-NEXT: "func.func"() ({ +// CHECK-NEXT: ^0(%0 : memref, %1 : memref): +// CHECK-NEXT: "func.return"() : () -> () +// CHECK-NEXT: }) {"sym_name" = "test_funcop_lowering", "function_type" = (memref, memref) -> (), "sym_visibility" = "private"} : () -> () +// CHECK-NEXT: }) : () -> () + diff --git a/tests/xdsl_opt/test_xdsl_opt.py b/tests/xdsl_opt/test_xdsl_opt.py index 4879a92388..50459a117d 100644 --- a/tests/xdsl_opt/test_xdsl_opt.py +++ b/tests/xdsl_opt/test_xdsl_opt.py @@ -12,7 +12,9 @@ def test_opt(): opt = xDSLOptMain(args=[]) assert list(opt.available_frontends.keys()) == ['xdsl', 'mlir'] assert list(opt.available_targets.keys()) == ['xdsl', 'irdl', 'mlir'] - assert list(opt.available_passes.keys()) == ['lower-mpi'] + assert list(opt.available_passes.keys()) == [ + 'lower-mpi', 'convert-stencil-to-ll-mlir' + ] def test_empty_program(): diff --git a/xdsl/dialects/experimental/stencil.py b/xdsl/dialects/experimental/stencil.py index 100670869a..7530bc1258 100644 --- a/xdsl/dialects/experimental/stencil.py +++ b/xdsl/dialects/experimental/stencil.py @@ -1,15 +1,18 @@ from __future__ import annotations from dataclasses import dataclass -from typing import cast, Any +from typing import Annotated, TypeVar, Any, cast -from xdsl.dialects.builtin import (ParametrizedAttribute, ArrayAttr, f32, f64, - IntegerType, IndexType, IntAttr, AnyFloat) +from xdsl.dialects.builtin import (AnyIntegerAttr, ParametrizedAttribute, + ArrayAttr, f32, f64, IntegerType, IntAttr, + AnyFloat) +from xdsl.dialects import builtin from xdsl.ir import Operation, Dialect, MLIRType from xdsl.irdl import (irdl_attr_definition, irdl_op_definition, ParameterDef, AttrConstraint, Attribute, Region, VerifyException, - AnyOf, Annotated, Operand, OpAttr, OpResult, VarOperand, - VarOpResult, OptOpAttr, AttrSizedOperandSegments) + Generic, AnyOf, Annotated, Operand, OpAttr, OpResult, + VarOperand, VarOpResult, OptOpAttr, + AttrSizedOperandSegments) @dataclass @@ -28,14 +31,19 @@ def verify(self, attr: Attribute) -> None: ) +_FieldTypeElement = TypeVar("_FieldTypeElement", bound=Attribute) + + @irdl_attr_definition -class FieldType(ParametrizedAttribute, MLIRType): +class FieldType(Generic[_FieldTypeElement], ParametrizedAttribute, MLIRType): name = "stencil.field" - shape: ParameterDef[ArrayAttr[IntAttr]] + shape: ParameterDef[ArrayAttr[AnyIntegerAttr]] + element_type: ParameterDef[_FieldTypeElement] @staticmethod - def from_shape(shape: list[int] | list[IntAttr]) -> FieldType: + def from_shape( + shape: list[int] | list[IntAttr]) -> FieldType[_FieldTypeElement]: # TODO: why do we need all these casts here, can we tell pyright "trust me" if all(isinstance(elm, IntAttr) for elm in shape): shape = cast(list[IntAttr], shape) @@ -46,18 +54,20 @@ def from_shape(shape: list[int] | list[IntAttr]) -> FieldType: @irdl_attr_definition -class TempType(ParametrizedAttribute, MLIRType): +class TempType(Generic[_FieldTypeElement], ParametrizedAttribute, MLIRType): name = "stencil.temp" - shape: ParameterDef[ArrayAttr[IntAttr]] + shape: ParameterDef[ArrayAttr[AnyIntegerAttr]] + element_type: ParameterDef[_FieldTypeElement] @staticmethod def from_shape( - shape: ArrayAttr[IntAttr] | list[IntAttr] | list[int]) -> TempType: + shape: ArrayAttr[IntAttr] | list[IntAttr] | list[int] + ) -> TempType[_FieldTypeElement]: assert len(shape) > 0 if isinstance(shape, ArrayAttr): - return TempType([shape]) + return TempType.new([shape]) # cast to list shape = cast(list[IntAttr] | list[int], shape) @@ -71,7 +81,7 @@ def from_shape( def __repr__(self): repr: str = "stencil.Temp<[" for size in self.shape.data: - repr += f"{size.data} " + repr += f"{size.value.data} " repr += "]>" return repr @@ -103,27 +113,34 @@ def verify(self, attr: Attribute) -> None: # TODO: How can we inherit from MLIRType and ParametrizedAttribute? @dataclass(frozen=True) -class Stencil_Element(ParametrizedAttribute): +class ElementType(ParametrizedAttribute): name = "stencil.element" element = AnyOf([f32, f64]) -@dataclass(frozen=True) -class Stencil_Index(ParametrizedAttribute): +@irdl_attr_definition +class IndexAttr(ParametrizedAttribute): # TODO: can you have an attr and an op with the same name? name = "stencil.index" - shape = Annotated[ArrayAttr[IntAttr], ArrayLength(2)] + + array: ParameterDef[ArrayAttr[AnyIntegerAttr]] + + def verify(self) -> None: + if len(self.array.data) != 3: + raise VerifyException( + f"Expected 3 indexes for stencil.index, got {len(self.array.data)}." + ) @dataclass(frozen=True) -class Stencil_Loop(ParametrizedAttribute): +class LoopAttr(ParametrizedAttribute): name = "stencil.loop" shape = Annotated[ArrayAttr[IntAttr], ArrayLength(4)] # Operations @irdl_op_definition -class Cast(Operation): +class CastOp(Operation): """ This operation casts dynamically shaped input fields to statically shaped fields. @@ -132,14 +149,14 @@ class Cast(Operation): """ name: str = "stencil.cast" field: Annotated[Operand, FieldType] - lb: OptOpAttr[Stencil_Index] - ub: OptOpAttr[Stencil_Index] + lb: OptOpAttr[IndexAttr] + ub: OptOpAttr[IndexAttr] result: Annotated[OpResult, FieldType] # Operations @irdl_op_definition -class External_Load(Operation): +class ExternalLoadOp(Operation): """ This operation loads from an external field type, e.g. to bring data into the stencil @@ -152,7 +169,7 @@ class External_Load(Operation): @irdl_op_definition -class External_Store(Operation): +class ExternalStoreOp(Operation): """ This operation takes a stencil field and then stores this to an external type @@ -165,7 +182,7 @@ class External_Store(Operation): @irdl_op_definition -class Index(Operation): +class IndexOp(Operation): """ This operation returns the index of the current loop iteration for the chosen direction (0, 1, or 2). @@ -176,12 +193,12 @@ class Index(Operation): """ name: str = "stencil.index" dim: OpAttr[IntegerType] - offset: OpAttr[Stencil_Index] - idx: Annotated[OpResult, IndexType] + offset: OpAttr[IndexAttr] + idx: Annotated[OpResult, builtin.IndexType] @irdl_op_definition -class Access(Operation): +class AccessOp(Operation): """ This operation accesses a temporary element given a constant offset. The offset is specified relative to the current position. @@ -191,12 +208,12 @@ class Access(Operation): """ name: str = "stencil.access" temp: Annotated[Operand, TempType] - offset: OpAttr[ArrayAttr[IntAttr]] + offset: OpAttr[IndexAttr] res: Annotated[OpResult, Attribute] @irdl_op_definition -class DynAccess(Operation): +class DynAccessOp(Operation): """ This operation accesses a temporary element given a dynamic offset. The offset is specified in absolute coordinates. An additional @@ -208,14 +225,14 @@ class DynAccess(Operation): """ name: str = "stencil.dyn_access" temp: Annotated[Operand, TempType] - offset: OpAttr[Stencil_Index] - lb: OpAttr[Stencil_Index] - ub: OpAttr[Stencil_Index] - res: Annotated[OpResult, Stencil_Element] + offset: OpAttr[IndexAttr] + lb: OpAttr[IndexAttr] + ub: OpAttr[IndexAttr] + res: Annotated[OpResult, ElementType] @irdl_op_definition -class Load(Operation): +class LoadOp(Operation): """ This operation takes a field and returns a temporary values. @@ -224,13 +241,13 @@ class Load(Operation): """ name: str = "stencil.load" field: Annotated[Operand, FieldType] - lb: OptOpAttr[Stencil_Index] - ub: OptOpAttr[Stencil_Index] + lb: OptOpAttr[IndexAttr] + ub: OptOpAttr[IndexAttr] res: Annotated[OpResult, TempType] @irdl_op_definition -class Buffer(Operation): +class BufferOp(Operation): """ Prevents fusion of consecutive stencil.apply operations. @@ -239,13 +256,13 @@ class Buffer(Operation): """ name: str = "stencil.buffer" temp: Annotated[Operand, TempType] - lb: OpAttr[Stencil_Index] - ub: OpAttr[Stencil_Index] + lb: OpAttr[IndexAttr] + ub: OpAttr[IndexAttr] res: Annotated[OpResult, TempType] @irdl_op_definition -class Store(Operation): +class StoreOp(Operation): """ This operation takes a temp and writes a field on a user defined range. @@ -255,12 +272,12 @@ class Store(Operation): name: str = "stencil.store" temp: Annotated[Operand, TempType] field: Annotated[Operand, FieldType] - lb: OptOpAttr[Stencil_Index] - ub: OptOpAttr[Stencil_Index] + lb: OptOpAttr[IndexAttr] + ub: OptOpAttr[IndexAttr] @irdl_op_definition -class Apply(Operation): +class ApplyOp(Operation): """ This operation takes a stencil function plus parameters and applies the stencil function to the output temp. @@ -272,15 +289,15 @@ class Apply(Operation): } """ name: str = "stencil.apply" - args: Annotated[VarOperand, FieldType] - lb: OptOpAttr[Stencil_Index] - ub: OptOpAttr[Stencil_Index] + args: Annotated[VarOperand, TempType] + lb: OptOpAttr[IndexAttr] + ub: OptOpAttr[IndexAttr] region: Region - res: Annotated[VarOpResult, FieldType] + res: Annotated[VarOpResult, TempType] @irdl_op_definition -class StoreResult(Operation): +class StoreResultOp(Operation): """ The store_result operation either stores an operand value or nothing. @@ -294,7 +311,7 @@ class StoreResult(Operation): @irdl_op_definition -class Return(Operation): +class ReturnOp(Operation): """ The return operation terminates the stencil apply and writes the results of the stencil operator to the temporary values returned @@ -312,7 +329,7 @@ class Return(Operation): @irdl_op_definition -class Combine(Operation): +class CombineOp(Operation): """ Combines the results computed on a lower with the results computed on an upper domain. The operation combines the domain at a given index/offset @@ -334,8 +351,8 @@ class Combine(Operation): lower_ext: Annotated[VarOperand, TempType] upper_ext: Annotated[VarOperand, TempType] - lb = OptOpAttr[Stencil_Index] - ub = OptOpAttr[Stencil_Index] + lb: OptOpAttr[IndexAttr] + ub: OptOpAttr[IndexAttr] region: Region res: VarOpResult @@ -343,25 +360,25 @@ class Combine(Operation): irdl_options = [AttrSizedOperandSegments()] -Dialect([ - Cast, - External_Load, - External_Store, - Index, - Access, - DynAccess, - Load, - Buffer, - Store, - Apply, - StoreResult, - Return, - Combine, +Stencil = Dialect([ + CastOp, + ExternalLoadOp, + ExternalStoreOp, + IndexOp, + AccessOp, + DynAccessOp, + LoadOp, + BufferOp, + StoreOp, + ApplyOp, + StoreResultOp, + ReturnOp, + CombineOp, ], [ FieldType, TempType, ResultType, - Stencil_Element, - Stencil_Index, - Stencil_Loop, + ElementType, + IndexAttr, + LoopAttr, ]) diff --git a/xdsl/transforms/experimental/ConvertStencilToLLMLIR.py b/xdsl/transforms/experimental/ConvertStencilToLLMLIR.py new file mode 100644 index 0000000000..9b1788f38d --- /dev/null +++ b/xdsl/transforms/experimental/ConvertStencilToLLMLIR.py @@ -0,0 +1,68 @@ +from typing import TypeVar + +from xdsl.pattern_rewriter import (PatternRewriter, PatternRewriteWalker, + RewritePattern, GreedyRewritePatternApplier) +from xdsl.ir import MLContext, Operation +from xdsl.irdl import Attribute +from xdsl.dialects.builtin import ArrayAttr, FunctionType, IntegerAttr, ModuleOp, AnyIntegerAttr, IndexType +from xdsl.dialects.func import FuncOp +from xdsl.dialects.memref import MemRefType + +from xdsl.dialects.experimental.stencil import FieldType, IndexAttr + +_TypeElement = TypeVar("_TypeElement", bound=Attribute) + + +def GetMemRefFromField( + inputFieldType: FieldType[_TypeElement]) -> MemRefType[_TypeElement]: + memref_shape_integer_attr_list: list[AnyIntegerAttr] = [] + for i in inputFieldType.shape.data: + memref_shape_integer_attr_list.append( + IntegerAttr.from_params(i.value.data, i.typ)) + + memref_shape_array_attr = ArrayAttr[AnyIntegerAttr]( + memref_shape_integer_attr_list) + + return MemRefType.from_params(inputFieldType.element_type, + memref_shape_array_attr) + + +def GetMemRefFromFieldWithLBAndUB(memref_element_type: _TypeElement, + lb: IndexAttr, + ub: IndexAttr) -> MemRefType[_TypeElement]: + memref_shape_integer_attr_list: list[AnyIntegerAttr] = [] + for i in range(len(lb.array.data)): + memref_shape_integer_attr_list.append( + IntegerAttr.from_params( + ub.array.data[i].value.data - lb.array.data[i].value.data, + IndexType())) + + memref_shape_array_attr = ArrayAttr(memref_shape_integer_attr_list) + + return MemRefType.from_params(memref_element_type, memref_shape_array_attr) + + +class StencilTypeConversionFuncOp(RewritePattern): + + def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /): + if (isinstance(op, FuncOp)): + inputs: list[Attribute] = [] + for arg in op.body.blocks[0].args: + if isinstance(arg.typ, FieldType): + typ: FieldType[Attribute] = arg.typ + memreftyp = GetMemRefFromField(typ) + rewriter.modify_block_argument_type(arg, memreftyp) + inputs.append(memreftyp) + + op.attributes["function_type"] = FunctionType( + [ArrayAttr(inputs), + ArrayAttr(op.function_type.outputs.data)]) + + +def ConvertStencilToLLMLIR(ctx: MLContext, module: ModuleOp): + walker = PatternRewriteWalker(GreedyRewritePatternApplier( + [StencilTypeConversionFuncOp()]), + walk_regions_first=True, + apply_recursively=False, + walk_reverse=True) + walker.rewrite_module(module) diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index 0af964608c..62ff727afd 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -21,6 +21,10 @@ from xdsl.transforms.lower_mpi import lower_mpi from xdsl.dialects.gpu import GPU +from xdsl.dialects.experimental.stencil import Stencil + +from xdsl.transforms.experimental.ConvertStencilToLLMLIR import ConvertStencilToLLMLIR + from xdsl.irdl_mlir_printer import IRDLPrinter from xdsl.utils.exceptions import DiagnosticException @@ -191,6 +195,7 @@ def register_all_dialects(self): self.ctx.register_dialect(Vector) self.ctx.register_dialect(MPI) self.ctx.register_dialect(GPU) + self.ctx.register_dialect(Stencil) def register_all_frontends(self): """ @@ -217,6 +222,8 @@ def register_all_passes(self): Add other/additional passes by overloading this function. """ self.available_passes['lower-mpi'] = lower_mpi + self.available_passes[ + 'convert-stencil-to-ll-mlir'] = ConvertStencilToLLMLIR def register_all_targets(self): """