Skip to content

Commit

Permalink
transformations: Add linalg-to-csl pass (#3028)
Browse files Browse the repository at this point in the history
A short pass that translates (bufferized) linalg ops to csl ops.

Memrefs -> DsdType should be done in a separate pass.

---------

Co-authored-by: n-io <n-io@users.noreply.github.com>
  • Loading branch information
n-io and n-io authored Aug 15, 2024
1 parent 0697603 commit e7c69de
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 1 deletion.
33 changes: 33 additions & 0 deletions tests/filecheck/transforms/linalg-to-csl.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: xdsl-opt %s -p linalg-to-csl | filecheck %s

builtin.module {
%0, %1, %2, %3, %4 = "test.op"() : () -> (memref<16xf32>, memref<16xf32>, memref<16xf32>, memref<16xf32>, memref<16xf32>)
linalg.add ins(%1, %2 : memref<16xf32>, memref<16xf32>) outs(%0 : memref<16xf32>)
linalg.sub ins(%0, %3 : memref<16xf32>, memref<16xf32>) outs(%0 : memref<16xf32>)
linalg.mul ins(%0, %4 : memref<16xf32>, memref<16xf32>) outs(%0 : memref<16xf32>)

%5, %6, %7, %8, %9 = "test.op"() : () -> (memref<16xf16>, memref<16xf16>, memref<16xf16>, memref<16xf16>, memref<16xf16>)
linalg.add ins(%6, %7 : memref<16xf16>, memref<16xf16>) outs(%5 : memref<16xf16>)
linalg.sub ins(%5, %8 : memref<16xf16>, memref<16xf16>) outs(%5 : memref<16xf16>)
linalg.mul ins(%5, %9 : memref<16xf16>, memref<16xf16>) outs(%5 : memref<16xf16>)

%10 = arith.constant dense<1.123400e-01> : memref<16xf32>
linalg.add ins(%0, %10 : memref<16xf32>, memref<16xf32>) outs(%0 : memref<16xf32>)
linalg.mul ins(%10, %0 : memref<16xf32>, memref<16xf32>) outs(%0 : memref<16xf32>)
}

//CHECK-NEXT: builtin.module {
//CHECK-NEXT: %0, %1, %2, %3, %4 = "test.op"() : () -> (memref<16xf32>, memref<16xf32>, memref<16xf32>, memref<16xf32>, memref<16xf32>)
//CHECK-NEXT: "csl.fadds"(%0, %1, %2) : (memref<16xf32>, memref<16xf32>, memref<16xf32>) -> ()
//CHECK-NEXT: "csl.fsubs"(%0, %0, %3) : (memref<16xf32>, memref<16xf32>, memref<16xf32>) -> ()
//CHECK-NEXT: "csl.fmuls"(%0, %0, %4) : (memref<16xf32>, memref<16xf32>, memref<16xf32>) -> ()
//CHECK-NEXT: %5, %6, %7, %8, %9 = "test.op"() : () -> (memref<16xf16>, memref<16xf16>, memref<16xf16>, memref<16xf16>, memref<16xf16>)
//CHECK-NEXT: "csl.faddh"(%5, %6, %7) : (memref<16xf16>, memref<16xf16>, memref<16xf16>) -> ()
//CHECK-NEXT: "csl.fsubh"(%5, %5, %8) : (memref<16xf16>, memref<16xf16>, memref<16xf16>) -> ()
//CHECK-NEXT: "csl.fmulh"(%5, %5, %9) : (memref<16xf16>, memref<16xf16>, memref<16xf16>) -> ()
//CHECK-NEXT: %10 = arith.constant dense<1.123400e-01> : memref<16xf32>
//CHECK-NEXT: %11 = arith.constant 1.123400e-01 : f32
//CHECK-NEXT: "csl.fadds"(%0, %0, %11) : (memref<16xf32>, memref<16xf32>, f32) -> ()
//CHECK-NEXT: %12 = arith.constant 1.123400e-01 : f32
//CHECK-NEXT: "csl.fmuls"(%0, %12, %0) : (memref<16xf32>, f32, memref<16xf32>) -> ()
//CHECK-NEXT: }
5 changes: 4 additions & 1 deletion xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from xdsl.dialects.builtin import (
AnyFloatAttr,
AnyIntegerAttr,
AnyMemRefType,
ArrayAttr,
BoolAttr,
ContainerType,
Expand Down Expand Up @@ -1008,7 +1009,9 @@ def typcheck(
sig_typ: Attribute | type[Attribute],
) -> bool:
if isinstance(sig_typ, type):
return isinstance(op_typ, sig_typ)
return (
sig_typ == DsdType and isa(op_typ, AnyMemRefType)
) or isinstance(op_typ, sig_typ)
else:
return op_typ == sig_typ

Expand Down
6 changes: 6 additions & 0 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ def get_lift_arith_to_linalg():

return LiftArithToLinalg

def get_linalg_to_csl():
from xdsl.transforms.linalg_to_csl import LinalgToCsl

return LinalgToCsl

def get_lower_affine():
from xdsl.transforms import lower_affine

Expand Down Expand Up @@ -409,6 +414,7 @@ def get_stencil_bufferize():
"hls-convert-stencil-to-ll-mlir": get_hls_convert_stencil_to_ll_mlir,
"apply-individual-rewrite": get_individual_rewrite,
"lift-arith-to-linalg": get_lift_arith_to_linalg,
"linalg-to-csl": get_linalg_to_csl,
"lower-affine": get_lower_affine,
"lower-hls": get_lower_hls,
"lower-mpi": get_lower_mpi,
Expand Down
121 changes: 121 additions & 0 deletions xdsl/transforms/linalg_to_csl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from dataclasses import dataclass

from xdsl.context import MLContext
from xdsl.dialects import arith, linalg
from xdsl.dialects.builtin import (
AnyFloatAttr,
AnyIntegerAttr,
AnyMemRefType,
DenseIntOrFPElementsAttr,
Float16Type,
Float32Type,
ModuleOp,
)
from xdsl.dialects.csl import csl
from xdsl.ir import OpResult, SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.utils.hints import isa


class ConvertBinaryLinalgOp(RewritePattern):
"""
Base class for converting binary linalg operations.
"""

def transform_op(
self,
op: linalg.NamedOpBase,
rewriter: PatternRewriter,
f16: type[csl.BuiltinDsdOp],
f32: type[csl.BuiltinDsdOp],
):
if not isa(op.outputs.types[0], AnyMemRefType):
return

match op.outputs.types[0].get_element_type():
case Float16Type():
builtin = f16
case Float32Type():
builtin = f32
case _:
raise ValueError(
f"Unsupported element type {op.outputs.types[0].get_element_type()}"
)

lhs = op.inputs[0]
rhs = op.inputs[1]

# binary functions translated here support mixing scalar and collection operands
# may need revisiting if more functions are translated
if scalar_const := self._get_scalar_const(lhs):
rewriter.insert_op(
const_op := arith.Constant(scalar_const), InsertPoint.before(op)
)
lhs = const_op.result
elif scalar_const := self._get_scalar_const(rhs):
rewriter.insert_op(
const_op := arith.Constant(scalar_const), InsertPoint.before(op)
)
rhs = const_op.result

rewriter.replace_matched_op(builtin(operands=[[op.outputs[0], lhs, rhs]]))

@staticmethod
def _get_scalar_const(op: SSAValue) -> AnyFloatAttr | AnyIntegerAttr | None:
"""Returns the value of a scalar arith.constant, or None if not a constant or not scalar)."""
if (
isinstance(op, OpResult)
and isinstance(op.op, arith.Constant)
and isa(val := op.op.value, DenseIntOrFPElementsAttr)
and val.data.data.count(val.data.data[0]) == len(val.data.data)
):
return val.data.data[0]


class ConvertLinalgAddPass(ConvertBinaryLinalgOp):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg.AddOp, rewriter: PatternRewriter, /):
self.transform_op(op, rewriter, f16=csl.FaddhOp, f32=csl.FaddsOp)


class ConvertLinalgSubPass(ConvertBinaryLinalgOp):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg.SubOp, rewriter: PatternRewriter, /):
self.transform_op(op, rewriter, f16=csl.FsubhOp, f32=csl.FsubsOp)


class ConvertLinalgMulPass(ConvertBinaryLinalgOp):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg.MulOp, rewriter: PatternRewriter, /):
self.transform_op(op, rewriter, f16=csl.FmulhOp, f32=csl.FmulsOp)


@dataclass(frozen=True)
class LinalgToCsl(ModulePass):
"""
Convert linalg ops to csl ops.
The linalg ops are required to be in 'memref mode', i.e., after bufferization has been applied.
"""

name = "linalg-to-csl"

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
module_pass = PatternRewriteWalker(
GreedyRewritePatternApplier(
[
ConvertLinalgAddPass(),
ConvertLinalgSubPass(),
ConvertLinalgMulPass(),
]
),
)
module_pass.rewrite_module(op)

0 comments on commit e7c69de

Please sign in to comment.