Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (csl) add modules #2602

Merged
merged 7 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/filecheck/backend/csl/print_csl.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// RUN: xdsl-opt -t csl %s | filecheck %s

"csl.module"() <{kind=#csl<module_kind program>}> ({

"memref.global"() {"sym_name" = "A", "type" = memref<24xf32>, "sym_visibility" = "public", "initial_value" = dense<0> : tensor<1xindex>} : () -> ()
"memref.global"() {"sym_name" = "x", "type" = memref<6xf32>, "sym_visibility" = "public", "initial_value" = dense<0> : tensor<1xindex>} : () -> ()
"memref.global"() {"sym_name" = "b", "type" = memref<4xf32>, "sym_visibility" = "public", "initial_value" = dense<0> : tensor<1xindex>} : () -> ()
Expand Down Expand Up @@ -57,6 +59,7 @@ csl.func @initialize() {

csl.return
}
}) {sym_name = "program"} : () -> ()


// CHECK: //unknown op Global("memref.global"() <{"sym_name" = "A", "sym_visibility" = "public", "type" = memref<24xf32>, "initial_value" = dense<0> : tensor<1xindex>}> : () -> ())
Expand Down
23 changes: 19 additions & 4 deletions tests/filecheck/dialects/csl/ops.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
// RUN: XDSL_ROUNDTRIP

"csl.module"() <{kind = #csl<module_kind program>}> ({

%thing = "csl.import_module"() <{module = "<thing>"}> : () -> !csl.imported_module

csl.func @initialize() {

%lb, %ub = "test.op"() : () -> (i16, i16)

%thing = "csl.import_module"() <{module = "<thing>"}> : () -> !csl.imported_module

"csl.member_call"(%thing, %lb, %ub) <{field = "some_func", operandSegmentSizes = array<i32: 1, 2>}> : (!csl.imported_module, i16, i16) -> ()

%res = "csl.member_call"(%thing, %lb, %ub) <{field = "some_func", operandSegmentSizes = array<i32: 1, 2>}> : (!csl.imported_module, i16, i16) -> (i32)
Expand All @@ -21,12 +23,19 @@ csl.func @initialize() {

csl.return
}
}) {sym_name = "program"} : () -> ()

"csl.module"() <{kind = #csl<module_kind layout>}> ({
csl.layout {
}
}) {sym_name = "layout"} : () -> ()


