diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h index c1016dd7bdddb..7f8a0c9c0af7b 100644 --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -18,6 +18,7 @@ #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/IR/Module.h" #include "llvm/IR/VFABIDemangler.h" +#include "llvm/IR/VectorTypeUtils.h" #include "llvm/Support/CheckedArithmetic.h" namespace llvm { @@ -127,19 +128,6 @@ namespace Intrinsic { typedef unsigned ID; } -/// A helper function for converting Scalar types to vector types. If -/// the incoming type is void, we return void. If the EC represents a -/// scalar, we return the scalar type. -inline Type *ToVectorTy(Type *Scalar, ElementCount EC) { - if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar()) - return Scalar; - return VectorType::get(Scalar, EC); -} - -inline Type *ToVectorTy(Type *Scalar, unsigned VF) { - return ToVectorTy(Scalar, ElementCount::getFixed(VF)); -} - /// Identify if the intrinsic is trivially vectorizable. /// This method returns true if the intrinsic's argument types are all scalars /// for the scalar form of the intrinsic and all vectors (or scalars handled by diff --git a/llvm/include/llvm/IR/VectorTypeUtils.h b/llvm/include/llvm/IR/VectorTypeUtils.h new file mode 100644 index 0000000000000..f30bf9ee9240b --- /dev/null +++ b/llvm/include/llvm/IR/VectorTypeUtils.h @@ -0,0 +1,94 @@ +//===------- VectorTypeUtils.h - Vector type utility functions -*- C++ -*-====// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_VECTORTYPEUTILS_H +#define LLVM_IR_VECTORTYPEUTILS_H + +#include "llvm/IR/DerivedTypes.h" + +namespace llvm { + +/// A helper function for converting Scalar types to vector types. If +/// the incoming type is void, we return void. If the EC represents a +/// scalar, we return the scalar type. +inline Type *ToVectorTy(Type *Scalar, ElementCount EC) { + if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar()) + return Scalar; + return VectorType::get(Scalar, EC); +} + +inline Type *ToVectorTy(Type *Scalar, unsigned VF) { + return ToVectorTy(Scalar, ElementCount::getFixed(VF)); +} + +/// A helper for converting structs of scalar types to structs of vector types. +/// Note: +/// - If \p EC is scalar, \p StructTy is returned unchanged +/// - Only unpacked literal struct types are supported +Type *toVectorizedStructTy(StructType *StructTy, ElementCount EC); + +/// A helper for converting structs of vector types to structs of scalar types. +/// Note: Only unpacked literal struct types are supported. +Type *toScalarizedStructTy(StructType *StructTy); + +/// Returns true if `StructTy` is an unpacked literal struct where all elements +/// are vectors of matching element count. This does not include empty structs. +bool isVectorizedStructTy(StructType *StructTy); + +/// A helper for converting to vectorized types. For scalar types, this is +/// equivalent to calling `ToVectorTy`. For struct types, this returns a new +/// struct where each element type has been widened to a vector type. +/// Note: +/// - If the incoming type is void, we return void +/// - If \p EC is scalar, \p Ty is returned unchanged +/// - Only unpacked literal struct types are supported +inline Type *toVectorizedTy(Type *Ty, ElementCount EC) { + if (StructType *StructTy = dyn_cast(Ty)) + return toVectorizedStructTy(StructTy, EC); + return ToVectorTy(Ty, EC); +} + +/// A helper for converting vectorized types to scalarized (non-vector) types. +/// For vector types, this is equivalent to calling .getScalarType(). For struct +/// types, this returns a new struct where each element type has been converted +/// to a scalar type. Note: Only unpacked literal struct types are supported. +inline Type *toScalarizedTy(Type *Ty) { + if (StructType *StructTy = dyn_cast(Ty)) + return toScalarizedStructTy(StructTy); + return Ty->getScalarType(); +} + +/// Returns true if `Ty` is a vector type or a struct of vector types where all +/// vector types share the same VF. +inline bool isVectorizedTy(Type *Ty) { + if (StructType *StructTy = dyn_cast(Ty)) + return isVectorizedStructTy(StructTy); + return Ty->isVectorTy(); +} + +/// Returns the types contained in `Ty`. For struct types, it returns the +/// elements, all other types are returned directly. +inline ArrayRef getContainedTypes(Type *const &Ty) { + if (auto *StructTy = dyn_cast(Ty)) + return StructTy->elements(); + return ArrayRef(&Ty, 1); +} + +/// Returns the number of vector elements for a vectorized type. +inline ElementCount getVectorizedTypeVF(Type *Ty) { + assert(isVectorizedTy(Ty) && "expected vectorized type"); + return cast(getContainedTypes(Ty).front())->getElementCount(); +} + +inline bool isUnpackedStructLiteral(StructType *StructTy) { + return StructTy->isLiteral() && !StructTy->isPacked(); +} + +} // namespace llvm + +#endif diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt index 544f4ea9223d0..5f6254b231318 100644 --- a/llvm/lib/IR/CMakeLists.txt +++ b/llvm/lib/IR/CMakeLists.txt @@ -73,6 +73,7 @@ add_llvm_component_library(LLVMCore Value.cpp ValueSymbolTable.cpp VectorBuilder.cpp + VectorTypeUtils.cpp Verifier.cpp VFABIDemangler.cpp RuntimeLibcalls.cpp diff --git a/llvm/lib/IR/VFABIDemangler.cpp b/llvm/lib/IR/VFABIDemangler.cpp index 897583084bf38..62f96b10cea4a 100644 --- a/llvm/lib/IR/VFABIDemangler.cpp +++ b/llvm/lib/IR/VFABIDemangler.cpp @@ -11,6 +11,7 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/IR/Module.h" +#include "llvm/IR/VectorTypeUtils.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include @@ -346,12 +347,20 @@ getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA, // Also check the return type if not void. Type *RetTy = Signature->getReturnType(); if (!RetTy->isVoidTy()) { - std::optional ReturnEC = getElementCountForTy(ISA, RetTy); - // If we have an unknown scalar element type we can't find a reasonable VF. - if (!ReturnEC) + // If the return type is a struct, only allow unpacked struct literals. + StructType *StructTy = dyn_cast(RetTy); + if (StructTy && !isUnpackedStructLiteral(StructTy)) return std::nullopt; - if (ElementCount::isKnownLT(*ReturnEC, MinEC)) - MinEC = *ReturnEC; + + for (Type *RetTy : getContainedTypes(RetTy)) { + std::optional ReturnEC = getElementCountForTy(ISA, RetTy); + // If we have an unknown scalar element type we can't find a reasonable + // VF. + if (!ReturnEC) + return std::nullopt; + if (ElementCount::isKnownLT(*ReturnEC, MinEC)) + MinEC = *ReturnEC; + } } // The SVE Vector function call ABI bases the VF on the widest element types @@ -566,7 +575,7 @@ FunctionType *VFABI::createFunctionType(const VFInfo &Info, auto *RetTy = ScalarFTy->getReturnType(); if (!RetTy->isVoidTy()) - RetTy = VectorType::get(RetTy, VF); + RetTy = toVectorizedTy(RetTy, VF); return FunctionType::get(RetTy, VecTypes, false); } diff --git a/llvm/lib/IR/VectorTypeUtils.cpp b/llvm/lib/IR/VectorTypeUtils.cpp new file mode 100644 index 0000000000000..e6e265414a2b8 --- /dev/null +++ b/llvm/lib/IR/VectorTypeUtils.cpp @@ -0,0 +1,54 @@ +//===------- VectorTypeUtils.cpp - Vector type utility functions ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/VectorTypeUtils.h" +#include "llvm/ADT/SmallVectorExtras.h" + +using namespace llvm; + +/// A helper for converting structs of scalar types to structs of vector types. +/// Note: Only unpacked literal struct types are supported. +Type *llvm::toVectorizedStructTy(StructType *StructTy, ElementCount EC) { + if (EC.isScalar()) + return StructTy; + assert(isUnpackedStructLiteral(StructTy) && + "expected unpacked struct literal"); + assert(all_of(StructTy->elements(), VectorType::isValidElementType) && + "expected all element types to be valid vector element types"); + return StructType::get( + StructTy->getContext(), + map_to_vector(StructTy->elements(), [&](Type *ElTy) -> Type * { + return VectorType::get(ElTy, EC); + })); +} + +/// A helper for converting structs of vector types to structs of scalar types. +/// Note: Only unpacked literal struct types are supported. +Type *llvm::toScalarizedStructTy(StructType *StructTy) { + assert(isUnpackedStructLiteral(StructTy) && + "expected unpacked struct literal"); + return StructType::get( + StructTy->getContext(), + map_to_vector(StructTy->elements(), [](Type *ElTy) -> Type * { + return ElTy->getScalarType(); + })); +} + +/// Returns true if `StructTy` is an unpacked literal struct where all elements +/// are vectors of matching element count. This does not include empty structs. +bool llvm::isVectorizedStructTy(StructType *StructTy) { + if (!isUnpackedStructLiteral(StructTy)) + return false; + auto ElemTys = StructTy->elements(); + if (ElemTys.empty() || !ElemTys.front()->isVectorTy()) + return false; + ElementCount VF = cast(ElemTys.front())->getElementCount(); + return all_of(ElemTys, [&](Type *Ty) { + return Ty->isVectorTy() && cast(Ty)->getElementCount() == VF; + }); +} diff --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt index ed93ee547d223..b3dfe3d72fd38 100644 --- a/llvm/unittests/IR/CMakeLists.txt +++ b/llvm/unittests/IR/CMakeLists.txt @@ -51,6 +51,7 @@ add_llvm_unittest(IRTests ValueMapTest.cpp ValueTest.cpp VectorBuilderTest.cpp + VectorTypeUtilsTest.cpp VectorTypesTest.cpp VerifierTest.cpp VFABIDemanglerTest.cpp diff --git a/llvm/unittests/IR/VFABIDemanglerTest.cpp b/llvm/unittests/IR/VFABIDemanglerTest.cpp index 07bff16df4933..e30e0f865f719 100644 --- a/llvm/unittests/IR/VFABIDemanglerTest.cpp +++ b/llvm/unittests/IR/VFABIDemanglerTest.cpp @@ -40,7 +40,9 @@ class VFABIParserTest : public ::testing::Test { VFInfo Info; /// Reset the data needed for the test. void reset(const StringRef ScalarFTyStr) { - M = parseAssemblyString("declare void @dummy()", Err, Ctx); + M = parseAssemblyString("%dummy_named_struct = type { double, double }\n" + "declare void @dummy()", + Err, Ctx); EXPECT_NE(M.get(), nullptr) << "Loading an invalid module.\n " << Err.getMessage() << "\n"; Type *Ty = parseType(ScalarFTyStr, Err, *(M)); @@ -753,6 +755,87 @@ TEST_F(VFABIParserTest, ParseVoidReturnTypeSVE) { EXPECT_EQ(VectorName, "vector_foo"); } +TEST_F(VFABIParserTest, ParseWideStructReturnTypeSVE) { + EXPECT_TRUE( + invokeParser("_ZGVsMxv_foo(vector_foo)", "{double, double}(float)")); + EXPECT_EQ(ISA, VFISAKind::SVE); + EXPECT_TRUE(isMasked()); + ElementCount NXV2 = ElementCount::getScalable(2); + FunctionType *FTy = FunctionType::get( + StructType::get(VectorType::get(Type::getDoubleTy(Ctx), NXV2), + VectorType::get(Type::getDoubleTy(Ctx), NXV2)), + { + VectorType::get(Type::getFloatTy(Ctx), NXV2), + VectorType::get(Type::getInt1Ty(Ctx), NXV2), + }, + false); + EXPECT_EQ(getFunctionType(), FTy); + EXPECT_EQ(Parameters.size(), 2U); + EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector})); + EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::GlobalPredicate})); + EXPECT_EQ(VF, NXV2); + EXPECT_EQ(ScalarName, "foo"); + EXPECT_EQ(VectorName, "vector_foo"); +} + +TEST_F(VFABIParserTest, ParseWideStructMixedReturnTypeSVE) { + EXPECT_TRUE(invokeParser("_ZGVsMxv_foo(vector_foo)", "{float, i64}(float)")); + EXPECT_EQ(ISA, VFISAKind::SVE); + EXPECT_TRUE(isMasked()); + ElementCount NXV2 = ElementCount::getScalable(2); + FunctionType *FTy = FunctionType::get( + StructType::get(VectorType::get(Type::getFloatTy(Ctx), NXV2), + VectorType::get(Type::getInt64Ty(Ctx), NXV2)), + { + VectorType::get(Type::getFloatTy(Ctx), NXV2), + VectorType::get(Type::getInt1Ty(Ctx), NXV2), + }, + false); + EXPECT_EQ(getFunctionType(), FTy); + EXPECT_EQ(Parameters.size(), 2U); + EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector})); + EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::GlobalPredicate})); + EXPECT_EQ(VF, NXV2); + EXPECT_EQ(ScalarName, "foo"); + EXPECT_EQ(VectorName, "vector_foo"); +} + +TEST_F(VFABIParserTest, ParseWideStructReturnTypeNEON) { + EXPECT_TRUE( + invokeParser("_ZGVnN4v_foo(vector_foo)", "{float, float}(float)")); + EXPECT_EQ(ISA, VFISAKind::AdvancedSIMD); + EXPECT_FALSE(isMasked()); + ElementCount V4 = ElementCount::getFixed(4); + FunctionType *FTy = FunctionType::get( + StructType::get(VectorType::get(Type::getFloatTy(Ctx), V4), + VectorType::get(Type::getFloatTy(Ctx), V4)), + { + VectorType::get(Type::getFloatTy(Ctx), V4), + }, + false); + EXPECT_EQ(getFunctionType(), FTy); + EXPECT_EQ(Parameters.size(), 1U); + EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector})); + EXPECT_EQ(VF, V4); + EXPECT_EQ(ScalarName, "foo"); + EXPECT_EQ(VectorName, "vector_foo"); +} + +TEST_F(VFABIParserTest, ParseUnsupportedStructReturnTypesSVE) { + // Struct with array element type. + EXPECT_FALSE( + invokeParser("_ZGVsMxv_foo(vector_foo)", "{double, [4 x float]}(float)")); + // Nested struct type. + EXPECT_FALSE( + invokeParser("_ZGVsMxv_foo(vector_foo)", "{{float, float}}(float)")); + // Packed struct type. + EXPECT_FALSE( + invokeParser("_ZGVsMxv_foo(vector_foo)", "<{double, float}>(float)")); + // Named struct type. + EXPECT_FALSE( + invokeParser("_ZGVsMxv_foo(vector_foo)", "%dummy_named_struct(float)")); +} + // Make sure we reject unsupported parameter types. TEST_F(VFABIParserTest, ParseUnsupportedElementTypeSVE) { EXPECT_FALSE(invokeParser("_ZGVsMxv_foo(vector_foo)", "void(i128)")); diff --git a/llvm/unittests/IR/VectorTypeUtilsTest.cpp b/llvm/unittests/IR/VectorTypeUtilsTest.cpp new file mode 100644 index 0000000000000..c77f183e921de --- /dev/null +++ b/llvm/unittests/IR/VectorTypeUtilsTest.cpp @@ -0,0 +1,149 @@ +//===------- VectorTypeUtilsTest.cpp - Vector utils tests -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/VectorTypeUtils.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/LLVMContext.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { + +class VectorTypeUtilsTest : public ::testing::Test {}; + +TEST(VectorTypeUtilsTest, TestToVectorizedTy) { + LLVMContext C; + + Type *ITy = Type::getInt32Ty(C); + Type *FTy = Type::getFloatTy(C); + Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy); + Type *MixedStructTy = StructType::get(FTy, ITy); + Type *VoidTy = Type::getVoidTy(C); + + for (ElementCount VF : + {ElementCount::getFixed(4), ElementCount::getScalable(2)}) { + Type *IntVec = toVectorizedTy(ITy, VF); + EXPECT_TRUE(isa(IntVec)); + EXPECT_EQ(IntVec, VectorType::get(ITy, VF)); + + Type *FloatVec = toVectorizedTy(FTy, VF); + EXPECT_TRUE(isa(FloatVec)); + EXPECT_EQ(FloatVec, VectorType::get(FTy, VF)); + + Type *WideHomogeneousStructTy = toVectorizedTy(HomogeneousStructTy, VF); + EXPECT_TRUE(isa(WideHomogeneousStructTy)); + EXPECT_TRUE( + cast(WideHomogeneousStructTy)->containsHomogeneousTypes()); + EXPECT_TRUE(cast(WideHomogeneousStructTy)->getNumElements() == + 3); + EXPECT_TRUE(cast(WideHomogeneousStructTy)->getElementType(0) == + VectorType::get(FTy, VF)); + + Type *WideMixedStructTy = toVectorizedTy(MixedStructTy, VF); + EXPECT_TRUE(isa(WideMixedStructTy)); + EXPECT_TRUE(cast(WideMixedStructTy)->getNumElements() == 2); + EXPECT_TRUE(cast(WideMixedStructTy)->getElementType(0) == + VectorType::get(FTy, VF)); + EXPECT_TRUE(cast(WideMixedStructTy)->getElementType(1) == + VectorType::get(ITy, VF)); + + EXPECT_EQ(toVectorizedTy(VoidTy, VF), VoidTy); + } + + ElementCount ScalarVF = ElementCount::getFixed(1); + for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy, VoidTy}) { + EXPECT_EQ(toVectorizedTy(Ty, ScalarVF), Ty); + } +} + +TEST(VectorTypeUtilsTest, TestToScalarizedTy) { + LLVMContext C; + + Type *ITy = Type::getInt32Ty(C); + Type *FTy = Type::getFloatTy(C); + Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy); + Type *MixedStructTy = StructType::get(FTy, ITy); + Type *VoidTy = Type::getVoidTy(C); + + for (ElementCount VF : {ElementCount::getFixed(1), ElementCount::getFixed(4), + ElementCount::getScalable(2)}) { + for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy, VoidTy}) { + // toScalarizedTy should be the inverse of toVectorizedTy. + EXPECT_EQ(toScalarizedTy(toVectorizedTy(Ty, VF)), Ty); + }; + } +} + +TEST(VectorTypeUtilsTest, TestGetContainedTypes) { + LLVMContext C; + + Type *ITy = Type::getInt32Ty(C); + Type *FTy = Type::getFloatTy(C); + Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy); + Type *MixedStructTy = StructType::get(FTy, ITy); + Type *VoidTy = Type::getVoidTy(C); + + EXPECT_EQ(getContainedTypes(ITy), ArrayRef({ITy})); + EXPECT_EQ(getContainedTypes(FTy), ArrayRef({FTy})); + EXPECT_EQ(getContainedTypes(VoidTy), ArrayRef({VoidTy})); + EXPECT_EQ(getContainedTypes(HomogeneousStructTy), + ArrayRef({FTy, FTy, FTy})); + EXPECT_EQ(getContainedTypes(MixedStructTy), ArrayRef({FTy, ITy})); +} + +TEST(VectorTypeUtilsTest, TestIsVectorizedTy) { + LLVMContext C; + + Type *ITy = Type::getInt32Ty(C); + Type *FTy = Type::getFloatTy(C); + Type *NarrowStruct = StructType::get(FTy, ITy); + Type *VoidTy = Type::getVoidTy(C); + + EXPECT_FALSE(isVectorizedTy(ITy)); + EXPECT_FALSE(isVectorizedTy(NarrowStruct)); + EXPECT_FALSE(isVectorizedTy(VoidTy)); + + ElementCount VF = ElementCount::getFixed(4); + EXPECT_TRUE(isVectorizedTy(toVectorizedTy(ITy, VF))); + EXPECT_TRUE(isVectorizedTy(toVectorizedTy(NarrowStruct, VF))); + + Type *MixedVFStruct = + StructType::get(VectorType::get(ITy, ElementCount::getFixed(2)), + VectorType::get(ITy, ElementCount::getFixed(4))); + EXPECT_FALSE(isVectorizedTy(MixedVFStruct)); + + // Currently only literals types are considered wide. + Type *NamedWideStruct = StructType::create("Named", VectorType::get(ITy, VF), + VectorType::get(ITy, VF)); + EXPECT_FALSE(isVectorizedTy(NamedWideStruct)); + + // Currently only unpacked types are considered wide. + Type *PackedWideStruct = StructType::get( + C, ArrayRef{VectorType::get(ITy, VF), VectorType::get(ITy, VF)}, + /*isPacked=*/true); + EXPECT_FALSE(isVectorizedTy(PackedWideStruct)); +} + +TEST(VectorTypeUtilsTest, TestGetVectorizedTypeVF) { + LLVMContext C; + + Type *ITy = Type::getInt32Ty(C); + Type *FTy = Type::getFloatTy(C); + Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy); + Type *MixedStructTy = StructType::get(FTy, ITy); + + for (ElementCount VF : + {ElementCount::getFixed(4), ElementCount::getScalable(2)}) { + for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy}) { + EXPECT_EQ(getVectorizedTypeVF(toVectorizedTy(Ty, VF)), VF); + }; + } +} + +} // namespace