From 9d36d73404a1a9f9bca9bcb276147f4542119337 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Mon, 4 Nov 2024 14:15:00 -0500 Subject: [PATCH] address pr feedback --- llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 52 ++++++++----------- llvm/lib/Target/DirectX/DXILFlattenArrays.h | 2 - 2 files changed, 21 insertions(+), 33 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp index 20c7401e934e6c5..65b5c2a2764c6ed 100644 --- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp +++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp @@ -5,10 +5,9 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===---------------------------------------------------------------------===// - /// /// \file This file contains a pass to flatten arrays for the DirectX Backend. -// +/// //===----------------------------------------------------------------------===// #include "DXILFlattenArrays.h" @@ -26,10 +25,12 @@ #include #include #include +#include #define DEBUG_TYPE "dxil-flatten-arrays" using namespace llvm; +namespace { class DXILFlattenArraysLegacy : public ModulePass { @@ -75,17 +76,16 @@ class DXILFlattenArraysVisitor bool visitCallInst(CallInst &ICI) { return false; } bool visitFreezeInst(FreezeInst &FI) { return false; } static bool isMultiDimensionalArray(Type *T); - static unsigned getTotalElements(Type *ArrayTy); - static Type *getBaseElementType(Type *ArrayTy); + static std::pair getElementCountAndType(Type *ArrayTy); private: - SmallVector PotentiallyDeadInstrs; + SmallVector PotentiallyDeadInstrs; DenseMap GEPChainMap; bool finish(); - ConstantInt *constFlattenIndices(ArrayRef Indices, + ConstantInt *genConstFlattenIndices(ArrayRef Indices, ArrayRef Dims, IRBuilder<> &Builder); - Value *instructionFlattenIndices(ArrayRef Indices, + Value *genInstructionFlattenIndices(ArrayRef Indices, ArrayRef Dims, IRBuilder<> &Builder); void @@ -99,6 +99,7 @@ class DXILFlattenArraysVisitor bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo, GetElementPtrInst &GEP); }; +} // namespace bool DXILFlattenArraysVisitor::finish() { RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); @@ -111,25 +112,17 @@ bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) { return false; } -unsigned DXILFlattenArraysVisitor::getTotalElements(Type *ArrayTy) { +std::pair DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) { unsigned TotalElements = 1; Type *CurrArrayTy = ArrayTy; while (auto *InnerArrayTy = dyn_cast(CurrArrayTy)) { TotalElements *= InnerArrayTy->getNumElements(); CurrArrayTy = InnerArrayTy->getElementType(); } - return TotalElements; -} - -Type *DXILFlattenArraysVisitor::getBaseElementType(Type *ArrayTy) { - Type *CurrArrayTy = ArrayTy; - while (auto *InnerArrayTy = dyn_cast(CurrArrayTy)) { - CurrArrayTy = InnerArrayTy->getElementType(); - } - return CurrArrayTy; + return std::make_pair(TotalElements, CurrArrayTy); } -ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices( +ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices( ArrayRef Indices, ArrayRef Dims, IRBuilder<> &Builder) { assert(Indices.size() == Dims.size() && "Indicies and dimmensions should be the same"); @@ -146,7 +139,7 @@ ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices( return Builder.getInt32(FlatIndex); } -Value *DXILFlattenArraysVisitor::instructionFlattenIndices( +Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices( ArrayRef Indices, ArrayRef Dims, IRBuilder<> &Builder) { if (Indices.size() == 1) return Indices[0]; @@ -202,10 +195,10 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) { ArrayType *ArrType = cast(AI.getAllocatedType()); IRBuilder<> Builder(&AI); - unsigned TotalElements = getTotalElements(ArrType); + auto [TotalElements, BaseType] = getElementCountAndType(ArrType); ArrayType *FattenedArrayType = - ArrayType::get(getBaseElementType(ArrType), TotalElements); + ArrayType::get(BaseType, TotalElements); AllocaInst *FlatAlloca = Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat"); FlatAlloca->setAlignment(AI.getAlign()); @@ -261,10 +254,10 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase( IRBuilder<> Builder(&GEP); Value *FlatIndex; if (GEPInfo.AllIndicesAreConstInt) - FlatIndex = constFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder); + FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder); else FlatIndex = - instructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder); + genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder); ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType; Value *FlatGEP = @@ -285,9 +278,9 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { ArrayType *ArrType = cast(GEP.getSourceElementType()); IRBuilder<> Builder(&GEP); - unsigned TotalElements = getTotalElements(ArrType); + auto [TotalElements, BaseType] = getElementCountAndType(ArrType); ArrayType *FlattenedArrayType = - ArrayType::get(getBaseElementType(ArrType), TotalElements); + ArrayType::get(BaseType, TotalElements); Value *PtrOperand = GEP.getPointerOperand(); @@ -313,7 +306,6 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { bool DXILFlattenArraysVisitor::visit(Function &F) { bool MadeChange = false; - ////for (BasicBlock &BB : make_early_inc_range(F)) { ReversePostOrderTraversal RPOT(&F); for (BasicBlock *BB : make_early_inc_range(RPOT)) { for (Instruction &I : make_early_inc_range(*BB)) { @@ -345,8 +337,7 @@ static void collectElements(Constant *Init, collectElements(DataArrayConstant->getElementAsConstant(I), Elements); } } else { - assert( - false && + llvm_unreachable ( "Expected a ConstantArray or ConstantDataArray for array initializer!"); } } @@ -382,10 +373,9 @@ flattenGlobalArrays(Module &M, continue; ArrayType *ArrType = cast(OrigType); - unsigned TotalElements = - DXILFlattenArraysVisitor::getTotalElements(ArrType); + auto [TotalElements, BaseType] = DXILFlattenArraysVisitor::getElementCountAndType(ArrType); ArrayType *FattenedArrayType = ArrayType::get( - DXILFlattenArraysVisitor::getBaseElementType(ArrType), TotalElements); + BaseType, TotalElements); // Create a new global variable with the updated type // Note: Initializer is set via transformInitializer diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.h b/llvm/lib/Target/DirectX/DXILFlattenArrays.h index 409f8d198782c90..aae68496af620aa 100644 --- a/llvm/lib/Target/DirectX/DXILFlattenArrays.h +++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.h @@ -9,9 +9,7 @@ #ifndef LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H #define LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H -#include "DXILResource.h" #include "llvm/IR/PassManager.h" -#include "llvm/Pass.h" namespace llvm {