diff --git a/src/xdsl/dialects/builtin.py b/src/xdsl/dialects/builtin.py index 6951135e16..e9081a57f5 100644 --- a/src/xdsl/dialects/builtin.py +++ b/src/xdsl/dialects/builtin.py @@ -29,6 +29,7 @@ def __post_init__(self): self.ctx.register_attr(FunctionType) self.ctx.register_attr(Float32Type) + self.ctx.register_attr(Float64Type) self.ctx.register_attr(IntegerType) self.ctx.register_attr(IndexType) @@ -382,6 +383,13 @@ class Float32Type(ParametrizedAttribute): f32 = Float32Type() +class Float64Type(ParametrizedAttribute): + name = "f64" + + +f64 = Float64Type() + + @irdl_attr_definition class UnitAttr(ParametrizedAttribute): name: str = "unit" diff --git a/src/xdsl/dialects/cmath.py b/src/xdsl/dialects/cmath.py new file mode 100644 index 0000000000..51686f7831 --- /dev/null +++ b/src/xdsl/dialects/cmath.py @@ -0,0 +1,56 @@ +from __future__ import annotations +from dataclasses import dataclass + +from xdsl.dialects.builtin import * +from xdsl.irdl import * +from xdsl.ir import * + + +@dataclass +class CMath: + ctx: MLContext + + def __post_init__(self): + self.ctx.register_attr(ComplexType) + + self.ctx.register_op(Norm) + self.ctx.register_op(Mul) + + +@irdl_attr_definition +class ComplexType(ParametrizedAttribute): + name = "cmath.complex" + data: ParameterDef[Float64Type | Float32Type] + + +@irdl_op_definition +class Norm(Operation): + name: str = "cmath.norm" + + op = OperandDef( + ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])])) + res = ResultDef(AnyOf([Float32Type, Float64Type])) + + # TODO replace with trait + def verify_(self) -> None: + if self.op.typ.data != self.res.typ: + raise VerifyException( + "expect all input and output types to be equal") + + +@irdl_op_definition +class Mul(Operation): + name: str = "cmath.mul" + + lhs = OperandDef( + ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])])) + rhs = OperandDef( + ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])])) + res = ResultDef( + ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])])) + + # TODO replace with trait + def verify_(self) -> None: + if self.lhs != self.rhs.typ and self.rhs.typ != self.res.typ: + raise VerifyException( + "expect all input and output types to be equal") diff --git a/src/xdsl/xdsl_opt_main.py b/src/xdsl/xdsl_opt_main.py index 41e325c3e2..2cd1a6597c 100644 --- a/src/xdsl/xdsl_opt_main.py +++ b/src/xdsl/xdsl_opt_main.py @@ -11,6 +11,7 @@ from xdsl.dialects.affine import * from xdsl.dialects.memref import * from xdsl.dialects.builtin import * +from xdsl.dialects.cmath import * from xdsl.dialects.cf import * @@ -147,6 +148,7 @@ def register_all_dialects(self): affine = Affine(self.ctx) scf = Scf(self.ctx) cf = Cf(self.ctx) + cmath = CMath(self.ctx) def register_all_frontends(self): """ diff --git a/tests/filecheck/cmath_ops.xdsl b/tests/filecheck/cmath_ops.xdsl new file mode 100644 index 0000000000..84343996da --- /dev/null +++ b/tests/filecheck/cmath_ops.xdsl @@ -0,0 +1,34 @@ +// RUN: xdsl-opt %s | xdsl-opt | filecheck %s + +module() { + + func.func() ["sym_name" = "conorm", "function_type" = !fun<[!cmath.complex, !cmath.complex], [!f32]>, "sym_visibility" = "private"] { + ^0(%p: !cmath.complex, %q: !cmath.complex): + %norm_p : !f32 = cmath.norm(%p : !cmath.complex) + %norm_q : !f32 = cmath.norm(%q : !cmath.complex) + %pq : !f32 = arith.mulf(%norm_p : !f32, %norm_q : !f32) + func.return(%pq : !f32) + } + + // CHECK: func.func() ["sym_name" = "conorm", + // CHECK-NEXT: ^{{.*}}({{.*}}: !cmath.complex, {{.*}}: !cmath.complex): + // CHECK-NEXT: %{{.*}} : !f32 = cmath.norm(%{{.*}} : !cmath.complex) + // CHECK-NEXT: %{{.*}} : !f32 = cmath.norm(%{{.*}} : !cmath.complex) + // CHECK-NEXT: %{{.*}} : !f32 = arith.mulf(%{{.*}} : !f32, %{{.*}} : !f32) + // CHECK-NEXT: func.return(%{{.*}} : !f32) + // CHECK-NEXT: } + + func.func() ["sym_name" = "conorm2", "function_type" = !fun<[!cmath.complex], [!f32]>, "sym_visibility" = "private"] { + ^1(%a: !cmath.complex, %b: !cmath.complex): + %ab : !cmath.complex = cmath.mul(%a : !cmath.complex, %b : !cmath.complex) + %conorm : !f32 = cmath.norm(%ab : !cmath.complex) + func.return(%conorm : !f32) + } + // CHECK: func.func() ["sym_name" = "conorm2", + // CHECK-NEXT: ^{{.*}}(%{{.*}}: !cmath.complex, %{{.*}}: !cmath.complex): + // CHECK-NEXT: %{{.*}} : !cmath.complex = cmath.mul(%{{.*}} : !cmath.complex, %{{.*}} : !cmath.complex) + // CHECK-NEXT: %{{.*}} : !f32 = cmath.norm(%{{.*}} : !cmath.complex) + // CHECK-NEXT: func.return(%{{.*}} : !f32) + // CHECK-NEXT: } + +}