Skip to content

Commit 81a1f9c

Browse files
author
Lubomir Litchev
authored
Merge pull request #27 from plaidml/llitchev-vector-cast-fixes
Add support for vector casts for fptosi and fptoui
2 parents a38882d + 8e46cd4 commit 81a1f9c

File tree

2 files changed

+100
-15
lines changed

2 files changed

+100
-15
lines changed

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "llvm/ADT/StringSwitch.h"
2626
#include "llvm/Support/FormatVariadic.h"
2727
#include "llvm/Support/raw_ostream.h"
28-
2928
// Pull in all enum type definitions and utility function declarations.
3029
#include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc"
3130

@@ -217,6 +216,21 @@ static LogicalResult foldMemRefCast(Operation *op) {
217216
return success(folded);
218217
}
219218

219+
// //===----------------------------------------------------------------------===//
220+
// // Common cast compatibility check for vector types
221+
// //===----------------------------------------------------------------------===//
222+
223+
// // Type compatibility for vector casts.
224+
static bool areVectorCastSimpleCompatible(
225+
Type a, Type b, function_ref<bool(Type, Type)> areElementsCastCompatible) {
226+
if (auto va = a.dyn_cast<VectorType>())
227+
if (auto vb = b.dyn_cast<VectorType>())
228+
return va.getShape().equals(vb.getShape()) &&
229+
areElementsCastCompatible(va.getElementType(),
230+
vb.getElementType());
231+
return false;
232+
}
233+
220234
//===----------------------------------------------------------------------===//
221235
// AddFOp
222236
//===----------------------------------------------------------------------===//
@@ -1764,27 +1778,27 @@ bool FPExtOp::areCastCompatible(Type a, Type b) {
17641778
if (auto fa = a.dyn_cast<FloatType>())
17651779
if (auto fb = b.dyn_cast<FloatType>())
17661780
return fa.getWidth() < fb.getWidth();
1767-
if (auto va = a.dyn_cast<VectorType>())
1768-
if (auto vb = b.dyn_cast<VectorType>())
1769-
return va.getShape().equals(vb.getShape()) &&
1770-
areCastCompatible(va.getElementType(), vb.getElementType());
1771-
return false;
1781+
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
17721782
}
17731783

17741784
//===----------------------------------------------------------------------===//
17751785
// FPToSIOp
17761786
//===----------------------------------------------------------------------===//
17771787

17781788
bool FPToSIOp::areCastCompatible(Type a, Type b) {
1779-
return a.isa<FloatType>() && b.isSignlessInteger();
1789+
if (a.isa<FloatType>() && b.isSignlessInteger())
1790+
return true;
1791+
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
17801792
}
17811793

17821794
//===----------------------------------------------------------------------===//
17831795
// FPToUIOp
17841796
//===----------------------------------------------------------------------===//
17851797

17861798
bool FPToUIOp::areCastCompatible(Type a, Type b) {
1787-
return a.isa<FloatType>() && b.isSignlessInteger();
1799+
if (a.isa<FloatType>() && b.isSignlessInteger())
1800+
return true;
1801+
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
17881802
}
17891803

17901804
//===----------------------------------------------------------------------===//
@@ -1795,11 +1809,7 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
17951809
if (auto fa = a.dyn_cast<FloatType>())
17961810
if (auto fb = b.dyn_cast<FloatType>())
17971811
return fa.getWidth() > fb.getWidth();
1798-
if (auto va = a.dyn_cast<VectorType>())
1799-
if (auto vb = b.dyn_cast<VectorType>())
1800-
return va.getShape().equals(vb.getShape()) &&
1801-
areCastCompatible(va.getElementType(), vb.getElementType());
1802-
return false;
1812+
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
18031813
}
18041814

18051815
//===----------------------------------------------------------------------===//
@@ -2239,7 +2249,9 @@ OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
22392249

22402250
// sitofp is applicable from integer types to float types.
22412251
bool SIToFPOp::areCastCompatible(Type a, Type b) {
2242-
return a.isSignlessInteger() && b.isa<FloatType>();
2252+
if (a.isSignlessInteger() && b.isa<FloatType>())
2253+
return true;
2254+
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
22432255
}
22442256

22452257
//===----------------------------------------------------------------------===//
@@ -2319,7 +2331,9 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
23192331

