2525#include " llvm/ADT/StringSwitch.h"
2626#include " llvm/Support/FormatVariadic.h"
2727#include " llvm/Support/raw_ostream.h"
28-
2928// Pull in all enum type definitions and utility function declarations.
3029#include " mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc"
3130
@@ -217,6 +216,21 @@ static LogicalResult foldMemRefCast(Operation *op) {
217216 return success (folded);
218217}
219218
219+ // //===----------------------------------------------------------------------===//
220+ // // Common cast compatibility check for vector types
221+ // //===----------------------------------------------------------------------===//
222+
223+ // // Type compatibility for vector casts.
224+ static bool areVectorCastSimpleCompatible (
225+ Type a, Type b, function_ref<bool (Type, Type)> areElementsCastCompatible) {
226+ if (auto va = a.dyn_cast <VectorType>())
227+ if (auto vb = b.dyn_cast <VectorType>())
228+ return va.getShape ().equals (vb.getShape ()) &&
229+ areElementsCastCompatible (va.getElementType (),
230+ vb.getElementType ());
231+ return false ;
232+ }
233+
220234// ===----------------------------------------------------------------------===//
221235// AddFOp
222236// ===----------------------------------------------------------------------===//
@@ -1764,27 +1778,27 @@ bool FPExtOp::areCastCompatible(Type a, Type b) {
17641778 if (auto fa = a.dyn_cast <FloatType>())
17651779 if (auto fb = b.dyn_cast <FloatType>())
17661780 return fa.getWidth () < fb.getWidth ();
1767- if (auto va = a.dyn_cast <VectorType>())
1768- if (auto vb = b.dyn_cast <VectorType>())
1769- return va.getShape ().equals (vb.getShape ()) &&
1770- areCastCompatible (va.getElementType (), vb.getElementType ());
1771- return false ;
1781+ return areVectorCastSimpleCompatible (a, b, areCastCompatible);
17721782}
17731783
17741784// ===----------------------------------------------------------------------===//
17751785// FPToSIOp
17761786// ===----------------------------------------------------------------------===//
17771787
17781788bool FPToSIOp::areCastCompatible (Type a, Type b) {
1779- return a.isa <FloatType>() && b.isSignlessInteger ();
1789+ if (a.isa <FloatType>() && b.isSignlessInteger ())
1790+ return true ;
1791+ return areVectorCastSimpleCompatible (a, b, areCastCompatible);
17801792}
17811793
17821794// ===----------------------------------------------------------------------===//
17831795// FPToUIOp
17841796// ===----------------------------------------------------------------------===//
17851797
17861798bool FPToUIOp::areCastCompatible (Type a, Type b) {
1787- return a.isa <FloatType>() && b.isSignlessInteger ();
1799+ if (a.isa <FloatType>() && b.isSignlessInteger ())
1800+ return true ;
1801+ return areVectorCastSimpleCompatible (a, b, areCastCompatible);
17881802}
17891803
17901804// ===----------------------------------------------------------------------===//
@@ -1795,11 +1809,7 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
17951809 if (auto fa = a.dyn_cast <FloatType>())
17961810 if (auto fb = b.dyn_cast <FloatType>())
17971811 return fa.getWidth () > fb.getWidth ();
1798- if (auto va = a.dyn_cast <VectorType>())
1799- if (auto vb = b.dyn_cast <VectorType>())
1800- return va.getShape ().equals (vb.getShape ()) &&
1801- areCastCompatible (va.getElementType (), vb.getElementType ());
1802- return false ;
1812+ return areVectorCastSimpleCompatible (a, b, areCastCompatible);
18031813}
18041814
18051815// ===----------------------------------------------------------------------===//
@@ -2239,7 +2249,9 @@ OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
22392249
22402250// sitofp is applicable from integer types to float types.
22412251bool SIToFPOp::areCastCompatible (Type a, Type b) {
2242- return a.isSignlessInteger () && b.isa <FloatType>();
2252+ if (a.isSignlessInteger () && b.isa <FloatType>())
2253+ return true ;
2254+ return areVectorCastSimpleCompatible (a, b, areCastCompatible);
22432255}
22442256
22452257// ===----------------------------------------------------------------------===//
@@ -2319,7 +2331,9 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
23192331
23202332// uitofp is applicable from integer types to float types.
23212333bool UIToFPOp::areCastCompatible (Type a, Type b) {
2322- return a.isSignlessInteger () && b.isa <FloatType>();
2334+ if (a.isSignlessInteger () && b.isa <FloatType>())
2335+ return true ;
2336+ return areVectorCastSimpleCompatible (a, b, areCastCompatible);
23232337}
23242338
23252339// ===----------------------------------------------------------------------===//
0 commit comments