diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 9a54430482160..59237e91a103f 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -25,7 +25,6 @@ #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" - // Pull in all enum type definitions and utility function declarations. #include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc" @@ -217,6 +216,21 @@ static LogicalResult foldMemRefCast(Operation *op) { return success(folded); } +// //===----------------------------------------------------------------------===// +// // Common cast compatibility check for vector types +// //===----------------------------------------------------------------------===// + +// // Type compatibility for vector casts. +static bool areVectorCastSimpleCompatible( + Type a, Type b, function_ref areElementsCastCompatible) { + if (auto va = a.dyn_cast()) + if (auto vb = b.dyn_cast()) + return va.getShape().equals(vb.getShape()) && + areElementsCastCompatible(va.getElementType(), + vb.getElementType()); + return false; +} + //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// @@ -1764,11 +1778,7 @@ bool FPExtOp::areCastCompatible(Type a, Type b) { if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() < fb.getWidth(); - if (auto va = a.dyn_cast()) - if (auto vb = b.dyn_cast()) - return va.getShape().equals(vb.getShape()) && - areCastCompatible(va.getElementType(), vb.getElementType()); - return false; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -1776,7 +1786,9 @@ bool FPExtOp::areCastCompatible(Type a, Type b) { //===----------------------------------------------------------------------===// bool FPToSIOp::areCastCompatible(Type a, Type b) { - return a.isa() && b.isSignlessInteger(); + if (a.isa() && b.isSignlessInteger()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -1784,7 +1796,9 @@ bool FPToSIOp::areCastCompatible(Type a, Type b) { //===----------------------------------------------------------------------===// bool FPToUIOp::areCastCompatible(Type a, Type b) { - return a.isa() && b.isSignlessInteger(); + if (a.isa() && b.isSignlessInteger()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -1795,11 +1809,7 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) { if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() > fb.getWidth(); - if (auto va = a.dyn_cast()) - if (auto vb = b.dyn_cast()) - return va.getShape().equals(vb.getShape()) && - areCastCompatible(va.getElementType(), vb.getElementType()); - return false; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -2239,7 +2249,9 @@ OpFoldResult SignedRemIOp::fold(ArrayRef operands) { // sitofp is applicable from integer types to float types. bool SIToFPOp::areCastCompatible(Type a, Type b) { - return a.isSignlessInteger() && b.isa(); + if (a.isSignlessInteger() && b.isa()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -2319,7 +2331,9 @@ OpFoldResult SubIOp::fold(ArrayRef operands) { // uitofp is applicable from integer types to float types. bool UIToFPOp::areCastCompatible(Type a, Type b) { - return a.isSignlessInteger() && b.isa(); + if (a.isSignlessInteger() && b.isa()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index 62be4783e364b..bb0363b1cba52 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -594,6 +594,24 @@ func @sitofp(%arg0 : i32, %arg1 : i64) { return } +// Checking conversion of integer vectors to floating point vector types. +// CHECK-LABEL: @sitofp_vector +func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) { +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float> + %0 = sitofp %arg0: vector<2xi16> to vector<2xf32> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double> + %1 = sitofp %arg0: vector<2xi16> to vector<2xf64> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float> + %2 = sitofp %arg1: vector<2xi32> to vector<2xf32> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double> + %3 = sitofp %arg1: vector<2xi32> to vector<2xf64> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float> + %4 = sitofp %arg2: vector<2xi64> to vector<2xf32> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double> + %5 = sitofp %arg2: vector<2xi64> to vector<2xf64> + return +} + // Checking conversion of unsigned integer types to floating point. // CHECK-LABEL: @uitofp func @uitofp(%arg0 : i32, %arg1 : i64) { @@ -646,6 +664,24 @@ func @fptosi(%arg0 : f32, %arg1 : f64) { return } +// Checking conversion of floating point vectors to integer vector types. +// CHECK-LABEL: @fptosi_vector +func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) { +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32> + %0 = fptosi %arg0: vector<2xf16> to vector<2xi32> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64> + %1 = fptosi %arg0: vector<2xf16> to vector<2xi64> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32> + %2 = fptosi %arg1: vector<2xf32> to vector<2xi32> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64> + %3 = fptosi %arg1: vector<2xf32> to vector<2xi64> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32> + %4 = fptosi %arg2: vector<2xf64> to vector<2xi32> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64> + %5 = fptosi %arg2: vector<2xf64> to vector<2xi64> + return +} + // Checking conversion of floating point to integer types. // CHECK-LABEL: @fptoui func @fptoui(%arg0 : f32, %arg1 : f64) { @@ -660,6 +696,41 @@ func @fptoui(%arg0 : f32, %arg1 : f64) { return } +// Checking conversion of floating point vectors to integer vector types. +// CHECK-LABEL: @fptoui_vector +func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) { +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32> + %0 = fptoui %arg0: vector<2xf16> to vector<2xi32> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64> + %1 = fptoui %arg0: vector<2xf16> to vector<2xi64> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32> + %2 = fptoui %arg1: vector<2xf32> to vector<2xi32> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64> + %3 = fptoui %arg1: vector<2xf32> to vector<2xi64> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32> + %4 = fptoui %arg2: vector<2xf64> to vector<2xi32> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64> + %5 = fptoui %arg2: vector<2xf64> to vector<2xi64> + return +} + +// Checking conversion of integer vectors to floating point vector types. +// CHECK-LABEL: @uitofp_vector +func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) { +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float> + %0 = uitofp %arg0: vector<2xi16> to vector<2xf32> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double> + %1 = uitofp %arg0: vector<2xi16> to vector<2xf64> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float> + %2 = uitofp %arg1: vector<2xi32> to vector<2xf32> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double> + %3 = uitofp %arg1: vector<2xi32> to vector<2xf64> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float> + %4 = uitofp %arg2: vector<2xi64> to vector<2xf32> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double> + %5 = uitofp %arg2: vector<2xi64> to vector<2xf64> + return +} // Checking conversion of integer types to floating point. // CHECK-LABEL: @fptrunc