// CHECK-NEXT: builtin.module {
// CHECK-NEXT: csl.func @initialize() {
// CHECK-NEXT: "csl.module"() <{"kind" = #csl<module_kind program>}> ({
// CHECK-NEXT: %thing = "csl.import_module"() <{"module" = "<thing>"}> : () -> !csl.imported_module
// CHECK-NEXT: csl.func @initialize() {
// CHECK-NEXT: %lb, %ub = "test.op"() : () -> (i16, i16)
// CHECK-NEXT: %thing = "csl.import_module"() <{"module" = "<thing>"}> : () -> !csl.imported_module
// CHECK-NEXT: "csl.member_call"(%thing, %lb, %ub) <{"field" = "some_func", "operandSegmentSizes" = array<i32: 1, 2>}> : (!csl.imported_module, i16, i16) -> ()
// CHECK-NEXT: %res = "csl.member_call"(%thing, %lb, %ub) <{"field" = "some_func", "operandSegmentSizes" = array<i32: 1, 2>}> : (!csl.imported_module, i16, i16) -> i32
// CHECK-NEXT: %0 = "csl.member_access"(%thing) <{"field" = "some_field"}> : (!csl.imported_module) -> !csl.comptime_struct
Expand All @@ -37,4 +46,10 @@ csl.func @initialize() {
// CHECK-NEXT: %col = "test.op"() : () -> !csl.color
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
// CHECK-NEXT: }) {"sym_name" = "program"} : () -> ()
// CHECK-NEXT: "csl.module"() <{"kind" = #csl<module_kind layout>}> ({
// CHECK-NEXT: csl.layout {
// CHECK-NEXT: ^0:
// CHECK-NEXT: }
// CHECK-NEXT: }) {"sym_name" = "layout"} : () -> ()
// CHECK-NEXT: }
4 changes: 3 additions & 1 deletion xdsl/backend/csl/print_csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,6 @@ def print_to_csl(prog: ModuleOp, output: IO[str]):
Takes a module op and prints it to the given output stream.
"""
ctx = CslPrintContext(output)
ctx.print_block(prog.body.block)
for mod in prog.body.block.ops:
assert isinstance(mod, csl.CslModuleOp)
ctx.print_block(mod.body.block)
112 changes: 110 additions & 2 deletions xdsl/dialects/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass
from typing import TypeAlias

from xdsl.dialects import func
Expand All @@ -18,6 +19,7 @@
ContainerType,
DictionaryAttr,
FunctionType,
ModuleOp,
StringAttr,
)
from xdsl.dialects.utils import parse_func_op_like, print_func_op_like
Expand All @@ -37,6 +39,7 @@
IRDLOperation,
ParameterDef,
ParametrizedAttribute,
attr_def,
irdl_attr_definition,
irdl_op_definition,
operand_def,
Expand All @@ -50,7 +53,14 @@
)
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.traits import HasParent, IsTerminator, SymbolOpInterface
from xdsl.traits import (
HasParent,
IsolatedFromAbove,
IsTerminator,
NoTerminator,
OpTrait,
SymbolOpInterface,
)
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.str_enum import StrEnum

Expand All @@ -65,6 +75,44 @@ class PtrConst(StrEnum):
VAR = "var"


class ModuleKind(StrEnum):
LAYOUT = "layout"
PROGRAM = "program"


@dataclass(frozen=True)
class InModuleKind(OpTrait):
"""
Constrain an op to a particular module kind

Optionally specify if the op has to be a direct child of CslModuleOp
(default is yes).
"""

def __init__(self, kind: ModuleKind, *, direct_child: bool = True):
super().__init__((kind, direct_child))

def verify(self, op: Operation) -> None:
kind: ModuleKind = self.parameters[0]
direct_child: bool = self.parameters[1]

direct = "direct" if direct_child else "indirect"
parent_module = op.parent_op()
if not direct_child:
while parent_module is not None and not isinstance(
parent_module, CslModuleOp
):
parent_module = parent_module.parent_op()
if not isinstance(parent_module, CslModuleOp):
raise VerifyException(
f"'{op.name}' expexts {direct} parent to be {CslModuleOp.name}, got {parent_module}"
)
if parent_module.kind.data != kind:
raise VerifyException(
f"'{op.name}' expexts {direct} parent to be {CslModuleOp.name} of kind {kind.value}"
)


@irdl_attr_definition
class ComptimeStructType(ParametrizedAttribute, TypeAttribute):
"""
Expand Down Expand Up @@ -104,6 +152,13 @@ class PtrConstAttr(EnumAttribute[PtrConst], SpacedOpaqueSyntaxAttribute):
name = "csl.ptr_const"


@irdl_attr_definition
class ModuleKindAttr(EnumAttribute[ModuleKind], SpacedOpaqueSyntaxAttribute):
"""Attribute representing the kind of CSL module, either layout or program"""

name = "csl.module_kind"


@irdl_attr_definition
class PtrType(ParametrizedAttribute, TypeAttribute, ContainerType[Attribute]):
"""
Expand Down Expand Up @@ -131,6 +186,29 @@ class ColorType(ParametrizedAttribute, TypeAttribute):
name = "csl.color"


@irdl_op_definition
class CslModuleOp(IRDLOperation):
"""
Separates layout module from program module
"""

# TODO(dk949): This should also probably handle csl `param`s

name = "csl.module"
body: Region = region_def("single_block")
kind = prop_def(ModuleKindAttr)
sym_name: StringAttr = attr_def(StringAttr)

traits = frozenset(
[
HasParent(ModuleOp),
IsolatedFromAbove(),
NoTerminator(),
SymbolOpInterface(),
]
)


@irdl_op_definition
class ImportModuleConstOp(IRDLOperation):
"""
Expand All @@ -139,6 +217,8 @@ class ImportModuleConstOp(IRDLOperation):

name = "csl.import_module"

traits = frozenset([HasParent(CslModuleOp)])

module = prop_def(StringAttr)

params = opt_operand_def(StructLike)
Expand Down Expand Up @@ -197,7 +277,9 @@ class FuncOp(IRDLOperation):
arg_attrs = opt_prop_def(ArrayAttr[DictionaryAttr])
res_attrs = opt_prop_def(ArrayAttr[DictionaryAttr])

traits = frozenset([SymbolOpInterface(), func.FuncOpCallableInterface()])
traits = frozenset(
[HasParent(CslModuleOp), SymbolOpInterface(), func.FuncOpCallableInterface()]
)

def __init__(
self,
Expand Down Expand Up @@ -307,6 +389,29 @@ def verify_(self) -> None:
)


@irdl_op_definition
class LayoutOp(IRDLOperation):
name = "csl.layout"

body: Region = region_def()

traits = frozenset([NoTerminator(), InModuleKind(ModuleKind.LAYOUT)])

def __init__(self, ops: Sequence[Operation] | Region):
if not isinstance(ops, Region):
ops = Region(Block(ops))
if len(ops.blocks) == 0:
ops = Region(Block([]))
super().__init__(regions=[ops])

@classmethod
def parse(cls, parser: Parser) -> LayoutOp:
return cls(parser.parse_region())

def print(self, printer: Printer):
printer.print(" ", self.body)


CSL = Dialect(
"csl",
[
Expand All @@ -315,6 +420,8 @@ def verify_(self) -> None:
ImportModuleConstOp,
MemberCallOp,
MemberAccessOp,
CslModuleOp,
LayoutOp,
],
[
ComptimeStructType,
Expand All @@ -323,5 +430,6 @@ def verify_(self) -> None:
PtrConstAttr,
PtrType,
ColorType,
ModuleKindAttr,
],
)
Loading