23202332
// uitofp is applicable from integer types to float types.
23212333
bool UIToFPOp::areCastCompatible(Type a, Type b) {
2322-
return a.isSignlessInteger() && b.isa<FloatType>();
2334+
if (a.isSignlessInteger() && b.isa<FloatType>())
2335+
return true;
2336+
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
23232337
}
23242338

23252339
//===----------------------------------------------------------------------===//

mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,24 @@ func @sitofp(%arg0 : i32, %arg1 : i64) {
594594
return
595595
}
596596

597+
// Checking conversion of integer vectors to floating point vector types.
598+
// CHECK-LABEL: @sitofp_vector
599+
func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
600+
// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float>
601+
%0 = sitofp %arg0: vector<2xi16> to vector<2xf32>
602+
// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double>
603+
%1 = sitofp %arg0: vector<2xi16> to vector<2xf64>
604+
// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float>
605+
%2 = sitofp %arg1: vector<2xi32> to vector<2xf32>
606+
// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double>
607+
%3 = sitofp %arg1: vector<2xi32> to vector<2xf64>
608+
// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float>
609+
%4 = sitofp %arg2: vector<2xi64> to vector<2xf32>
610+
// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double>
611+
%5 = sitofp %arg2: vector<2xi64> to vector<2xf64>
612+
return
613+
}
614+
597615
// Checking conversion of unsigned integer types to floating point.
598616
// CHECK-LABEL: @uitofp
599617
func @uitofp(%arg0 : i32, %arg1 : i64) {
@@ -646,6 +664,24 @@ func @fptosi(%arg0 : f32, %arg1 : f64) {
646664
return
647665
}
648666

667+
// Checking conversion of floating point vectors to integer vector types.
668+
// CHECK-LABEL: @fptosi_vector
669+
func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
670+
// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32>
671+
%0 = fptosi %arg0: vector<2xf16> to vector<2xi32>
672+
// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64>
673+
%1 = fptosi %arg0: vector<2xf16> to vector<2xi64>
674+
// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32>
675+
%2 = fptosi %arg1: vector<2xf32> to vector<2xi32>
676+
// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64>
677+
%3 = fptosi %arg1: vector<2xf32> to vector<2xi64>
678+
// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32>
679+
%4 = fptosi %arg2: vector<2xf64> to vector<2xi32>
680+
// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64>
681+
%5 = fptosi %arg2: vector<2xf64> to vector<2xi64>
682+
return
683+
}
684+
649685
// Checking conversion of floating point to integer types.
650686
// CHECK-LABEL: @fptoui
651687
func @fptoui(%arg0 : f32, %arg1 : f64) {
@@ -660,6 +696,41 @@ func @fptoui(%arg0 : f32, %arg1 : f64) {
660696
return
661697
}
662698

699+
// Checking conversion of floating point vectors to integer vector types.
700+
// CHECK-LABEL: @fptoui_vector
701+
func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
702+
// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32>
703+
%0 = fptoui %arg0: vector<2xf16> to vector<2xi32>
704+
// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64>
705+
%1 = fptoui %arg0: vector<2xf16> to vector<2xi64>
706+
// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32>
707+
%2 = fptoui %arg1: vector<2xf32> to vector<2xi32>
708+
// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64>
709+
%3 = fptoui %arg1: vector<2xf32> to vector<2xi64>
710+
// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32>
711+
%4 = fptoui %arg2: vector<2xf64> to vector<2xi32>
712+
// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64>
713+
%5 = fptoui %arg2: vector<2xf64> to vector<2xi64>
714+
return
715+
}
716+
717+
// Checking conversion of integer vectors to floating point vector types.
718+
// CHECK-LABEL: @uitofp_vector
719+
func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
720+
// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float>
721+
%0 = uitofp %arg0: vector<2xi16> to vector<2xf32>
722+
// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double>
723+
%1 = uitofp %arg0: vector<2xi16> to vector<2xf64>
724+
// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float>
725+
%2 = uitofp %arg1: vector<2xi32> to vector<2xf32>
726+
// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double>
727+
%3 = uitofp %arg1: vector<2xi32> to vector<2xf64>
728+
// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float>
729+
%4 = uitofp %arg2: vector<2xi64> to vector<2xf32>
730+
// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double>
731+
%5 = uitofp %arg2: vector<2xi64> to vector<2xf64>
732+
return
733+
}
663734

664735
// Checking conversion of integer types to floating point.
665736
// CHECK-LABEL: @fptrunc

0 commit comments

Comments
 (0)