Skip to content

Commit 68d2338

Browse files
committed
fixup! handle struct and minor fixup
1 parent 6b76dbc commit 68d2338

File tree

4 files changed

+384
-19
lines changed

4 files changed

+384
-19
lines changed

clang/lib/CodeGen/CGCall.cpp

+11-16
Original file line numberDiff line numberDiff line change
@@ -3239,24 +3239,19 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
32393239
}
32403240
}
32413241

3242-
llvm::StructType *STy =
3243-
dyn_cast<llvm::StructType>(ArgI.getCoerceToType());
3244-
if (ArgI.isDirect() && !ArgI.getCanBeFlattened() && STy &&
3245-
STy->getNumElements() > 1) {
3246-
[[maybe_unused]] llvm::TypeSize StructSize =
3247-
CGM.getDataLayout().getTypeAllocSize(STy);
3248-
[[maybe_unused]] llvm::TypeSize PtrElementSize =
3249-
CGM.getDataLayout().getTypeAllocSize(ConvertTypeForMem(Ty));
3250-
if (STy->containsHomogeneousScalableVectorTypes()) {
3251-
assert(StructSize == PtrElementSize &&
3252-
"Only allow non-fractional movement of structure with"
3253-
"homogeneous scalable vector type");
3254-
3255-
ArgVals.push_back(ParamValue::forDirect(AI));
3256-
break;
3257-
}
3242+
// Struct of fixed-length vectors and struct of array of fixed-length
3243+
// vector in VLS calling convention are coerced to vector tuple
3244+
// type(represented as TargetExtType) and scalable vector type
3245+
// respectively, they're no longer handled as struct.
3246+
if (ArgI.isDirect() && isa<llvm::StructType>(ConvertType(Ty)) &&
3247+
(isa<llvm::TargetExtType>(ArgI.getCoerceToType()) ||
3248+
isa<llvm::ScalableVectorType>(ArgI.getCoerceToType()))) {
3249+
ArgVals.push_back(ParamValue::forDirect(AI));
3250+
break;
32583251
}
32593252

3253+
llvm::StructType *STy =
3254+
dyn_cast<llvm::StructType>(ArgI.getCoerceToType());
32603255
Address Alloca = CreateMemTemp(Ty, getContext().getDeclAlign(Arg),
32613256
Arg->getName());
32623257

clang/lib/CodeGen/Targets/RISCV.cpp

+157-3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ class RISCVABIInfo : public DefaultABIInfo {
3535
llvm::Type *&Field2Ty,
3636
CharUnits &Field2Off) const;
3737

38+
bool detectVLSCCEligibleStruct(QualType Ty, unsigned ABIVLen,
39+
llvm::Type *&VLSType) const;
40+
3841
public:
3942
RISCVABIInfo(CodeGen::CodeGenTypes &CGT, unsigned XLen, unsigned FLen,
4043
bool EABI)
@@ -361,6 +364,149 @@ ABIArgInfo RISCVABIInfo::coerceAndExpandFPCCEligibleStruct(
361364
return ABIArgInfo::getCoerceAndExpand(CoerceToType, UnpaddedCoerceToType);
362365
}
363366

367+
bool RISCVABIInfo::detectVLSCCEligibleStruct(QualType Ty, unsigned ABIVLen,
368+
llvm::Type *&VLSType) const {
369+
// No riscv_vls_cc attribute.
370+
if (ABIVLen == 1)
371+
return false;
372+
373+
// Legal struct for VLS calling convention should fulfill following rules:
374+
// 1. Struct element should be either "homogeneous fixed-length vectors" or "a
375+
// fixed-length vector array".
376+
// 2. Number of struct elements or array elements should be power of 2.
377+
// 3. Total number of vector registers needed should not exceed 8.
378+
//
379+
// Examples: Assume ABI_VLEN = 128.
380+
// These are legal structs:
381+
// a. Structs with 1, 2, 4 or 8 "same" fixed-length vectors, e.g.
382+
// struct {
383+
// __attribute__((vector_size(16))) int a;
384+
// __attribute__((vector_size(16))) int b;
385+
// }
386+
//
387+
// b. Structs with "single" fixed-length vector array with lengh 1, 2, 4
388+
// or 8, e.g.
389+
// struct {
390+
// __attribute__((vector_size(16))) int a[2];
391+
// }
392+
// These are illegal structs:
393+
// a. Structs with 3 fixed-length vectors, e.g.
394+
// struct {
395+
// __attribute__((vector_size(16))) int a;
396+
// __attribute__((vector_size(16))) int b;
397+
// __attribute__((vector_size(16))) int c;
398+
// }
399+
//
400+
// b. Structs with "multiple" fixed-length vector array, e.g.
401+
// struct {
402+
// __attribute__((vector_size(16))) int a[2];
403+
// __attribute__((vector_size(16))) int b[2];
404+
// }
405+
//
406+
// c. Vector registers needed exceeds 8, e.g.
407+
// struct {
408+
// // Registers needed for single fixed-length element:
409+
// // 64 * 8 / ABI_VLEN = 4
410+
// __attribute__((vector_size(64))) int a;
411+
// __attribute__((vector_size(64))) int b;
412+
// __attribute__((vector_size(64))) int c;
413+
// __attribute__((vector_size(64))) int d;
414+
// }
415+
//
416+
// Struct of 1 fixed-length vector is passed as a scalable vector.
417+
// Struct of >1 fixed-length vectors are passed as vector tuple.
418+
// Struct of 1 array of fixed-length vectors is passed as a scalable vector.
419+
// Otherwise, pass the struct indirectly.
420+
421+
if (llvm::StructType *STy = dyn_cast<llvm::StructType>(CGT.ConvertType(Ty))) {
422+
int NumElts = STy->getStructNumElements();
423+
if (NumElts > 8 || !llvm::isPowerOf2_32(NumElts))
424+
return false;
425+
426+
auto *FirstEltTy = STy->getElementType(0);
427+
if (!STy->containsHomogeneousTypes())
428+
return false;
429+
430+
// Check structure of fixed-length vectors and turn them into vector tuple
431+
// type if legal.
432+
if (auto *FixedVecTy = dyn_cast<llvm::FixedVectorType>(FirstEltTy)) {
433+
if (NumElts == 1) {
434+
// Handle single fixed-length vector.
435+
VLSType = llvm::ScalableVectorType::get(
436+
FixedVecTy->getElementType(),
437+
llvm::divideCeil(FixedVecTy->getNumElements() *
438+
llvm::RISCV::RVVBitsPerBlock,
439+
ABIVLen));
440+
// Check registers needed <= 8.
441+
return llvm::divideCeil(
442+
FixedVecTy->getNumElements() *
443+
FixedVecTy->getElementType()->getScalarSizeInBits(),
444+
ABIVLen) <= 8;
445+
}
446+
// LMUL
447+
// = fixed-length vector size / ABIVLen
448+
// = 8 * I8EltCount / RVVBitsPerBlock
449+
// =>
450+
// I8EltCount
451+
// = (fixed-length vector size * RVVBitsPerBlock) / (ABIVLen * 8)
452+
unsigned I8EltCount = llvm::divideCeil(
453+
FixedVecTy->getNumElements() *
454+
FixedVecTy->getElementType()->getScalarSizeInBits() *
455+
llvm::RISCV::RVVBitsPerBlock,
456+
ABIVLen * 8);
457+
VLSType = llvm::TargetExtType::get(
458+
getVMContext(), "riscv.vector.tuple",
459+
llvm::ScalableVectorType::get(llvm::Type::getInt8Ty(getVMContext()),
460+
I8EltCount),
461+
NumElts);
462+
// Check registers needed <= 8.
463+
return NumElts *
464+
llvm::divideCeil(
465+
FixedVecTy->getNumElements() *
466+
FixedVecTy->getElementType()->getScalarSizeInBits(),
467+
ABIVLen) <=
468+
8;
469+
}
470+
471+
// If elements are not fixed-length vectors, it should be an array.
472+
if (NumElts != 1)
473+
return false;
474+
475+
// Check array of fixed-length vector and turn it into scalable vector type
476+
// if legal.
477+
if (auto *ArrTy = dyn_cast<llvm::ArrayType>(FirstEltTy)) {
478+
int NumArrElt = ArrTy->getNumElements();
479+
if (NumArrElt > 8 || !llvm::isPowerOf2_32(NumArrElt))
480+
return false;
481+
482+
auto *ArrEltTy = dyn_cast<llvm::FixedVectorType>(ArrTy->getElementType());
483+
if (!ArrEltTy)
484+
return false;
485+
486+
// LMUL
487+
// = NumArrElt * fixed-length vector size / ABIVLen
488+
// = fixed-length vector elt size * ScalVecNumElts / RVVBitsPerBlock
489+
// =>
490+
// ScalVecNumElts
491+
// = (NumArrElt * fixed-length vector size * RVVBitsPerBlock) /
492+
// (ABIVLen * fixed-length vector elt size)
493+
// = NumArrElt * num fixed-length vector elt * RVVBitsPerBlock /
494+
// ABIVLen
495+
unsigned ScalVecNumElts = llvm::divideCeil(
496+
NumArrElt * ArrEltTy->getNumElements() * llvm::RISCV::RVVBitsPerBlock,
497+
ABIVLen);
498+
VLSType = llvm::ScalableVectorType::get(ArrEltTy->getElementType(),
499+
ScalVecNumElts);
500+
// Check registers needed <= 8.
501+
return llvm::divideCeil(
502+
ScalVecNumElts *
503+
ArrEltTy->getElementType()->getScalarSizeInBits(),
504+
llvm::RISCV::RVVBitsPerBlock) <= 8;
505+
}
506+
}
507+
return false;
508+
}
509+
364510
// Fixed-length RVV vectors are represented as scalable vectors in function
365511
// args/return and must be coerced from fixed vectors.
366512
ABIArgInfo RISCVABIInfo::coerceVLSVector(QualType Ty,
@@ -410,11 +556,13 @@ ABIArgInfo RISCVABIInfo::coerceVLSVector(QualType Ty,
410556
(EltType->isBFloatTy() && !TI.hasFeature("zvfbfmin")) ||
411557
(EltType->isFloatTy() && !TI.hasFeature("zve32f")) ||
412558
(EltType->isDoubleTy() && !TI.hasFeature("zve64d")) ||
413-
(EltType->isIntegerTy(64) && !TI.hasFeature("zve64x")) ||
414-
EltType->isIntegerTy(128)) {
559+
EltType->isIntegerTy(128))
415560
EltType =
416561
llvm::Type::getIntNTy(getVMContext(), EltType->getScalarSizeInBits());
417-
}
562+
563+
// Check registers needed <= 8.
564+
if ((EltType->getScalarSizeInBits() * NumElts / ABIVLen) > 8)
565+
return getNaturalAlignIndirect(Ty, /*ByVal=*/false);
418566

419567
// Generic vector
420568
// The number of elements needs to be at least 1.
@@ -485,6 +633,12 @@ ABIArgInfo RISCVABIInfo::classifyArgumentType(QualType Ty, bool IsFixed,
485633
}
486634
}
487635

