Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dialects #81

Merged
merged 5 commits into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# xDSL

TODO

## Prerequisits

To install all required dependencies, execute the following command:
Expand All @@ -24,6 +22,22 @@ pytest
lit tests/filecheck
```

## Generating executables through MLIR

xDSL can generate executables using MLIR as the backend. To use this functionality, make sure to install the [MLIR Python Bindings](https://mlir.llvm.org/docs/Bindings/Python/). Given an input file `input.xdsl`, that contains IR with only the mirrored dialects found in `src/xdsl/dialects` (arith, memref, func, cf, scf, and builtin), run:

```
### Prints MLIR generic from to tmp.mlir
./src/xdsl/xdsl_opt.py -t mlir -o tmp.mlir `input.xdsl`

mlir-opt --convert-scf-to-cf --convert-cf-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-memref-to-llvm --reconcile-unrealized-casts tmp.mlir | mlir-translate --mlir-to-llvmir > tmp.ll
```

The generated `tmp.ll` file contains LLVMIR, so it can be directly passed to a compiler like clang.
Notice that a `main` function is required for clang to build. Refer to `tests/filecheck/arith_ops.test` for an example.
The functionality is tested with MLIR git commit hash: 74992f4a5bb79e2084abdef406ef2e5aa2024368


## Formatting

All python code used in xDSL use yapf to format the code in a uniform manner.
Expand Down
40 changes: 0 additions & 40 deletions src/xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __post_init__(self):
self.ctx.register_attr(IndexType)

self.ctx.register_op(ModuleOp)
self.ctx.register_op(FuncOp)


@irdl_attr_definition
Expand Down Expand Up @@ -344,45 +343,6 @@ def from_attrs(inputs: ArrayAttr, outputs: ArrayAttr) -> Attribute:
return FunctionType([inputs, outputs])


@irdl_op_definition
class FuncOp(Operation):
name: str = "builtin.func"

body = RegionDef()
sym_name = AttributeDef(StringAttr)
type = AttributeDef(FunctionType)
sym_visibility = AttributeDef(StringAttr)

@staticmethod
def from_callable(
name: str, input_types: List[Attribute],
return_types: List[Attribute],
func: Callable[[BlockArgument, ...], List[Operation]]) -> FuncOp:
type_attr = FunctionType.from_lists(input_types, return_types)
op = FuncOp.build(attributes={
"sym_name": name,
"type": type_attr,
"sym_visibility": "private"
},
regions=[
Region.from_block_list(
[Block.from_callable(input_types, func)])
])
return op

@staticmethod
def from_region(name: str, input_types: List[Attribute],
return_types: List[Attribute], region: Region) -> FuncOp:
type_attr = FunctionType.from_lists(input_types, return_types)
op = FuncOp.build(attributes={
"sym_name": name,
"type": type_attr,
"sym_visibility": "private"
},
regions=[region])
return op


@irdl_op_definition
class ModuleOp(Operation):
name: str = "module"
Expand Down
84 changes: 84 additions & 0 deletions src/xdsl/dialects/func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations
from dataclasses import dataclass

from xdsl.irdl import *
from xdsl.ir import *
from xdsl.dialects.builtin import *


@dataclass
class Func:
ctx: MLContext

def __post_init__(self):
self.ctx.register_op(FuncOp)
self.ctx.register_op(Call)
self.ctx.register_op(Return)


@irdl_op_definition
class FuncOp(Operation):
name: str = "func.func"

body = RegionDef()
sym_name = AttributeDef(StringAttr)
function_type = AttributeDef(FunctionType)
sym_visibility = AttributeDef(StringAttr)

@staticmethod
def from_callable(
name: str, input_types: List[Attribute],
return_types: List[Attribute],
func: Callable[[BlockArgument, ...], List[Operation]]) -> FuncOp:
type_attr = FunctionType.from_lists(input_types, return_types)
op = FuncOp.build(attributes={
"sym_name": name,
"function_type": type_attr,
"sym_visibility": "private"
},
regions=[
Region.from_block_list(
[Block.from_callable(input_types, func)])
])
return op

@staticmethod
def from_region(name: str, input_types: List[Attribute],
return_types: List[Attribute], region: Region) -> FuncOp:
type_attr = FunctionType.from_lists(input_types, return_types)
op = FuncOp.build(attributes={
"sym_name": name,
"function_type": type_attr,
"sym_visibility": "private"
},
regions=[region])
return op


@irdl_op_definition
class Call(Operation):
name: str = "func.call"
arguments = VarOperandDef(AnyAttr())
callee = AttributeDef(FlatSymbolRefAttr)

# Note: naming this results triggers an ArgumentError
res = VarResultDef(AnyAttr())
# TODO how do we verify that the types are correct?

@staticmethod
def get(callee: Union[str, FlatSymbolRefAttr],
operands: List[Union[SSAValue, Operation]],
return_types: List[Attribute]) -> Call:
return Call.build(operands=operands,
result_types=return_types,
attributes={"callee": callee})


@irdl_op_definition
class Return(Operation):
name: str = "func.return"
arguments = VarOperandDef(AnyAttr())

@staticmethod
def get(*ops: Union[Operation, SSAValue]) -> Return:
return Return.build(operands=[[op for op in ops]])
43 changes: 0 additions & 43 deletions src/xdsl/dialects/std.py

This file was deleted.

117 changes: 96 additions & 21 deletions src/xdsl/xdsl_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,123 @@

import argparse
import sys
import os
from xdsl.ir import *
from xdsl.parser import *
from xdsl.printer import *
from xdsl.dialects.std import *
from xdsl.dialects.func import *
from xdsl.dialects.scf import *
from xdsl.dialects.arith import *
from xdsl.dialects.affine import *
from xdsl.dialects.memref import *
from xdsl.dialects.builtin import *


class xDSLOptMain:
ctx: MLContext
args: argparse.Namespace

def __init__(self, args: argparse.Namespace):
self.ctx = MLContext()
self.args = args

def register_all_dialects(self):
"""Register all dialects that can be used."""
builtin = Builtin(self.ctx)
func = Func(self.ctx)
arith = Arith(self.ctx)
memref = MemRef(self.ctx)
affine = Affine(self.ctx)
scf = Scf(self.ctx)

def parse_frontend(self) -> ModuleOp:
"""Parse the input file."""
if self.args.input_file is None:
f = sys.stdin
file_extension = '.xdsl'
else:
f = open(self.args.input_file, mode='r')
_, file_extension = os.path.splitext(self.args.input_file)

if file_extension == '.xdsl' or file_extension == '.test':
input_str = f.read()
parser = Parser(self.ctx, input_str)
module = parser.parse_op()
if not self.args.disable_verify:
module.verify()
if not (isinstance(module, ModuleOp)):
raise Exception(
"Expected module or program as toplevel operation")
return module

raise Exception(f"Unrecognized file extension '{file_extension}'")

def output_resulting_program(self, prog: ModuleOp) -> str:
"""Get the resulting program."""
output = StringIO()
if self.args.target == 'xdsl':
printer = Printer(stream=output)
printer.print_op(prog)
return output.getvalue()
if self.args.target == 'mlir':
try:
from xdsl.mlir_converter import MLIRConverter
except ImportError as ex:
raise Exception(
"Can only emit mlir if the mlir bindings are present"
) from ex
converter = MLIRConverter(self.ctx)
mlir_module = converter.convert_module(prog)
print(mlir_module, file=output)
return output.getvalue()
raise Exception(f"Unknown target {self.args.target}")

def print_to_output_stream(self, contents: str):
"""Print the contents in the expected stream."""
if self.args.output_file is None:
print(contents)
else:
output_stream = open(self.args.output_file, 'w')
output_stream.write(contents)


arg_parser = argparse.ArgumentParser(
description='MLIR modular optimizer driver')
arg_parser.add_argument("-f",
arg_parser.add_argument("input_file",
type=str,
required=False,
nargs="?",
help="path to input file")

arg_parser.add_argument("-t",
"--target",
type=str,
required=False,
choices=["xdsl", "mlir"],
help="target",
default="xdsl")

arg_parser.add_argument("--disable-verify", default=False, action='store_true')
arg_parser.add_argument("-o",
"--output-file",
type=str,
required=False,
help="path to output file")


def __main__(input_str: str):
ctx = MLContext()
builtin = Builtin(ctx)
std = Std(ctx)
arith = Arith(ctx)
affine = Affine(ctx)
scf = Scf(ctx)
def __main__(args: argparse.Namespace):
xdsl_main = xDSLOptMain(args)

parser = Parser(ctx, input_str)
module = parser.parse_op()
module.verify()
xdsl_main.register_all_dialects()
module = xdsl_main.parse_frontend()

printer = Printer()
printer.print_op(module)
contents = xdsl_main.output_resulting_program(module)
xdsl_main.print_to_output_stream(contents)


def main():
args = arg_parser.parse_args()
if not args.f:
input_str = sys.stdin.read()
else:
f = open(args.f, mode='r')
input_str = f.read()

__main__(input_str)
__main__(args)


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions tests/affine_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from xdsl.dialects.builtin import *
from xdsl.dialects.std import *
from xdsl.dialects.func import *
from xdsl.dialects.arith import *
from xdsl.printer import Printer
from xdsl.dialects.affine import *


def get_example_affine_program(ctx: MLContext, builtin: Builtin, std: Std,
def get_example_affine_program(ctx: MLContext, builtin: Builtin, func: Func,
affine: Affine) -> Operation:

def affine_mm(arg0: BlockArgument, arg1: BlockArgument,
Expand Down Expand Up @@ -35,15 +35,15 @@ def affine_mm(arg0: BlockArgument, arg1: BlockArgument,
def test_affine():
ctx = MLContext()
builtin = Builtin(ctx)
std = Std(ctx)
func = Func(ctx)
arith = Arith(ctx)
affine = Affine(ctx)

test_empty = new_op("test_empty", 0, 0, 0)
ctx.register_op(test_empty)
op = test_empty()

f = get_example_affine_program(ctx, builtin, std, affine)
f = get_example_affine_program(ctx, builtin, func, affine)
f.verify()
printer = Printer()
printer.print_op(f)
Loading