Skip to content

Commit

Permalink
Merge "merge main into amd-staging" into amd-staging
Browse files Browse the repository at this point in the history
  • Loading branch information
ronlieb committed Dec 23, 2024
2 parents c2c99e2 + c5027e5 commit 96a12bb
Show file tree
Hide file tree
Showing 97 changed files with 4,558 additions and 634 deletions.
68 changes: 44 additions & 24 deletions clang/lib/CodeGen/Targets/AArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class AArch64ABIInfo : public ABIInfo {

bool isIllegalVectorType(QualType Ty) const;

bool passAsAggregateType(QualType Ty) const;
bool passAsPureScalableType(QualType Ty, unsigned &NV, unsigned &NP,
SmallVectorImpl<llvm::Type *> &CoerceToSeq) const;

Expand Down Expand Up @@ -337,6 +338,10 @@ ABIArgInfo AArch64ABIInfo::coerceAndExpandPureScalableAggregate(
NSRN += NVec;
NPRN += NPred;

// Handle SVE vector tuples.
if (Ty->isSVESizelessBuiltinType())
return ABIArgInfo::getDirect();

llvm::Type *UnpaddedCoerceToType =
UnpaddedCoerceToSeq.size() == 1
? UnpaddedCoerceToSeq[0]
Expand All @@ -362,7 +367,7 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
if (isIllegalVectorType(Ty))
return coerceIllegalVector(Ty, NSRN, NPRN);

if (!isAggregateTypeForABI(Ty)) {
if (!passAsAggregateType(Ty)) {
// Treat an enum type as its underlying type.
if (const EnumType *EnumTy = Ty->getAs<EnumType>())
Ty = EnumTy->getDecl()->getIntegerType();
Expand Down Expand Up @@ -417,7 +422,7 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
// elsewhere for GNU compatibility.
uint64_t Size = getContext().getTypeSize(Ty);
bool IsEmpty = isEmptyRecord(getContext(), Ty, true);
if (IsEmpty || Size == 0) {
if (!Ty->isSVESizelessBuiltinType() && (IsEmpty || Size == 0)) {
if (!getContext().getLangOpts().CPlusPlus || isDarwinPCS())
return ABIArgInfo::getIgnore();

Expand Down Expand Up @@ -504,7 +509,7 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
if (RetTy->isVectorType() && getContext().getTypeSize(RetTy) > 128)
return getNaturalAlignIndirect(RetTy);

if (!isAggregateTypeForABI(RetTy)) {
if (!passAsAggregateType(RetTy)) {
// Treat an enum type as its underlying type.
if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
RetTy = EnumTy->getDecl()->getIntegerType();
Expand All @@ -519,7 +524,8 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
}

uint64_t Size = getContext().getTypeSize(RetTy);
if (isEmptyRecord(getContext(), RetTy, true) || Size == 0)
if (!RetTy->isSVESizelessBuiltinType() &&
(isEmptyRecord(getContext(), RetTy, true) || Size == 0))
return ABIArgInfo::getIgnore();

const Type *Base = nullptr;
Expand Down Expand Up @@ -654,6 +660,15 @@ bool AArch64ABIInfo::isZeroLengthBitfieldPermittedInHomogeneousAggregate()
return true;
}

bool AArch64ABIInfo::passAsAggregateType(QualType Ty) const {
if (Kind == AArch64ABIKind::AAPCS && Ty->isSVESizelessBuiltinType()) {
const auto *BT = Ty->getAs<BuiltinType>();
return !BT->isSVECount() &&
getContext().getBuiltinVectorTypeInfo(BT).NumVectors > 1;
}
return isAggregateTypeForABI(Ty);
}

// Check if a type needs to be passed in registers as a Pure Scalable Type (as
// defined by AAPCS64). Return the number of data vectors and the number of
// predicate vectors in the type, into `NVec` and `NPred`, respectively. Upon
Expand Down Expand Up @@ -719,37 +734,38 @@ bool AArch64ABIInfo::passAsPureScalableType(
return true;
}

const auto *VT = Ty->getAs<VectorType>();
if (!VT)
return false;
if (const auto *VT = Ty->getAs<VectorType>()) {
if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
++NPred;
if (CoerceToSeq.size() + 1 > 12)
return false;
CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
return true;
}

if (VT->getVectorKind() == VectorKind::SveFixedLengthPredicate) {
++NPred;
if (CoerceToSeq.size() + 1 > 12)
return false;
CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
return true;
}
if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
++NVec;
if (CoerceToSeq.size() + 1 > 12)
return false;
CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
return true;
}

if (VT->getVectorKind() == VectorKind::SveFixedLengthData) {
++NVec;
if (CoerceToSeq.size() + 1 > 12)
return false;
CoerceToSeq.push_back(convertFixedToScalableVectorType(VT));
return true;
return false;
}

if (!VT->isBuiltinType())
if (!Ty->isBuiltinType())
return false;

switch (cast<BuiltinType>(VT)->getKind()) {
bool isPredicate;
switch (Ty->getAs<BuiltinType>()->getKind()) {
#define SVE_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
case BuiltinType::Id: \
++NVec; \
isPredicate = false; \
break;
#define SVE_PREDICATE_TYPE(Name, MangledName, Id, SingletonId) \
case BuiltinType::Id: \
++NPred; \
isPredicate = true; \
break;
#define SVE_TYPE(Name, Id, SingletonId)
#include "clang/Basic/AArch64SVEACLETypes.def"
Expand All @@ -761,6 +777,10 @@ bool AArch64ABIInfo::passAsPureScalableType(
getContext().getBuiltinVectorTypeInfo(cast<BuiltinType>(Ty));
assert(Info.NumVectors > 0 && Info.NumVectors <= 4 &&
"Expected 1, 2, 3 or 4 vectors!");
if (isPredicate)
NPred += Info.NumVectors;
else
NVec += Info.NumVectors;
auto VTy = llvm::ScalableVectorType::get(CGT.ConvertType(Info.ElementType),
Info.EC.getKnownMinValue());

Expand Down
4 changes: 2 additions & 2 deletions clang/lib/Headers/avx10_2_512convertintrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,13 @@ static __inline __m512h __DEFAULT_FN_ATTRS512 _mm512_cvtpbf8_ph(__m256i __A) {
}

static __inline __m512h __DEFAULT_FN_ATTRS512
_mm512_mask_cvtpbf8_ph(__m512h __S, __mmask16 __U, __m256i __A) {
_mm512_mask_cvtpbf8_ph(__m512h __S, __mmask32 __U, __m256i __A) {
return _mm512_castsi512_ph(
_mm512_mask_slli_epi16((__m512i)__S, __U, _mm512_cvtepi8_epi16(__A), 8));
}

static __inline __m512h __DEFAULT_FN_ATTRS512
_mm512_maskz_cvtpbf8_ph(__mmask16 __U, __m256i __A) {
_mm512_maskz_cvtpbf8_ph(__mmask32 __U, __m256i __A) {
return _mm512_castsi512_ph(
_mm512_slli_epi16(_mm512_maskz_cvtepi8_epi16(__U, __A), 8));
}
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/Headers/avx10_2convertintrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,13 +580,13 @@ static __inline__ __m256h __DEFAULT_FN_ATTRS256 _mm256_cvtpbf8_ph(__m128i __A) {
}

static __inline__ __m256h __DEFAULT_FN_ATTRS256
_mm256_mask_cvtpbf8_ph(__m256h __S, __mmask8 __U, __m128i __A) {
_mm256_mask_cvtpbf8_ph(__m256h __S, __mmask16 __U, __m128i __A) {
return _mm256_castsi256_ph(
_mm256_mask_slli_epi16((__m256i)__S, __U, _mm256_cvtepi8_epi16(__A), 8));
}

static __inline__ __m256h __DEFAULT_FN_ATTRS256
_mm256_maskz_cvtpbf8_ph(__mmask8 __U, __m128i __A) {
_mm256_maskz_cvtpbf8_ph(__mmask16 __U, __m128i __A) {
return _mm256_castsi256_ph(
_mm256_slli_epi16(_mm256_maskz_cvtepi8_epi16(__U, __A), 8));
}
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/Serialization/ASTReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10642,7 +10642,8 @@ void ASTReader::FinishedDeserializing() {
// We do this now rather than in finishPendingActions because we want to
// be able to walk the complete redeclaration chains of the updated decls.
while (!PendingExceptionSpecUpdates.empty() ||
!PendingDeducedTypeUpdates.empty()) {
!PendingDeducedTypeUpdates.empty() ||
!PendingUndeducedFunctionDecls.empty()) {
auto ESUpdates = std::move(PendingExceptionSpecUpdates);
PendingExceptionSpecUpdates.clear();
for (auto Update : ESUpdates) {
Expand Down
19 changes: 19 additions & 0 deletions clang/test/CodeGen/AArch64/pure-scalable-args.c
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,22 @@ void test_va_arg(int n, ...) {
// CHECK-DARWIN-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %ap)
// CHECK-DARWIN-NEXT: ret void
// CHECK-DARWIN-NEXT: }

// Regression test for incorrect passing of SVE vector tuples
// The whole `y` need to be passed indirectly.
void test_tuple_reg_count(svfloat32_t x, svfloat32x2_t y) {
void test_tuple_reg_count_callee(svfloat32_t, svfloat32_t, svfloat32_t, svfloat32_t,
svfloat32_t, svfloat32_t, svfloat32_t, svfloat32x2_t);
test_tuple_reg_count_callee(x, x, x, x, x, x, x, y);
}
// CHECK-AAPCS: declare void @test_tuple_reg_count_callee(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, ptr noundef)
// CHECK-DARWIN: declare void @test_tuple_reg_count_callee(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>)

// Regression test for incorrect passing of SVE vector tuples
// The whole `y` need to be passed indirectly.
void test_tuple_reg_count_bool(svboolx4_t x, svboolx4_t y) {
void test_tuple_reg_count_bool_callee(svboolx4_t, svboolx4_t);
test_tuple_reg_count_bool_callee(x, y);
}
// CHECK-AAPCS: declare void @test_tuple_reg_count_bool_callee(<vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, ptr noundef)
// CHECK-DARWIN: declare void @test_tuple_reg_count_bool_callee(<vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>)
4 changes: 2 additions & 2 deletions clang/test/CodeGen/X86/avx10_2_512convert-builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ __m512h test_mm512_cvtpbf8_ph(__m256i A) {
return _mm512_cvtpbf8_ph(A);
}

__m512h test_mm512_mask_cvtpbf8_ph(__m512h S, __mmask16 M, __m256i A) {
__m512h test_mm512_mask_cvtpbf8_ph(__m512h S, __mmask32 M, __m256i A) {
// CHECK-LABEL: @test_mm512_mask_cvtpbf8_ph
// CHECK: sext <32 x i8> %{{.*}} to <32 x i16>
// CHECK: @llvm.x86.avx512.pslli.w.512
Expand All @@ -308,7 +308,7 @@ __m512h test_mm512_mask_cvtpbf8_ph(__m512h S, __mmask16 M, __m256i A) {
return _mm512_mask_cvtpbf8_ph(S, M, A);
}

__m512h test_mm512_maskz_cvtpbf8_ph(__mmask16 M, __m256i A) {
__m512h test_mm512_maskz_cvtpbf8_ph(__mmask32 M, __m256i A) {
// CHECK-LABEL: @test_mm512_maskz_cvtpbf8_ph
// CHECK: sext <32 x i8> %{{.*}} to <32 x i16>
// CHECK: select <32 x i1> %{{.*}}, <32 x i16> %{{.*}}, <32 x i16> %{{.*}}
Expand Down
Loading

0 comments on commit 96a12bb

Please sign in to comment.