-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DirectX] Flatten arrays #114332
[DirectX] Flatten arrays #114332
Conversation
@llvm/pr-subscribers-backend-directx Author: Farzon Lotfi (farzonl) Changes
completes 89646 Patch is 42.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114332.diff 11 Files Affected:
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 5d1dc50fdb0dde..a726071e0dcecd 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -22,6 +22,7 @@ add_llvm_target(DirectXCodeGen
DXContainerGlobals.cpp
DXILDataScalarization.cpp
DXILFinalizeLinkage.cpp
+ DXILFlattenArrays.cpp
DXILIntrinsicExpansion.cpp
DXILOpBuilder.cpp
DXILOpLowering.cpp
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
new file mode 100644
index 00000000000000..20c7401e934e6c
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -0,0 +1,458 @@
+//===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+///
+/// \file This file contains a pass to flatten arrays for the DirectX Backend.
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILFlattenArrays.h"
+#include "DirectX.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/ReplaceConstant.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#define DEBUG_TYPE "dxil-flatten-arrays"
+
+using namespace llvm;
+
+class DXILFlattenArraysLegacy : public ModulePass {
+
+public:
+ bool runOnModule(Module &M) override;
+ DXILFlattenArraysLegacy() : ModulePass(ID) {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
+ static char ID; // Pass identification.
+};
+
+struct GEPData {
+ ArrayType *ParentArrayType;
+ Value *ParendOperand;
+ SmallVector<Value *> Indices;
+ SmallVector<uint64_t> Dims;
+ bool AllIndicesAreConstInt;
+};
+
+class DXILFlattenArraysVisitor
+ : public InstVisitor<DXILFlattenArraysVisitor, bool> {
+public:
+ DXILFlattenArraysVisitor() {}
+ bool visit(Function &F);
+ // InstVisitor methods. They return true if the instruction was scalarized,
+ // false if nothing changed.
+ bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
+ bool visitAllocaInst(AllocaInst &AI);
+ bool visitInstruction(Instruction &I) { return false; }
+ bool visitSelectInst(SelectInst &SI) { return false; }
+ bool visitICmpInst(ICmpInst &ICI) { return false; }
+ bool visitFCmpInst(FCmpInst &FCI) { return false; }
+ bool visitUnaryOperator(UnaryOperator &UO) { return false; }
+ bool visitBinaryOperator(BinaryOperator &BO) { return false; }
+ bool visitCastInst(CastInst &CI) { return false; }
+ bool visitBitCastInst(BitCastInst &BCI) { return false; }
+ bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
+ bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
+ bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
+ bool visitPHINode(PHINode &PHI) { return false; }
+ bool visitLoadInst(LoadInst &LI);
+ bool visitStoreInst(StoreInst &SI);
+ bool visitCallInst(CallInst &ICI) { return false; }
+ bool visitFreezeInst(FreezeInst &FI) { return false; }
+ static bool isMultiDimensionalArray(Type *T);
+ static unsigned getTotalElements(Type *ArrayTy);
+ static Type *getBaseElementType(Type *ArrayTy);
+
+private:
+ SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
+ DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
+ bool finish();
+ ConstantInt *constFlattenIndices(ArrayRef<Value *> Indices,
+ ArrayRef<uint64_t> Dims,
+ IRBuilder<> &Builder);
+ Value *instructionFlattenIndices(ArrayRef<Value *> Indices,
+ ArrayRef<uint64_t> Dims,
+ IRBuilder<> &Builder);
+ void
+ recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
+ ArrayType *FlattenedArrayType, Value *PtrOperand,
+ unsigned &GEPChainUseCount,
+ SmallVector<Value *> Indices = SmallVector<Value *>(),
+ SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
+ bool AllIndicesAreConstInt = true);
+ bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
+ bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
+ GetElementPtrInst &GEP);
+};
+
+bool DXILFlattenArraysVisitor::finish() {
+ RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
+ return true;
+}
+
+bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
+ if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
+ return isa<ArrayType>(ArrType->getElementType());
+ return false;
+}
+
+unsigned DXILFlattenArraysVisitor::getTotalElements(Type *ArrayTy) {
+ unsigned TotalElements = 1;
+ Type *CurrArrayTy = ArrayTy;
+ while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
+ TotalElements *= InnerArrayTy->getNumElements();
+ CurrArrayTy = InnerArrayTy->getElementType();
+ }
+ return TotalElements;
+}
+
+Type *DXILFlattenArraysVisitor::getBaseElementType(Type *ArrayTy) {
+ Type *CurrArrayTy = ArrayTy;
+ while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
+ CurrArrayTy = InnerArrayTy->getElementType();
+ }
+ return CurrArrayTy;
+}
+
+ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
+ ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
+ assert(Indices.size() == Dims.size() &&
+ "Indicies and dimmensions should be the same");
+ unsigned FlatIndex = 0;
+ unsigned Multiplier = 1;
+
+ for (int I = Indices.size() - 1; I >= 0; --I) {
+ unsigned DimSize = Dims[I];
+ ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]);
+ assert(CIndex && "This function expects all indicies to be ConstantInt");
+ FlatIndex += CIndex->getZExtValue() * Multiplier;
+ Multiplier *= DimSize;
+ }
+ return Builder.getInt32(FlatIndex);
+}
+
+Value *DXILFlattenArraysVisitor::instructionFlattenIndices(
+ ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
+ if (Indices.size() == 1)
+ return Indices[0];
+
+ Value *FlatIndex = Builder.getInt32(0);
+ unsigned Multiplier = 1;
+
+ for (int I = Indices.size() - 1; I >= 0; --I) {
+ unsigned DimSize = Dims[I];
+ Value *VMultiplier = Builder.getInt32(Multiplier);
+ Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);
+ FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
+ Multiplier *= DimSize;
+ }
+ return FlatIndex;
+}
+
+bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
+ unsigned NumOperands = LI.getNumOperands();
+ for (unsigned I = 0; I < NumOperands; ++I) {
+ Value *CurrOpperand = LI.getOperand(I);
+ ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+ if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+ convertUsersOfConstantsToInstructions(CE,
+ /*RestrictToFunc=*/nullptr,
+ /*RemoveDeadConstants=*/false,
+ /*IncludeSelf=*/true);
+ return false;
+ }
+ }
+ return false;
+}
+
+bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
+ unsigned NumOperands = SI.getNumOperands();
+ for (unsigned I = 0; I < NumOperands; ++I) {
+ Value *CurrOpperand = SI.getOperand(I);
+ ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+ if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+ convertUsersOfConstantsToInstructions(CE,
+ /*RestrictToFunc=*/nullptr,
+ /*RemoveDeadConstants=*/false,
+ /*IncludeSelf=*/true);
+ return false;
+ }
+ }
+ return false;
+}
+
+bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
+ if (!isMultiDimensionalArray(AI.getAllocatedType()))
+ return false;
+
+ ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
+ IRBuilder<> Builder(&AI);
+ unsigned TotalElements = getTotalElements(ArrType);
+
+ ArrayType *FattenedArrayType =
+ ArrayType::get(getBaseElementType(ArrType), TotalElements);
+ AllocaInst *FlatAlloca =
+ Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat");
+ FlatAlloca->setAlignment(AI.getAlign());
+ AI.replaceAllUsesWith(FlatAlloca);
+ AI.eraseFromParent();
+ return true;
+}
+
+void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
+ GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
+ Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
+ SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
+ Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1);
+ AllIndicesAreConstInt &= isa<ConstantInt>(LastIndex);
+ Indices.push_back(LastIndex);
+ assert(isa<ArrayType>(CurrGEP.getSourceElementType()));
+ Dims.push_back(
+ cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements());
+ bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
+ if (!IsMultiDimArr) {
+ assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
+ GEPChainMap.insert(
+ {&CurrGEP,
+ {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
+ std::move(Dims), AllIndicesAreConstInt}});
+ return;
+ }
+ bool GepUses = false;
+ for (auto *User : CurrGEP.users()) {
+ if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
+ recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
+ ++GEPChainUseCount, Indices, Dims,
+ AllIndicesAreConstInt);
+ GepUses = true;
+ }
+ }
+ // This case is just incase the gep chain doesn't end with a 1d array.
+ if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
+ GEPChainMap.insert(
+ {&CurrGEP,
+ {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
+ std::move(Dims), AllIndicesAreConstInt}});
+ }
+}
+
+bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
+ GetElementPtrInst &GEP) {
+ GEPData GEPInfo = GEPChainMap.at(&GEP);
+ return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+}
+bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
+ GEPData &GEPInfo, GetElementPtrInst &GEP) {
+ IRBuilder<> Builder(&GEP);
+ Value *FlatIndex;
+ if (GEPInfo.AllIndicesAreConstInt)
+ FlatIndex = constFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+ else
+ FlatIndex =
+ instructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+
+ ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
+ Value *FlatGEP =
+ Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex,
+ GEP.getName() + ".flat", GEP.isInBounds());
+
+ GEP.replaceAllUsesWith(FlatGEP);
+ GEP.eraseFromParent();
+ return true;
+}
+
+bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
+ auto It = GEPChainMap.find(&GEP);
+ if (It != GEPChainMap.end())
+ return visitGetElementPtrInstInGEPChain(GEP);
+ if (!isMultiDimensionalArray(GEP.getSourceElementType()))
+ return false;
+
+ ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
+ IRBuilder<> Builder(&GEP);
+ unsigned TotalElements = getTotalElements(ArrType);
+ ArrayType *FlattenedArrayType =
+ ArrayType::get(getBaseElementType(ArrType), TotalElements);
+
+ Value *PtrOperand = GEP.getPointerOperand();
+
+ unsigned GEPChainUseCount = 0;
+ recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
+
+ // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
+ // Here recursion is used to get the length of the GEP chain.
+ // Handle zero uses here because there won't be an update via
+ // a child in the chain later.
+ if (GEPChainUseCount == 0) {
+ SmallVector<Value *> Indices({GEP.getOperand(GEP.getNumOperands() - 1)});
+ SmallVector<uint64_t> Dims({ArrType->getNumElements()});
+ bool AllIndicesAreConstInt = isa<ConstantInt>(Indices[0]);
+ GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
+ std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
+ return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+ }
+
+ PotentiallyDeadInstrs.emplace_back(&GEP);
+ return false;
+}
+
+bool DXILFlattenArraysVisitor::visit(Function &F) {
+ bool MadeChange = false;
+ ////for (BasicBlock &BB : make_early_inc_range(F)) {
+ ReversePostOrderTraversal<Function *> RPOT(&F);
+ for (BasicBlock *BB : make_early_inc_range(RPOT)) {
+ for (Instruction &I : make_early_inc_range(*BB)) {
+ if (InstVisitor::visit(I) && I.getType()->isVoidTy()) {
+ I.eraseFromParent();
+ MadeChange = true;
+ }
+ }
+ }
+ finish();
+ return MadeChange;
+}
+
+static void collectElements(Constant *Init,
+ SmallVectorImpl<Constant *> &Elements) {
+ // Base case: If Init is not an array, add it directly to the vector.
+ if (!isa<ArrayType>(Init->getType())) {
+ Elements.push_back(Init);
+ return;
+ }
+
+ // Recursive case: Process each element in the array.
+ if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {
+ for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {
+ collectElements(ArrayConstant->getOperand(I), Elements);
+ }
+ } else if (auto *DataArrayConstant = dyn_cast<ConstantDataArray>(Init)) {
+ for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {
+ collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
+ }
+ } else {
+ assert(
+ false &&
+ "Expected a ConstantArray or ConstantDataArray for array initializer!");
+ }
+}
+
+static Constant *transformInitializer(Constant *Init, Type *OrigType,
+ ArrayType *FlattenedType,
+ LLVMContext &Ctx) {
+ // Handle ConstantAggregateZero (zero-initialized constants)
+ if (isa<ConstantAggregateZero>(Init))
+ return ConstantAggregateZero::get(FlattenedType);
+
+ // Handle UndefValue (undefined constants)
+ if (isa<UndefValue>(Init))
+ return UndefValue::get(FlattenedType);
+
+ if (!isa<ArrayType>(OrigType))
+ return Init;
+
+ SmallVector<Constant *> FlattenedElements;
+ collectElements(Init, FlattenedElements);
+ assert(FlattenedType->getNumElements() == FlattenedElements.size() &&
+ "The number of collected elements should match the FlattenedType");
+ return ConstantArray::get(FlattenedType, FlattenedElements);
+}
+
+static void
+flattenGlobalArrays(Module &M,
+ DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
+ LLVMContext &Ctx = M.getContext();
+ for (GlobalVariable &G : M.globals()) {
+ Type *OrigType = G.getValueType();
+ if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
+ continue;
+
+ ArrayType *ArrType = cast<ArrayType>(OrigType);
+ unsigned TotalElements =
+ DXILFlattenArraysVisitor::getTotalElements(ArrType);
+ ArrayType *FattenedArrayType = ArrayType::get(
+ DXILFlattenArraysVisitor::getBaseElementType(ArrType), TotalElements);
+
+ // Create a new global variable with the updated type
+ // Note: Initializer is set via transformInitializer
+ GlobalVariable *NewGlobal =
+ new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),
+ /*Initializer=*/nullptr, G.getName() + ".1dim", &G,
+ G.getThreadLocalMode(), G.getAddressSpace(),
+ G.isExternallyInitialized());
+
+ // Copy relevant attributes
+ NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
+ if (G.getAlignment() > 0) {
+ NewGlobal->setAlignment(G.getAlign());
+ }
+
+ if (G.hasInitializer()) {
+ Constant *Init = G.getInitializer();
+ Constant *NewInit =
+ transformInitializer(Init, OrigType, FattenedArrayType, Ctx);
+ NewGlobal->setInitializer(NewInit);
+ }
+ GlobalMap[&G] = NewGlobal;
+ }
+}
+
+static bool flattenArrays(Module &M) {
+ bool MadeChange = false;
+ DXILFlattenArraysVisitor Impl;
+ DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
+ flattenGlobalArrays(M, GlobalMap);
+ for (auto &F : make_early_inc_range(M.functions())) {
+ if (F.isIntrinsic())
+ continue;
+ MadeChange |= Impl.visit(F);
+ }
+ for (auto &[Old, New] : GlobalMap) {
+ Old->replaceAllUsesWith(New);
+ Old->eraseFromParent();
+ MadeChange = true;
+ }
+ return MadeChange;
+}
+
+PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) {
+ bool MadeChanges = flattenArrays(M);
+ if (!MadeChanges)
+ return PreservedAnalyses::all();
+ PreservedAnalyses PA;
+ PA.preserve<DXILResourceAnalysis>();
+ return PA;
+}
+
+bool DXILFlattenArraysLegacy::runOnModule(Module &M) {
+ return flattenArrays(M);
+}
+
+void DXILFlattenArraysLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addPreserved<DXILResourceWrapperPass>();
+}
+
+char DXILFlattenArraysLegacy::ID = 0;
+
+INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE,
+ "DXIL Array Flattener", false, false)
+INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener",
+ false, false)
+
+ModulePass *llvm::createDXILFlattenArraysLegacyPass() {
+ return new DXILFlattenArraysLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.h b/llvm/lib/Target/DirectX/DXILFlattenArrays.h
new file mode 100644
index 00000000000000..409f8d198782c9
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.h
@@ -0,0 +1,25 @@
+//===- DXILFlattenArrays.h - Perform flattening of DXIL Arrays -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#ifndef LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
+#define LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
+
+#include "DXILResource.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+
+namespace llvm {
+
+/// A pass that transforms multidimensional arrays into one-dimensional arrays.
+class DXILFlattenArrays : public PassInfoMixin<DXILFlattenArrays> {
+public:
+ PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
+};
+} // namespace llvm
+
+#endif // LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 3221779be2f311..3454f16ecd5955 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -40,6 +40,12 @@ void initializeDXILDataScalarizationLegacyPass(PassRegistry &);
/// Pass to scalarize llvm global data into a DXIL legal form
ModulePass *createDXILDataScalarizationLegacyPass();
+/// Initializer for DXIL Array Flatten Pass
+void initializeDXILFlattenArraysLegacyPass(PassRegistry &);
+
+/// Pass to flatten arrays into a one dimensional DXIL legal form
+ModulePass *createDXILFlattenArraysLegacyPass();
+
/// Initializer for DXILOpLowering
void initializeDXILOpLoweringLegacyPass(PassRegistry &);
diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index ae729a1082b867..a0f864ed39375f 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -24,6 +24,7 @@ MODULE_ANALYSIS("dxil-resource-md", DXILResourceMDAnalysis())
#define MODULE_PASS(NAME, CREATE_PASS)
#endif
MODULE_PASS("dxil-data-scalarization", DXILDataScalarization())
+MODULE_PASS("dxil-flatten-arrays", DXILFlattenArrays())
MODULE_PASS("dxil-intrinsic-expansion", DXILIntrinsicExpansion(...
[truncated]
|
bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase( | ||
GEPData &GEPInfo, GetElementPtrInst &GEP) { | ||
IRBuilder<> Builder(&GEP); | ||
Value *FlatIndex; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in DXC all the other indices were 0. We never seemed to use any index but the last one, so I dropped the other indices. Alternatively, I could preserve the other indices and just update the last index.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me, but of course an expert will need to take a look.
#include "DXILResource.h" | ||
#include "llvm/IR/PassManager.h" | ||
#include "llvm/Pass.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you only need llvm/IR/PassManager.h
in the header here. Pass.h
is for the legacy pass manager, and DXILResource.h
is only used in the cpp file
; CHECK: alloca [9 x i32], align 4 | ||
; CHECK-NOT: alloca [3 x [3 x i32]], align 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This CHECK-NOT:
isn't necessarily doing what you want here. If the unflattened alloca was left over before the replacement instead of after it wouldn't catch that. Might be more reliable to use CHECK-NEXT:
and check the whole function:
; CHECK-LABEL: alloca_2d_test
; CHECK-NEXT: alloca [9 x i32], align 4
; CHECK-NEXT: ret void
; CHECK @staticArray | ||
; CHECK-NOT: @staticArray.scalarized | ||
; CHECK: @groushared3dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x [4 x i32]]]] zeroinitializer, align 16 | ||
; CHECK-NOT: @groushared3dArrayofVectors | ||
; CHECK-NOT: @staticArray.scalarized.1dim | ||
; CHECK-NOT: @staticArray.1dim | ||
; DATACHECK: @groushared3dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x [4 x i32]]]] zeroinitializer, align 16 | ||
; CHECK: @groushared3dArrayofVectors.scalarized.1dim = local_unnamed_addr addrspace(3) global [108 x i32] zeroinitializer, align 16 | ||
; DATACHECK-NOT: @groushared3dArrayofVectors | ||
; CHECK-NOT: @groushared3dArrayofVectors.scalarized |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a similar problem here where the CHECK-NOT:
s aren't really going to be reliable. This one might be a bit harder to deal with though since we don't have good markers for the start and end of the globals here. This might be a moot point though - see my next comment.
; RUN: opt -S -passes='dxil-data-scalarization,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s | ||
; RUN: opt -S -passes='dxil-data-scalarization,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s -check-prefix DATACHECK | ||
; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure that we really need to add the flatten-arrays
versions to these scalar tests - this is specifically testing that the scalarizer does the right thing for nested arrays of vectors.
We're already testing that flatten-arrays will do the right thing and flatten these types of arrays in the dedicated test, so all we're getting here by running a second version of the test is a couple of extra test cases with arrays that happen to be flat. We could simply add those types of arrays to the test directly if we're worried about coverage, and then we don't need to have the complexity of multiple run lines that do rather different things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is two fold I want this to be a scalarizer test case, but we are using the llc
run line. llc
was important to find bugs were my passes could have created invalid llvmir for later passes. Dropping a flatten-arrays pass also means dropping the llc run line. So I saw three paths forward.
- Drop llc and make this a scalarizer only test. Downside if later passes break things the CI won't catch it.
- Drop the scalarizer specifics from the test. Downside is we don't have scalarizer only tests
- Do what I'm currently doing to maintain both llc and data scalarization specifics in one test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I think what we should do is add a test file for scalarization that checks that it does the right thing for non-flattened arrays. This will only have an opt
run line.
Then we have two options for the existing scalarization tests that have llc
run lines:
- Update them to just have flattened arrays, since that's what we'll actually see in practice
- Update the
opt
lines to include dxil-flatten-arrays
I think I have a preference for (1) since it keeps this to testing just one thing. That makes it easier to debug if/when the tests fail, and also avoids the combinatorial explosion of having tests for arbitrary sets of passes that work together.
ecc8428
to
9d36d73
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
9d36d73
to
7ec67b9
Compare
0852e43
to
4fffccc
Compare
4fffccc
to
964d113
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/27/builds/1992 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/146/builds/1595 Here is the relevant piece of the build log for the reference
|
DXILFlattenArrays.cpp
recursivelyCollectGEPs
.fixes 89646