Skip to content

Commit 63498f3

Browse files
vkuzyakiigcbot
authored andcommitted
Add array support in GenXPrologEpilogInsertion
GenXPrologEpilogInsertion uses IndexFlattener to read aggregates from a register and to store aggregates on a register. IndexFlattener now supports nested aggregates (both structs and arrays), not just nested structs.
1 parent d0c8f80 commit 63498f3

File tree

3 files changed

+133
-114
lines changed

3 files changed

+133
-114
lines changed

IGC/VectorCompiler/lib/GenXCodeGen/GenXLiveness.cpp

Lines changed: 81 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -196,22 +196,22 @@ LiveRange *GenXLiveness::visitPropagateSLRs(Function *F)
196196
/***********************************************************************
197197
* buildLiveRange : build live range for one value (arg or non-baled inst)
198198
*
199-
* For a struct value, each element's live range is built separately, even
199+
* For an aggregate value, each element's live range is built separately, even
200200
* though they are almost identical. They are not exactly identical,
201201
* differing at the def if it is the return value of a call, and at a use
202202
* that is a call arg.
203203
*/
204204
void GenXLiveness::buildLiveRange(Value *V)
205205
{
206206
LLVM_DEBUG(dbgs() << "Building LiveRange for :" << *V << "\n");
207-
auto ST = dyn_cast<StructType>(V->getType());
208-
if (!ST) {
209-
LLVM_DEBUG(dbgs() << "It is not struct, build for one\n");
207+
Type *Ty = V->getType();
208+
if (!Ty->isAggregateType()) {
209+
LLVM_DEBUG(dbgs() << "It is not aggregate, build for one\n");
210210
buildLiveRange(SimpleValue(V));
211211
return;
212212
}
213-
for (unsigned i = 0, e = IndexFlattener::getNumElements(ST); i != e; ++i) {
214-
LLVM_DEBUG(dbgs() << "Bulding for struct Index " << i << " from " << e
213+
for (unsigned i = 0, e = IndexFlattener::getNumElements(Ty); i != e; ++i) {
214+
LLVM_DEBUG(dbgs() << "Bulding for aggregate Index " << i << " from " << e
215215
<< "\n");
216216
buildLiveRange(SimpleValue(V, i));
217217
}
@@ -638,17 +638,17 @@ LiveRange *GenXLiveness::getOrCreateLiveRange(SimpleValue V, unsigned Cat, unsig
638638

639639
/***********************************************************************
640640
* eraseLiveRange : get rid of live range for a Value, possibly multiple
641-
* ones if it is a struct value
641+
* ones if it is an aggregate value
642642
*/
643643
void GenXLiveness::eraseLiveRange(Value *V)
644644
{
645645
LLVM_DEBUG(dbgs() << "Erasing LiveRange for Value: " << *V << "\n");
646-
auto ST = dyn_cast<StructType>(V->getType());
647-
if (!ST) {
646+
Type *Ty = V->getType();
647+
if (!Ty->isAggregateType()) {
648648
eraseLiveRange(SimpleValue(V));
649649
return;
650650
}
651-
for (unsigned i = 0, e = IndexFlattener::getNumElements(ST); i != e; ++i)
651+
for (unsigned i = 0, e = IndexFlattener::getNumElements(Ty); i != e; ++i)
652652
eraseLiveRange(SimpleValue(V, i));
653653
}
654654

@@ -741,8 +741,8 @@ Value *GenXLiveness::getUnifiedRetIfExist(Function *F) const {
741741
* Cannot be called on a function with void return type.
742742
*
743743
* This also creates the LiveRange for the unified return value, or
744-
* multiple ones if it is struct type, and sets the category to the same as in
745-
* one of the return instructions.
744+
* multiple ones if it is aggregate type, and sets the category to the same as
745+
* in one of the return instructions.
746746
*/
747747
Value *GenXLiveness::createUnifiedRet(Function *F) {
748748
IGC_ASSERT_MESSAGE(!F->isDeclaration(), "must be a function definition");
@@ -760,11 +760,11 @@ Value *GenXLiveness::createUnifiedRet(Function *F) {
760760
Value *RetVal = Ret->getOperand(0);
761761
// Use the categories of its operand to set the categories of the unified
762762
// return value.
763-
for (unsigned StructIdx = 0, NumElements = IndexFlattener::getNumElements(Ty);
764-
StructIdx != NumElements; ++StructIdx) {
765-
int Cat = getOrCreateLiveRange(SimpleValue(RetVal, StructIdx))
763+
for (unsigned AggrIdx = 0, NumElements = IndexFlattener::getNumElements(Ty);
764+
AggrIdx != NumElements; ++AggrIdx) {
765+
int Cat = getOrCreateLiveRange(SimpleValue(RetVal, AggrIdx))
766766
->getOrDefaultCategory();
767-
SimpleValue SV(URet, StructIdx);
767+
SimpleValue SV{URet, AggrIdx};
768768
getOrCreateLiveRange(SV)->setCategory(Cat);
769769
}
770770

@@ -1205,7 +1205,7 @@ bool GenXLiveness::wrapsAround(Value *V1, Value *V2)
12051205
}
12061206

12071207
/***********************************************************************
1208-
* insertCopy : insert a copy of a non-struct value
1208+
* insertCopy : insert a copy of a non-aggregate value
12091209
*
12101210
* Enter: InputVal = value to copy
12111211
* LR = live range to add the new value to (0 to avoid adjusting
@@ -1508,13 +1508,13 @@ Value *GenXLiveness::getAddressBase(Value *Addr)
15081508
"base register not found for address");
15091509
Value *BaseV = i->second;
15101510
LiveRange *LR = getLiveRange(BaseV);
1511-
// Find a SimpleValue in the live range that is not a struct member.
1511+
// Find a SimpleValue in the live range that is not an aggregate member.
15121512
for (auto vi = LR->value_begin(), ve = LR->value_end(); vi != ve; ++vi) {
15131513
Value *V = vi->getValue();
1514-
if (!isa<StructType>(V->getType()))
1514+
if (!V->getType()->isAggregateType())
15151515
return V;
15161516
}
1517-
IGC_ASSERT_EXIT_MESSAGE(0, "non-struct value not found");
1517+
IGC_ASSERT_EXIT_MESSAGE(0, "non-aggregate value not found");
15181518
}
15191519

15201520
/***********************************************************************
@@ -1836,59 +1836,79 @@ void LiveRange::printSegments(raw_ostream &OS) const
18361836
}
18371837
}
18381838

1839+
// Returns the type of an aggregate's element at specific index. This is a
1840+
// generalization for structures and arrays.
1841+
static Type *getElementTypeOfAggregate(Type *AggrTy, unsigned Index) {
1842+
IGC_ASSERT_MESSAGE(AggrTy->isAggregateType(), "unexpected type");
1843+
if (isa<StructType>(AggrTy))
1844+
return cast<StructType>(AggrTy)->getTypeAtIndex(Index);
1845+
IGC_ASSERT_MESSAGE(Index < cast<ArrayType>(AggrTy)->getNumElements(),
1846+
"invalid array index");
1847+
return cast<ArrayType>(AggrTy)->getElementType();
1848+
}
1849+
1850+
// Returns the number of elements of an aggregate. This is a generalization for
1851+
// structures and arrays.
1852+
static unsigned getNumElementsOfAggregate(Type *AggrTy) {
1853+
IGC_ASSERT_MESSAGE(AggrTy->isAggregateType(), "unexpected type");
1854+
if (isa<StructType>(AggrTy))
1855+
return cast<StructType>(AggrTy)->getNumElements();
1856+
return cast<ArrayType>(AggrTy)->getNumElements();
1857+
}
1858+
18391859
/***********************************************************************
1840-
* IndexFlattener::flatten : convert struct indices into a flattened index
1860+
* IndexFlattener::flatten : convert aggregate indices into a flattened index
18411861
*
18421862
* This has a special case of Indices having a single element that is the
1843-
* number of elements in ST, which returns the total number of flattened
1844-
* indices in the struct.
1863+
* number of elements in AggrTy, which returns the total number of flattened
1864+
* indices in the aggregate.
18451865
*
1846-
* This involves scanning through the struct layout each time it is called.
1866+
* This involves scanning through the aggregate layout each time it is called.
18471867
* If it is used a lot, it might benefit from some cacheing of the results.
18481868
*/
1849-
unsigned IndexFlattener::flatten(StructType *ST, ArrayRef<unsigned> Indices)
1850-
{
1869+
unsigned IndexFlattener::flatten(Type *AggrTy, ArrayRef<unsigned> Indices) {
1870+
IGC_ASSERT_MESSAGE(AggrTy->isAggregateType(), "unexpected type");
18511871
if (!Indices.size())
18521872
return 0;
18531873
unsigned Flattened = 0;
18541874
unsigned i = 0;
18551875
for (; i != Indices[0]; ++i) {
1856-
Type *ElTy = ST->getElementType(i);
1857-
if (auto ElST = dyn_cast<StructType>(ElTy))
1858-
Flattened += flatten(ElST, ElST->getNumElements());
1876+
Type *ElTy = getElementTypeOfAggregate(AggrTy, i);
1877+
if (ElTy->isAggregateType())
1878+
Flattened += flatten(ElTy, getNumElementsOfAggregate(ElTy));
18591879
else
18601880
++Flattened;
18611881
}
1862-
if (i == ST->getNumElements())
1882+
if (i == getNumElementsOfAggregate(AggrTy))
18631883
return Flattened; // handle special case noted at the top
1864-
Type *ElTy = ST->getElementType(i);
1865-
if (auto ElST = dyn_cast<StructType>(ElTy))
1866-
Flattened += flatten(ElST, Indices.slice(1));
1884+
Type *ElTy = getElementTypeOfAggregate(AggrTy, i);
1885+
if (ElTy->isAggregateType())
1886+
Flattened += flatten(ElTy, Indices.slice(1));
18671887
return Flattened;
18681888
}
18691889

18701890
/***********************************************************************
1871-
* IndexFlattener::unflatten : convert flattened index into struct indices
1891+
* IndexFlattener::unflatten : convert flattened index into aggregate indices
18721892
*
18731893
* Enter: Indices = vector to put unflattened indices into
18741894
*
18751895
* Return: number left over from flattened index if it goes off the end
1876-
* of the struct (used internally when recursing). If this is
1896+
* of the aggregate (used internally when recursing). If this is
18771897
* non-zero, nothing has been pushed into Indices
18781898
*
1879-
* This involves scanning through the struct layout each time it is called.
1899+
* This involves scanning through the aggregate layout each time it is called.
18801900
* If it is used a lot, it might benefit from some cacheing of the results.
18811901
*/
1882-
unsigned IndexFlattener::unflatten(StructType *ST, unsigned Flattened,
1883-
SmallVectorImpl<unsigned> *Indices)
1884-
{
1902+
unsigned IndexFlattener::unflatten(Type *AggrTy, unsigned Flattened,
1903+
SmallVectorImpl<unsigned> *Indices) {
1904+
IGC_ASSERT_MESSAGE(AggrTy->isAggregateType(), "unexpected type");
18851905
++Flattened;
1886-
for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) {
1906+
for (unsigned i = 0, e = getNumElementsOfAggregate(AggrTy); i != e; ++i) {
18871907
--Flattened;
1888-
Type *ElTy = ST->getElementType(i);
1889-
if (auto ElST = dyn_cast<StructType>(ElTy)) {
1908+
Type *ElTy = getElementTypeOfAggregate(AggrTy, i);
1909+
if (ElTy->isAggregateType()) {
18901910
Indices->push_back(i);
1891-
Flattened = unflatten(ElST, Flattened, Indices);
1911+
Flattened = unflatten(ElTy, Flattened, Indices);
18921912
if (!Flattened)
18931913
return 0;
18941914
Indices->pop_back();
@@ -1901,36 +1921,42 @@ unsigned IndexFlattener::unflatten(StructType *ST, unsigned Flattened,
19011921
}
19021922

19031923
/***********************************************************************
1904-
* IndexFlattener::getElementType : get type of struct element from
1924+
* IndexFlattener::getElementType : get type of aggregate element from
19051925
* flattened index
19061926
*
1907-
* Enter: Ty = type, possibly struct type
1908-
* FlattenedIndex = flattened index in the struct, 0 if not struct
1927+
* Enter: Ty = type, possibly aggregate type
1928+
* FlattenedIndex = flattened index in the aggregate, 0 if not
1929+
* aggregate
19091930
*
19101931
* Return: type of that element
19111932
*/
19121933
Type *IndexFlattener::getElementType(Type *Ty, unsigned FlattenedIndex)
19131934
{
1914-
auto ST = dyn_cast<StructType>(Ty);
1915-
if (!ST)
1935+
if (!Ty->isAggregateType())
19161936
return Ty;
19171937
SmallVector<unsigned, 4> Indices;
1918-
IndexFlattener::unflatten(ST, FlattenedIndex, &Indices);
1919-
IGC_ASSERT(IndexFlattener::flatten(ST, Indices) == FlattenedIndex);
1938+
IndexFlattener::unflatten(Ty, FlattenedIndex, &Indices);
1939+
IGC_ASSERT(IndexFlattener::flatten(Ty, Indices) == FlattenedIndex);
19201940
Type *T = 0;
19211941
for (unsigned i = 0;;) {
1922-
T = ST->getElementType(Indices[i]);
1942+
T = getElementTypeOfAggregate(Ty, Indices[i]);
19231943
if (++i == Indices.size())
19241944
return T;
1925-
ST = cast<StructType>(T);
1945+
Ty = T;
19261946
}
19271947
}
19281948

1949+
unsigned IndexFlattener::getNumElements(Type *Ty) {
1950+
if (Ty->isAggregateType())
1951+
return flatten(Ty, getNumElementsOfAggregate(Ty));
1952+
return !Ty->isVoidTy();
1953+
}
1954+
19291955
/***********************************************************************
19301956
* IndexFlattener::flattenArg : flatten an arg in a function or call
19311957
*
19321958
* This calculates the total number of flattened indices used up by previous
1933-
* args. If all previous args are not struct type, then this just returns the
1959+
* args. If all previous args are not aggregate type, then this just returns the
19341960
* arg index.
19351961
*/
19361962
unsigned IndexFlattener::flattenArg(FunctionType *FT, unsigned ArgIndex)
@@ -1962,13 +1988,13 @@ void SimpleValue::dump() const
19621988
void SimpleValue::print(raw_ostream &OS) const
19631989
{
19641990
OS << V->getName();
1965-
if (Index || isa<StructType>(V->getType()))
1991+
if (Index || V->getType()->isAggregateType())
19661992
OS << "#" << Index;
19671993
}
19681994
void SimpleValue::printName(raw_ostream &OS) const
19691995
{
19701996
OS << V->getName();
1971-
if (Index || isa<StructType>(V->getType()))
1997+
if (Index || V->getType()->isAggregateType())
19721998
OS << "#" << Index;
19731999
}
19742000

0 commit comments

Comments
 (0)