636+
if (IsFixed && Ty->isStructureOrClassType()) {
637+
llvm::Type *VLSType = nullptr;
638+
if (detectVLSCCEligibleStruct(Ty, ABIVLen, VLSType))
639+
return ABIArgInfo::getDirect(VLSType);
640+
}
641+
488642
uint64_t NeededAlign = getContext().getTypeAlign(Ty);
489643
// Determine the number of GPRs needed to pass the current argument
490644
// according to the ABI. 2*XLen-aligned varargs are passed in "aligned"

clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c

+108
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,111 @@ void __attribute__((riscv_vls_cc(1024))) test_vls_least_element(__attribute__((v
6767

6868
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_vls_least_element_c23(<vscale x 1 x i32> noundef %arg.coerce)
6969
[[riscv::vls_cc(1024)]] void test_vls_least_element_c23(__attribute__((vector_size(8))) int arg) {}
70+
71+
72+
struct st_i32x4{
73+
__attribute__((vector_size(16))) int i32;
74+
};
75+
76+
struct st_i32x4_arr1{
77+
__attribute__((vector_size(16))) int i32[1];
78+
};
79+
80+
struct st_i32x4_arr4{
81+
__attribute__((vector_size(16))) int i32[4];
82+
};
83+
84+
struct st_i32x4_arr8{
85+
__attribute__((vector_size(16))) int i32[8];
86+
};
87+
88+
89+
struct st_i32x4x2{
90+
__attribute__((vector_size(16))) int i32_1;
91+
__attribute__((vector_size(16))) int i32_2;
92+
};
93+
94+
struct st_i32x8x2{
95+
__attribute__((vector_size(32))) int i32_1;
96+
__attribute__((vector_size(32))) int i32_2;
97+
};
98+
99+
struct st_i32x64x2{
100+
__attribute__((vector_size(256))) int i32_1;
101+
__attribute__((vector_size(256))) int i32_2;
102+
};
103+
104+
struct st_i32x4x8{
105+
__attribute__((vector_size(16))) int i32_1;
106+
__attribute__((vector_size(16))) int i32_2;
107+
__attribute__((vector_size(16))) int i32_3;
108+
__attribute__((vector_size(16))) int i32_4;
109+
__attribute__((vector_size(16))) int i32_5;
110+
__attribute__((vector_size(16))) int i32_6;
111+
__attribute__((vector_size(16))) int i32_7;
112+
__attribute__((vector_size(16))) int i32_8;
113+
};
114+
115+
struct st_i32x4x9{
116+
__attribute__((vector_size(16))) int i32_1;
117+
__attribute__((vector_size(16))) int i32_2;
118+
__attribute__((vector_size(16))) int i32_3;
119+
__attribute__((vector_size(16))) int i32_4;
120+
__attribute__((vector_size(16))) int i32_5;
121+
__attribute__((vector_size(16))) int i32_6;
122+
__attribute__((vector_size(16))) int i32_7;
123+
__attribute__((vector_size(16))) int i32_8;
124+
__attribute__((vector_size(16))) int i32_9;
125+
};
126+
127+
typedef int __attribute__((vector_size(256))) int32x64_t;
128+
129+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_too_large(ptr noundef %0)
130+
void __attribute__((riscv_vls_cc)) test_too_large(int32x64_t arg) {}
131+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_too_large_256(<vscale x 16 x i32> noundef %arg.coerce)
132+
void __attribute__((riscv_vls_cc(256))) test_too_large_256(int32x64_t arg) {}
133+
134+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4(<vscale x 2 x i32> %arg)
135+
void __attribute__((riscv_vls_cc)) test_st_i32x4(struct st_i32x4 arg) {}
136+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4_256(<vscale x 1 x i32> %arg)
137+
void __attribute__((riscv_vls_cc(256))) test_st_i32x4_256(struct st_i32x4 arg) {}
138+
139+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4_arr1(<vscale x 2 x i32> %arg)
140+
void __attribute__((riscv_vls_cc)) test_st_i32x4_arr1(struct st_i32x4_arr1 arg) {}
141+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4_arr1_256(<vscale x 1 x i32> %arg)
142+
void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr1_256(struct st_i32x4_arr1 arg) {}
143+
144+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4_arr4(<vscale x 8 x i32> %arg)
145+
void __attribute__((riscv_vls_cc)) test_st_i32x4_arr4(struct st_i32x4_arr4 arg) {}
146+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4_arr4_256(<vscale x 4 x i32> %arg)
147+
void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr4_256(struct st_i32x4_arr4 arg) {}
148+
149+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4_arr8(<vscale x 16 x i32> %arg)
150+
void __attribute__((riscv_vls_cc)) test_st_i32x4_arr8(struct st_i32x4_arr8 arg) {}
151+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4_arr8_256(<vscale x 8 x i32> %arg)
152+
void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr8_256(struct st_i32x4_arr8 arg) {}
153+
154+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4x2(target("riscv.vector.tuple", <vscale x 8 x i8>, 2) %arg)
155+
void __attribute__((riscv_vls_cc)) test_st_i32x4x2(struct st_i32x4x2 arg) {}
156+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4x2_256(target("riscv.vector.tuple", <vscale x 4 x i8>, 2) %arg)
157+
void __attribute__((riscv_vls_cc(256))) test_st_i32x4x2_256(struct st_i32x4x2 arg) {}
158+
159+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x8x2(target("riscv.vector.tuple", <vscale x 16 x i8>, 2) %arg)
160+
void __attribute__((riscv_vls_cc)) test_st_i32x8x2(struct st_i32x8x2 arg) {}
161+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x8x2_256(target("riscv.vector.tuple", <vscale x 8 x i8>, 2) %arg)
162+
void __attribute__((riscv_vls_cc(256))) test_st_i32x8x2_256(struct st_i32x8x2 arg) {}
163+
164+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x64x2(ptr noundef %arg)
165+
void __attribute__((riscv_vls_cc)) test_st_i32x64x2(struct st_i32x64x2 arg) {}
166+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x64x2_256(ptr noundef %arg)
167+
void __attribute__((riscv_vls_cc(256))) test_st_i32x64x2_256(struct st_i32x64x2 arg) {}
168+
169+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4x8(target("riscv.vector.tuple", <vscale x 8 x i8>, 8) %arg)
170+
void __attribute__((riscv_vls_cc)) test_st_i32x4x8(struct st_i32x4x8 arg) {}
171+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4x8_256(target("riscv.vector.tuple", <vscale x 4 x i8>, 8) %arg)
172+
void __attribute__((riscv_vls_cc(256))) test_st_i32x4x8_256(struct st_i32x4x8 arg) {}
173+
174+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4x9(ptr noundef %arg)
175+
void __attribute__((riscv_vls_cc)) test_st_i32x4x9(struct st_i32x4x9 arg) {}
176+
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4x9_256(ptr noundef %arg)
177+
void __attribute__((riscv_vls_cc(256))) test_st_i32x4x9_256(struct st_i32x4x9 arg) {}

0 commit comments

Comments
 (0)