From 885666a9c3ef3cad10ae3b026672afbc2894db33 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Wed, 6 Sep 2023 07:20:59 +0000 Subject: [PATCH] [mlir][llvm] Return failure from type converter for n-D scalable vectors This patch changes vector type conversion to return failure on n-D scalable vector types instead of asserting. This is an alternative approach to #65261 that aims to enable lowering of Vector ops directly to ArmSME intrinsics where possible, and seems more consistent with other type conversions. It's trivial to hit the assert at the moment and it could be interpreted as n-D scalable vector types being a bug, when they're valid types in the Vector dialect. By returning failure it will generally fail more gracefully, particularly for release builds or other builds where assertions are disabled. --- .../Conversion/LLVMCommon/TypeConverter.h | 2 +- .../Conversion/LLVMCommon/TypeConverter.cpp | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h index ed174699314e8..2a4327535c687 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -239,7 +239,7 @@ class LLVMTypeConverter : public TypeConverter { Type convertMemRefToBarePtr(BaseMemRefType type) const; /// Convert a 1D vector type into an LLVM vector type. - Type convertVectorType(VectorType type) const; + FailureOr convertVectorType(VectorType type) const; /// Options for customizing the llvm lowering. LowerToLLVMOptions options; diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index a9e7ce9d42848..49e0513e629d9 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -61,7 +61,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, addConversion([&](MemRefType type) { return convertMemRefType(type); }); addConversion( [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); - addConversion([&](VectorType type) { return convertVectorType(type); }); + addConversion([&](VectorType type) -> std::optional { + FailureOr llvmType = convertVectorType(type); + if (failed(llvmType)) + return std::nullopt; + return llvmType; + }); // LLVM-compatible types are legal, so add a pass-through conversion. Do this // before the conversions below since conversions are attempted in reverse @@ -490,10 +495,9 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const { /// * 1-D `vector` remains as is while, /// * n>1 `vector` convert via an (n-1)-D array type to /// `!llvm.array>>`. -/// As LLVM does not support arrays of scalable vectors, it is assumed that -/// scalable vectors are always 1-D. This condition could be relaxed once the -/// missing functionality is added in LLVM -Type LLVMTypeConverter::convertVectorType(VectorType type) const { +/// Returns failure for n-D scalable vector types as LLVM does not support +/// arrays of scalable vectors. +FailureOr LLVMTypeConverter::convertVectorType(VectorType type) const { auto elementType = convertType(type.getElementType()); if (!elementType) return {}; @@ -503,9 +507,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) const { type.getScalableDims().back()); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); - assert( - (!type.isScalable() || (type.getRank() == 1)) && - "expected 1-D scalable vector (n-D scalable vectors are not supported)"); + if (type.isScalable() && (type.getRank() > 1)) + return failure(); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);