Skip to content

Commit

Permalink
Define CMath dialect (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
purvi-H authored Jun 30, 2022
1 parent 88a8241 commit 8cb869a
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __post_init__(self):

self.ctx.register_attr(FunctionType)
self.ctx.register_attr(Float32Type)
self.ctx.register_attr(Float64Type)
self.ctx.register_attr(IntegerType)
self.ctx.register_attr(IndexType)

Expand Down Expand Up @@ -382,6 +383,13 @@ class Float32Type(ParametrizedAttribute):
f32 = Float32Type()


class Float64Type(ParametrizedAttribute):
name = "f64"


f64 = Float64Type()


@irdl_attr_definition
class UnitAttr(ParametrizedAttribute):
name: str = "unit"
Expand Down
56 changes: 56 additions & 0 deletions src/xdsl/dialects/cmath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations
from dataclasses import dataclass

from xdsl.dialects.builtin import *
from xdsl.irdl import *
from xdsl.ir import *


@dataclass
class CMath:
ctx: MLContext

def __post_init__(self):
self.ctx.register_attr(ComplexType)

self.ctx.register_op(Norm)
self.ctx.register_op(Mul)


@irdl_attr_definition
class ComplexType(ParametrizedAttribute):
name = "cmath.complex"
data: ParameterDef[Float64Type | Float32Type]


@irdl_op_definition
class Norm(Operation):
name: str = "cmath.norm"

op = OperandDef(
ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])]))
res = ResultDef(AnyOf([Float32Type, Float64Type]))

# TODO replace with trait
def verify_(self) -> None:
if self.op.typ.data != self.res.typ:
raise VerifyException(
"expect all input and output types to be equal")


@irdl_op_definition
class Mul(Operation):
name: str = "cmath.mul"

lhs = OperandDef(
ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])]))
rhs = OperandDef(
ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])]))
res = ResultDef(
ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])]))

# TODO replace with trait
def verify_(self) -> None:
if self.lhs != self.rhs.typ and self.rhs.typ != self.res.typ:
raise VerifyException(
"expect all input and output types to be equal")
2 changes: 2 additions & 0 deletions src/xdsl/xdsl_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from xdsl.dialects.affine import *
from xdsl.dialects.memref import *
from xdsl.dialects.builtin import *
from xdsl.dialects.cmath import *
from xdsl.dialects.cf import *


Expand Down Expand Up @@ -147,6 +148,7 @@ def register_all_dialects(self):
affine = Affine(self.ctx)
scf = Scf(self.ctx)
cf = Cf(self.ctx)
cmath = CMath(self.ctx)

def register_all_frontends(self):
"""
Expand Down
34 changes: 34 additions & 0 deletions tests/filecheck/cmath_ops.xdsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: xdsl-opt %s | xdsl-opt | filecheck %s

module() {

func.func() ["sym_name" = "conorm", "function_type" = !fun<[!cmath.complex<!f32>, !cmath.complex<!f32>], [!f32]>, "sym_visibility" = "private"] {
^0(%p: !cmath.complex<!f32>, %q: !cmath.complex<!f32>):
%norm_p : !f32 = cmath.norm(%p : !cmath.complex<!f32>)
%norm_q : !f32 = cmath.norm(%q : !cmath.complex<!f32>)
%pq : !f32 = arith.mulf(%norm_p : !f32, %norm_q : !f32)
func.return(%pq : !f32)
}

// CHECK: func.func() ["sym_name" = "conorm",
// CHECK-NEXT: ^{{.*}}({{.*}}: !cmath.complex<!f32>, {{.*}}: !cmath.complex<!f32>):
// CHECK-NEXT: %{{.*}} : !f32 = cmath.norm(%{{.*}} : !cmath.complex<!f32>)
// CHECK-NEXT: %{{.*}} : !f32 = cmath.norm(%{{.*}} : !cmath.complex<!f32>)
// CHECK-NEXT: %{{.*}} : !f32 = arith.mulf(%{{.*}} : !f32, %{{.*}} : !f32)
// CHECK-NEXT: func.return(%{{.*}} : !f32)
// CHECK-NEXT: }

func.func() ["sym_name" = "conorm2", "function_type" = !fun<[!cmath.complex<!f32>], [!f32]>, "sym_visibility" = "private"] {
^1(%a: !cmath.complex<!f32>, %b: !cmath.complex<!f32>):
%ab : !cmath.complex<!f32> = cmath.mul(%a : !cmath.complex<!f32>, %b : !cmath.complex<!f32>)
%conorm : !f32 = cmath.norm(%ab : !cmath.complex<!f32>)
func.return(%conorm : !f32)
}
// CHECK: func.func() ["sym_name" = "conorm2",
// CHECK-NEXT: ^{{.*}}(%{{.*}}: !cmath.complex<!f32>, %{{.*}}: !cmath.complex<!f32>):
// CHECK-NEXT: %{{.*}} : !cmath.complex<!f32> = cmath.mul(%{{.*}} : !cmath.complex<!f32>, %{{.*}} : !cmath.complex<!f32>)
// CHECK-NEXT: %{{.*}} : !f32 = cmath.norm(%{{.*}} : !cmath.complex<!f32>)
// CHECK-NEXT: func.return(%{{.*}} : !f32)
// CHECK-NEXT: }

}

0 comments on commit 8cb869a

Please sign in to comment.