Skip to content

Commit 4affeb6

Browse files
kzhuravlyxsamliu
andauthored
convert HIP struct type vector to llvm vector type (#416)
Co-authored-by: Yaxun (Sam) Liu <yaxun.liu@amd.com>
1 parent d39c106 commit 4affeb6

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

llvm/lib/Transforms/Scalar/SROA.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
#include "llvm/Transforms/Scalar.h"
8484
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
8585
#include "llvm/Transforms/Utils/Local.h"
86+
#include "llvm/TargetParser/Triple.h"
8687
#include "llvm/Transforms/Utils/PromoteMemToReg.h"
8788
#include "llvm/Transforms/Utils/SSAUpdater.h"
8889
#include <algorithm>
@@ -5007,6 +5008,34 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
50075008
// FIXME: We might want to defer PHI speculation until after here.
50085009
// FIXME: return nullptr;
50095010
} else {
5011+
// AMDGPU: If the target is AMDGPU and the chosen SliceTy is a HIP vector
5012+
// struct of 2 or 4 identical elements, canonicalize it to an IR vector.
5013+
// This helps SROA treat it as a single value and unlock vector ld/st.
5014+
// We pattern-match struct names starting with "struct.HIP_vector".
5015+
if (Function *F = AI.getFunction()) {
5016+
Triple TT(F->getParent()->getTargetTriple());
5017+
if (TT.isAMDGPU()) {
5018+
if (auto *STy = dyn_cast<StructType>(SliceTy)) {
5019+
StringRef Name = STy->hasName() ? STy->getName() : StringRef();
5020+
if (Name.starts_with("struct.HIP_vector")) {
5021+
unsigned NumElts = STy->getNumElements();
5022+
if ((NumElts == 2 || NumElts == 4) && NumElts > 0) {
5023+
Type *EltTy = STy->getElementType(0);
5024+
bool AllSame = true;
5025+
for (unsigned I = 1; I < NumElts; ++I)
5026+
if (STy->getElementType(I) != EltTy) {
5027+
AllSame = false;
5028+
break;
5029+
}
5030+
if (AllSame && VectorType::isValidElementType(EltTy)) {
5031+
SliceTy = FixedVectorType::get(EltTy, NumElts);
5032+
}
5033+
}
5034+
}
5035+
}
5036+
}
5037+
}
5038+
50105039
// Make sure the alignment is compatible with P.beginOffset().
50115040
const Align Alignment = commonAlignment(AI.getAlign(), P.beginOffset());
50125041
// If we will get at least this much alignment from the type alone, leave

0 commit comments

Comments
 (0)