Skip to content

Commit

Permalink
address pr feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
farzonl committed Nov 4, 2024
1 parent c9f0d39 commit 9d36d73
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 33 deletions.
52 changes: 21 additions & 31 deletions llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//

///
/// \file This file contains a pass to flatten arrays for the DirectX Backend.
//
///
//===----------------------------------------------------------------------===//

#include "DXILFlattenArrays.h"
Expand All @@ -26,10 +25,12 @@
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <utility>

#define DEBUG_TYPE "dxil-flatten-arrays"

using namespace llvm;
namespace {

class DXILFlattenArraysLegacy : public ModulePass {

Expand Down Expand Up @@ -75,17 +76,16 @@ class DXILFlattenArraysVisitor
bool visitCallInst(CallInst &ICI) { return false; }
bool visitFreezeInst(FreezeInst &FI) { return false; }
static bool isMultiDimensionalArray(Type *T);
static unsigned getTotalElements(Type *ArrayTy);
static Type *getBaseElementType(Type *ArrayTy);
static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy);

private:
SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
bool finish();
ConstantInt *constFlattenIndices(ArrayRef<Value *> Indices,
ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,
ArrayRef<uint64_t> Dims,
IRBuilder<> &Builder);
Value *instructionFlattenIndices(ArrayRef<Value *> Indices,
Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
ArrayRef<uint64_t> Dims,
IRBuilder<> &Builder);
void
Expand All @@ -99,6 +99,7 @@ class DXILFlattenArraysVisitor
bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
GetElementPtrInst &GEP);
};
} // namespace

bool DXILFlattenArraysVisitor::finish() {
RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
Expand All @@ -111,25 +112,17 @@ bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
return false;
}

unsigned DXILFlattenArraysVisitor::getTotalElements(Type *ArrayTy) {
std::pair<unsigned, Type *> DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) {
unsigned TotalElements = 1;
Type *CurrArrayTy = ArrayTy;
while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
TotalElements *= InnerArrayTy->getNumElements();
CurrArrayTy = InnerArrayTy->getElementType();
}
return TotalElements;
}

Type *DXILFlattenArraysVisitor::getBaseElementType(Type *ArrayTy) {
Type *CurrArrayTy = ArrayTy;
while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
CurrArrayTy = InnerArrayTy->getElementType();
}
return CurrArrayTy;
return std::make_pair(TotalElements, CurrArrayTy);
}

ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
assert(Indices.size() == Dims.size() &&
"Indicies and dimmensions should be the same");
Expand All @@ -146,7 +139,7 @@ ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
return Builder.getInt32(FlatIndex);
}

Value *DXILFlattenArraysVisitor::instructionFlattenIndices(
Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
if (Indices.size() == 1)
return Indices[0];
Expand Down Expand Up @@ -202,10 +195,10 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {

ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
IRBuilder<> Builder(&AI);
unsigned TotalElements = getTotalElements(ArrType);
auto [TotalElements, BaseType] = getElementCountAndType(ArrType);

ArrayType *FattenedArrayType =
ArrayType::get(getBaseElementType(ArrType), TotalElements);
ArrayType::get(BaseType, TotalElements);
AllocaInst *FlatAlloca =
Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat");
FlatAlloca->setAlignment(AI.getAlign());
Expand Down Expand Up @@ -261,10 +254,10 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
IRBuilder<> Builder(&GEP);
Value *FlatIndex;
if (GEPInfo.AllIndicesAreConstInt)
FlatIndex = constFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
else
FlatIndex =
instructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);

ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
Value *FlatGEP =
Expand All @@ -285,9 +278,9 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {

ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
IRBuilder<> Builder(&GEP);
unsigned TotalElements = getTotalElements(ArrType);
auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
ArrayType *FlattenedArrayType =
ArrayType::get(getBaseElementType(ArrType), TotalElements);
ArrayType::get(BaseType, TotalElements);

Value *PtrOperand = GEP.getPointerOperand();

Expand All @@ -313,7 +306,6 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {

bool DXILFlattenArraysVisitor::visit(Function &F) {
bool MadeChange = false;
////for (BasicBlock &BB : make_early_inc_range(F)) {
ReversePostOrderTraversal<Function *> RPOT(&F);
for (BasicBlock *BB : make_early_inc_range(RPOT)) {
for (Instruction &I : make_early_inc_range(*BB)) {
Expand Down Expand Up @@ -345,8 +337,7 @@ static void collectElements(Constant *Init,
collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
}
} else {
assert(
false &&
llvm_unreachable (
"Expected a ConstantArray or ConstantDataArray for array initializer!");
}
}
Expand Down Expand Up @@ -382,10 +373,9 @@ flattenGlobalArrays(Module &M,
continue;

ArrayType *ArrType = cast<ArrayType>(OrigType);
unsigned TotalElements =
DXILFlattenArraysVisitor::getTotalElements(ArrType);
auto [TotalElements, BaseType] = DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
ArrayType *FattenedArrayType = ArrayType::get(
DXILFlattenArraysVisitor::getBaseElementType(ArrType), TotalElements);
BaseType, TotalElements);

// Create a new global variable with the updated type
// Note: Initializer is set via transformInitializer
Expand Down
2 changes: 0 additions & 2 deletions llvm/lib/Target/DirectX/DXILFlattenArrays.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
#ifndef LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
#define LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H

#include "DXILResource.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"

namespace llvm {

Expand Down

0 comments on commit 9d36d73

Please sign in to comment.