Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whitespace here too

// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc"

Expand Down Expand Up @@ -217,6 +216,21 @@ static LogicalResult foldMemRefCast(Operation *op) {
return success(folded);
}

// //===----------------------------------------------------------------------===//
// // Common cast compatibility check for vector types
// //===----------------------------------------------------------------------===//

// // Type compatibility for vector casts.
static bool areVectorCastSimpleCompatible(
Type a, Type b, function_ref<bool(Type, Type)> areElementsCastCompatible) {
if (auto va = a.dyn_cast<VectorType>())
if (auto vb = b.dyn_cast<VectorType>())
return va.getShape().equals(vb.getShape()) &&
areElementsCastCompatible(va.getElementType(),
vb.getElementType());
return false;
}

//===----------------------------------------------------------------------===//
// AddFOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1764,27 +1778,27 @@ bool FPExtOp::areCastCompatible(Type a, Type b) {
if (auto fa = a.dyn_cast<FloatType>())
if (auto fb = b.dyn_cast<FloatType>())
return fa.getWidth() < fb.getWidth();
if (auto va = a.dyn_cast<VectorType>())
if (auto vb = b.dyn_cast<VectorType>())
return va.getShape().equals(vb.getShape()) &&
areCastCompatible(va.getElementType(), vb.getElementType());
return false;
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
}

//===----------------------------------------------------------------------===//
// FPToSIOp
//===----------------------------------------------------------------------===//

bool FPToSIOp::areCastCompatible(Type a, Type b) {
return a.isa<FloatType>() && b.isSignlessInteger();
if (a.isa<FloatType>() && b.isSignlessInteger())
return true;
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
}

//===----------------------------------------------------------------------===//
// FPToUIOp
//===----------------------------------------------------------------------===//

bool FPToUIOp::areCastCompatible(Type a, Type b) {
return a.isa<FloatType>() && b.isSignlessInteger();
if (a.isa<FloatType>() && b.isSignlessInteger())
return true;
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
}

//===----------------------------------------------------------------------===//
Expand All @@ -1795,11 +1809,7 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
if (auto fa = a.dyn_cast<FloatType>())
if (auto fb = b.dyn_cast<FloatType>())
return fa.getWidth() > fb.getWidth();
if (auto va = a.dyn_cast<VectorType>())
if (auto vb = b.dyn_cast<VectorType>())
return va.getShape().equals(vb.getShape()) &&
areCastCompatible(va.getElementType(), vb.getElementType());
return false;
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2239,7 +2249,9 @@ OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {

// sitofp is applicable from integer types to float types.
bool SIToFPOp::areCastCompatible(Type a, Type b) {
return a.isSignlessInteger() && b.isa<FloatType>();
if (a.isSignlessInteger() && b.isa<FloatType>())
return true;
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2319,7 +2331,9 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {

// uitofp is applicable from integer types to float types.
bool UIToFPOp::areCastCompatible(Type a, Type b) {
return a.isSignlessInteger() && b.isa<FloatType>();
if (a.isSignlessInteger() && b.isa<FloatType>())
return true;
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
}

//===----------------------------------------------------------------------===//
Expand Down
71 changes: 71 additions & 0 deletions mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,24 @@ func @sitofp(%arg0 : i32, %arg1 : i64) {
return
}

// Checking conversion of integer vectors to floating point vector types.
// CHECK-LABEL: @sitofp_vector
func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float>
%0 = sitofp %arg0: vector<2xi16> to vector<2xf32>
// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double>
%1 = sitofp %arg0: vector<2xi16> to vector<2xf64>
// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float>
%2 = sitofp %arg1: vector<2xi32> to vector<2xf32>
// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double>
%3 = sitofp %arg1: vector<2xi32> to vector<2xf64>
// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float>
%4 = sitofp %arg2: vector<2xi64> to vector<2xf32>
// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double>
%5 = sitofp %arg2: vector<2xi64> to vector<2xf64>
return
}

// Checking conversion of unsigned integer types to floating point.
// CHECK-LABEL: @uitofp
func @uitofp(%arg0 : i32, %arg1 : i64) {
Expand Down Expand Up @@ -646,6 +664,24 @@ func @fptosi(%arg0 : f32, %arg1 : f64) {
return
}

// Checking conversion of floating point vectors to integer vector types.
// CHECK-LABEL: @fptosi_vector
func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32>
%0 = fptosi %arg0: vector<2xf16> to vector<2xi32>
// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64>
%1 = fptosi %arg0: vector<2xf16> to vector<2xi64>
// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32>
%2 = fptosi %arg1: vector<2xf32> to vector<2xi32>
// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64>
%3 = fptosi %arg1: vector<2xf32> to vector<2xi64>
// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32>
%4 = fptosi %arg2: vector<2xf64> to vector<2xi32>
// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64>
%5 = fptosi %arg2: vector<2xf64> to vector<2xi64>
return
}

// Checking conversion of floating point to integer types.
// CHECK-LABEL: @fptoui
func @fptoui(%arg0 : f32, %arg1 : f64) {
Expand All @@ -660,6 +696,41 @@ func @fptoui(%arg0 : f32, %arg1 : f64) {
return
}

// Checking conversion of floating point vectors to integer vector types.
// CHECK-LABEL: @fptoui_vector
func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32>
%0 = fptoui %arg0: vector<2xf16> to vector<2xi32>
// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64>
%1 = fptoui %arg0: vector<2xf16> to vector<2xi64>
// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32>
%2 = fptoui %arg1: vector<2xf32> to vector<2xi32>
// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64>
%3 = fptoui %arg1: vector<2xf32> to vector<2xi64>
// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32>
%4 = fptoui %arg2: vector<2xf64> to vector<2xi32>
// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64>
%5 = fptoui %arg2: vector<2xf64> to vector<2xi64>
return
}

// Checking conversion of integer vectors to floating point vector types.
// CHECK-LABEL: @uitofp_vector
func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float>
%0 = uitofp %arg0: vector<2xi16> to vector<2xf32>
// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double>
%1 = uitofp %arg0: vector<2xi16> to vector<2xf64>
// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float>
%2 = uitofp %arg1: vector<2xi32> to vector<2xf32>
// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double>
%3 = uitofp %arg1: vector<2xi32> to vector<2xf64>
// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float>
%4 = uitofp %arg2: vector<2xi64> to vector<2xf32>
// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double>
%5 = uitofp %arg2: vector<2xi64> to vector<2xf64>
return
}

// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fptrunc
Expand Down