Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DirectX] Flatten arrays #114332

Merged
merged 4 commits into from
Nov 13, 2024
Merged

[DirectX] Flatten arrays #114332

merged 4 commits into from
Nov 13, 2024

Conversation

farzonl
Copy link
Member

@farzonl farzonl commented Oct 31, 2024

  • Relevant piece is DXILFlattenArrays.cpp
  • Loads and Store Instruction visits are just for finding GetElementPtrConstantExpr and splitting them.
  • Allocas needed to be replaced with flattened allocas.
  • Global arrays were similar to allocas. Only interesting piece here is around initializers.
  • Most of the work went into building correct GEP chains. The approach here was a recursive strategy via recursivelyCollectGEPs.
  • All intermediary GEPs get marked for deletion and only the leaf GEPs get updated with the new index.

fixes 89646

@farzonl farzonl self-assigned this Oct 31, 2024
@farzonl farzonl requested a review from bogner October 31, 2024 00:03
@llvmbot
Copy link

llvmbot commented Oct 31, 2024

@llvm/pr-subscribers-backend-directx

Author: Farzon Lotfi (farzonl)

Changes
  • Relevant piece is DXILFlattenArrays.cpp
  • Loads and Store Instruction visits are just for finding GetElementPtrConstantExpr and splitting them.
  • Allocas needed to be replaced with flattened allocas.
  • Global arrays were similar to allocas. Only interesting piece here is around initializers.
  • Most of the work went into building correct GEP chains. The approach here was a recursive strategy via recursivelyCollectGEPs.
  • All intermediary GEPs get marked for deletion and only the leaf GEPs get updated with the new index.

completes 89646


Patch is 42.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114332.diff

11 Files Affected:

  • (modified) llvm/lib/Target/DirectX/CMakeLists.txt (+1)
  • (added) llvm/lib/Target/DirectX/DXILFlattenArrays.cpp (+458)
  • (added) llvm/lib/Target/DirectX/DXILFlattenArrays.h (+25)
  • (modified) llvm/lib/Target/DirectX/DirectX.h (+6)
  • (modified) llvm/lib/Target/DirectX/DirectXPassRegistry.def (+1)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetMachine.cpp (+3)
  • (added) llvm/test/CodeGen/DirectX/flatten-array.ll (+194)
  • (modified) llvm/test/CodeGen/DirectX/llc-pipeline.ll (+1)
  • (modified) llvm/test/CodeGen/DirectX/scalar-data.ll (+8-3)
  • (modified) llvm/test/CodeGen/DirectX/scalar-load.ll (+15-8)
  • (modified) llvm/test/CodeGen/DirectX/scalar-store.ll (+7-4)
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 5d1dc50fdb0dde..a726071e0dcecd 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -22,6 +22,7 @@ add_llvm_target(DirectXCodeGen
   DXContainerGlobals.cpp
   DXILDataScalarization.cpp
   DXILFinalizeLinkage.cpp
+  DXILFlattenArrays.cpp
   DXILIntrinsicExpansion.cpp
   DXILOpBuilder.cpp
   DXILOpLowering.cpp
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
new file mode 100644
index 00000000000000..20c7401e934e6c
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -0,0 +1,458 @@
+//===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+///
+/// \file This file contains a pass to flatten arrays for the DirectX Backend.
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILFlattenArrays.h"
+#include "DirectX.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/ReplaceConstant.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#define DEBUG_TYPE "dxil-flatten-arrays"
+
+using namespace llvm;
+
+class DXILFlattenArraysLegacy : public ModulePass {
+
+public:
+  bool runOnModule(Module &M) override;
+  DXILFlattenArraysLegacy() : ModulePass(ID) {}
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+  static char ID; // Pass identification.
+};
+
+struct GEPData {
+  ArrayType *ParentArrayType;
+  Value *ParendOperand;
+  SmallVector<Value *> Indices;
+  SmallVector<uint64_t> Dims;
+  bool AllIndicesAreConstInt;
+};
+
+class DXILFlattenArraysVisitor
+    : public InstVisitor<DXILFlattenArraysVisitor, bool> {
+public:
+  DXILFlattenArraysVisitor() {}
+  bool visit(Function &F);
+  // InstVisitor methods.  They return true if the instruction was scalarized,
+  // false if nothing changed.
+  bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
+  bool visitAllocaInst(AllocaInst &AI);
+  bool visitInstruction(Instruction &I) { return false; }
+  bool visitSelectInst(SelectInst &SI) { return false; }
+  bool visitICmpInst(ICmpInst &ICI) { return false; }
+  bool visitFCmpInst(FCmpInst &FCI) { return false; }
+  bool visitUnaryOperator(UnaryOperator &UO) { return false; }
+  bool visitBinaryOperator(BinaryOperator &BO) { return false; }
+  bool visitCastInst(CastInst &CI) { return false; }
+  bool visitBitCastInst(BitCastInst &BCI) { return false; }
+  bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
+  bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
+  bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
+  bool visitPHINode(PHINode &PHI) { return false; }
+  bool visitLoadInst(LoadInst &LI);
+  bool visitStoreInst(StoreInst &SI);
+  bool visitCallInst(CallInst &ICI) { return false; }
+  bool visitFreezeInst(FreezeInst &FI) { return false; }
+  static bool isMultiDimensionalArray(Type *T);
+  static unsigned getTotalElements(Type *ArrayTy);
+  static Type *getBaseElementType(Type *ArrayTy);
+
+private:
+  SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
+  DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
+  bool finish();
+  ConstantInt *constFlattenIndices(ArrayRef<Value *> Indices,
+                                   ArrayRef<uint64_t> Dims,
+                                   IRBuilder<> &Builder);
+  Value *instructionFlattenIndices(ArrayRef<Value *> Indices,
+                                   ArrayRef<uint64_t> Dims,
+                                   IRBuilder<> &Builder);
+  void
+  recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
+                         ArrayType *FlattenedArrayType, Value *PtrOperand,
+                         unsigned &GEPChainUseCount,
+                         SmallVector<Value *> Indices = SmallVector<Value *>(),
+                         SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
+                         bool AllIndicesAreConstInt = true);
+  bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
+  bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
+                                            GetElementPtrInst &GEP);
+};
+
+bool DXILFlattenArraysVisitor::finish() {
+  RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
+  return true;
+}
+
+bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
+  if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
+    return isa<ArrayType>(ArrType->getElementType());
+  return false;
+}
+
+unsigned DXILFlattenArraysVisitor::getTotalElements(Type *ArrayTy) {
+  unsigned TotalElements = 1;
+  Type *CurrArrayTy = ArrayTy;
+  while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
+    TotalElements *= InnerArrayTy->getNumElements();
+    CurrArrayTy = InnerArrayTy->getElementType();
+  }
+  return TotalElements;
+}
+
+Type *DXILFlattenArraysVisitor::getBaseElementType(Type *ArrayTy) {
+  Type *CurrArrayTy = ArrayTy;
+  while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
+    CurrArrayTy = InnerArrayTy->getElementType();
+  }
+  return CurrArrayTy;
+}
+
+ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
+    ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
+  assert(Indices.size() == Dims.size() &&
+         "Indicies and dimmensions should be the same");
+  unsigned FlatIndex = 0;
+  unsigned Multiplier = 1;
+
+  for (int I = Indices.size() - 1; I >= 0; --I) {
+    unsigned DimSize = Dims[I];
+    ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]);
+    assert(CIndex && "This function expects all indicies to be ConstantInt");
+    FlatIndex += CIndex->getZExtValue() * Multiplier;
+    Multiplier *= DimSize;
+  }
+  return Builder.getInt32(FlatIndex);
+}
+
+Value *DXILFlattenArraysVisitor::instructionFlattenIndices(
+    ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
+  if (Indices.size() == 1)
+    return Indices[0];
+
+  Value *FlatIndex = Builder.getInt32(0);
+  unsigned Multiplier = 1;
+
+  for (int I = Indices.size() - 1; I >= 0; --I) {
+    unsigned DimSize = Dims[I];
+    Value *VMultiplier = Builder.getInt32(Multiplier);
+    Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);
+    FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
+    Multiplier *= DimSize;
+  }
+  return FlatIndex;
+}
+
+bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
+  unsigned NumOperands = LI.getNumOperands();
+  for (unsigned I = 0; I < NumOperands; ++I) {
+    Value *CurrOpperand = LI.getOperand(I);
+    ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+    if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+      convertUsersOfConstantsToInstructions(CE,
+                                            /*RestrictToFunc=*/nullptr,
+                                            /*RemoveDeadConstants=*/false,
+                                            /*IncludeSelf=*/true);
+      return false;
+    }
+  }
+  return false;
+}
+
+bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
+  unsigned NumOperands = SI.getNumOperands();
+  for (unsigned I = 0; I < NumOperands; ++I) {
+    Value *CurrOpperand = SI.getOperand(I);
+    ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+    if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+      convertUsersOfConstantsToInstructions(CE,
+                                            /*RestrictToFunc=*/nullptr,
+                                            /*RemoveDeadConstants=*/false,
+                                            /*IncludeSelf=*/true);
+      return false;
+    }
+  }
+  return false;
+}
+
+bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
+  if (!isMultiDimensionalArray(AI.getAllocatedType()))
+    return false;
+
+  ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
+  IRBuilder<> Builder(&AI);
+  unsigned TotalElements = getTotalElements(ArrType);
+
+  ArrayType *FattenedArrayType =
+      ArrayType::get(getBaseElementType(ArrType), TotalElements);
+  AllocaInst *FlatAlloca =
+      Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat");
+  FlatAlloca->setAlignment(AI.getAlign());
+  AI.replaceAllUsesWith(FlatAlloca);
+  AI.eraseFromParent();
+  return true;
+}
+
+void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
+    GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
+    Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
+    SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
+  Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1);
+  AllIndicesAreConstInt &= isa<ConstantInt>(LastIndex);
+  Indices.push_back(LastIndex);
+  assert(isa<ArrayType>(CurrGEP.getSourceElementType()));
+  Dims.push_back(
+      cast<ArrayType>(CurrGEP.getSourceElementType())->getNumElements());
+  bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
+  if (!IsMultiDimArr) {
+    assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
+    GEPChainMap.insert(
+        {&CurrGEP,
+         {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
+          std::move(Dims), AllIndicesAreConstInt}});
+    return;
+  }
+  bool GepUses = false;
+  for (auto *User : CurrGEP.users()) {
+    if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
+      recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
+                             ++GEPChainUseCount, Indices, Dims,
+                             AllIndicesAreConstInt);
+      GepUses = true;
+    }
+  }
+  // This case is just incase the gep chain doesn't end with a 1d array.
+  if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
+    GEPChainMap.insert(
+        {&CurrGEP,
+         {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
+          std::move(Dims), AllIndicesAreConstInt}});
+  }
+}
+
+bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
+    GetElementPtrInst &GEP) {
+  GEPData GEPInfo = GEPChainMap.at(&GEP);
+  return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+}
+bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
+    GEPData &GEPInfo, GetElementPtrInst &GEP) {
+  IRBuilder<> Builder(&GEP);
+  Value *FlatIndex;
+  if (GEPInfo.AllIndicesAreConstInt)
+    FlatIndex = constFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+  else
+    FlatIndex =
+        instructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
+
+  ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
+  Value *FlatGEP =
+      Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex,
+                        GEP.getName() + ".flat", GEP.isInBounds());
+
+  GEP.replaceAllUsesWith(FlatGEP);
+  GEP.eraseFromParent();
+  return true;
+}
+
+bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
+  auto It = GEPChainMap.find(&GEP);
+  if (It != GEPChainMap.end())
+    return visitGetElementPtrInstInGEPChain(GEP);
+  if (!isMultiDimensionalArray(GEP.getSourceElementType()))
+    return false;
+
+  ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
+  IRBuilder<> Builder(&GEP);
+  unsigned TotalElements = getTotalElements(ArrType);
+  ArrayType *FlattenedArrayType =
+      ArrayType::get(getBaseElementType(ArrType), TotalElements);
+
+  Value *PtrOperand = GEP.getPointerOperand();
+
+  unsigned GEPChainUseCount = 0;
+  recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
+
+  // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
+  // Here recursion is used to get the length of the GEP chain.
+  // Handle zero uses here because there won't be an update via
+  // a child in the chain later.
+  if (GEPChainUseCount == 0) {
+    SmallVector<Value *> Indices({GEP.getOperand(GEP.getNumOperands() - 1)});
+    SmallVector<uint64_t> Dims({ArrType->getNumElements()});
+    bool AllIndicesAreConstInt = isa<ConstantInt>(Indices[0]);
+    GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
+                    std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
+    return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+  }
+
+  PotentiallyDeadInstrs.emplace_back(&GEP);
+  return false;
+}
+
+bool DXILFlattenArraysVisitor::visit(Function &F) {
+  bool MadeChange = false;
+  ////for (BasicBlock &BB : make_early_inc_range(F)) {
+  ReversePostOrderTraversal<Function *> RPOT(&F);
+  for (BasicBlock *BB : make_early_inc_range(RPOT)) {
+    for (Instruction &I : make_early_inc_range(*BB)) {
+      if (InstVisitor::visit(I) && I.getType()->isVoidTy()) {
+        I.eraseFromParent();
+        MadeChange = true;
+      }
+    }
+  }
+  finish();
+  return MadeChange;
+}
+
+static void collectElements(Constant *Init,
+                            SmallVectorImpl<Constant *> &Elements) {
+  // Base case: If Init is not an array, add it directly to the vector.
+  if (!isa<ArrayType>(Init->getType())) {
+    Elements.push_back(Init);
+    return;
+  }
+
+  // Recursive case: Process each element in the array.
+  if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {
+    for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {
+      collectElements(ArrayConstant->getOperand(I), Elements);
+    }
+  } else if (auto *DataArrayConstant = dyn_cast<ConstantDataArray>(Init)) {
+    for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {
+      collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
+    }
+  } else {
+    assert(
+        false &&
+        "Expected a ConstantArray or ConstantDataArray for array initializer!");
+  }
+}
+
+static Constant *transformInitializer(Constant *Init, Type *OrigType,
+                                      ArrayType *FlattenedType,
+                                      LLVMContext &Ctx) {
+  // Handle ConstantAggregateZero (zero-initialized constants)
+  if (isa<ConstantAggregateZero>(Init))
+    return ConstantAggregateZero::get(FlattenedType);
+
+  // Handle UndefValue (undefined constants)
+  if (isa<UndefValue>(Init))
+    return UndefValue::get(FlattenedType);
+
+  if (!isa<ArrayType>(OrigType))
+    return Init;
+
+  SmallVector<Constant *> FlattenedElements;
+  collectElements(Init, FlattenedElements);
+  assert(FlattenedType->getNumElements() == FlattenedElements.size() &&
+         "The number of collected elements should match the FlattenedType");
+  return ConstantArray::get(FlattenedType, FlattenedElements);
+}
+
+static void
+flattenGlobalArrays(Module &M,
+                    DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
+  LLVMContext &Ctx = M.getContext();
+  for (GlobalVariable &G : M.globals()) {
+    Type *OrigType = G.getValueType();
+    if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
+      continue;
+
+    ArrayType *ArrType = cast<ArrayType>(OrigType);
+    unsigned TotalElements =
+        DXILFlattenArraysVisitor::getTotalElements(ArrType);
+    ArrayType *FattenedArrayType = ArrayType::get(
+        DXILFlattenArraysVisitor::getBaseElementType(ArrType), TotalElements);
+
+    // Create a new global variable with the updated type
+    // Note: Initializer is set via transformInitializer
+    GlobalVariable *NewGlobal =
+        new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),
+                           /*Initializer=*/nullptr, G.getName() + ".1dim", &G,
+                           G.getThreadLocalMode(), G.getAddressSpace(),
+                           G.isExternallyInitialized());
+
+    // Copy relevant attributes
+    NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
+    if (G.getAlignment() > 0) {
+      NewGlobal->setAlignment(G.getAlign());
+    }
+
+    if (G.hasInitializer()) {
+      Constant *Init = G.getInitializer();
+      Constant *NewInit =
+          transformInitializer(Init, OrigType, FattenedArrayType, Ctx);
+      NewGlobal->setInitializer(NewInit);
+    }
+    GlobalMap[&G] = NewGlobal;
+  }
+}
+
+static bool flattenArrays(Module &M) {
+  bool MadeChange = false;
+  DXILFlattenArraysVisitor Impl;
+  DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
+  flattenGlobalArrays(M, GlobalMap);
+  for (auto &F : make_early_inc_range(M.functions())) {
+    if (F.isIntrinsic())
+      continue;
+    MadeChange |= Impl.visit(F);
+  }
+  for (auto &[Old, New] : GlobalMap) {
+    Old->replaceAllUsesWith(New);
+    Old->eraseFromParent();
+    MadeChange = true;
+  }
+  return MadeChange;
+}
+
+PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) {
+  bool MadeChanges = flattenArrays(M);
+  if (!MadeChanges)
+    return PreservedAnalyses::all();
+  PreservedAnalyses PA;
+  PA.preserve<DXILResourceAnalysis>();
+  return PA;
+}
+
+bool DXILFlattenArraysLegacy::runOnModule(Module &M) {
+  return flattenArrays(M);
+}
+
+void DXILFlattenArraysLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.addPreserved<DXILResourceWrapperPass>();
+}
+
+char DXILFlattenArraysLegacy::ID = 0;
+
+INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE,
+                      "DXIL Array Flattener", false, false)
+INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener",
+                    false, false)
+
+ModulePass *llvm::createDXILFlattenArraysLegacyPass() {
+  return new DXILFlattenArraysLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.h b/llvm/lib/Target/DirectX/DXILFlattenArrays.h
new file mode 100644
index 00000000000000..409f8d198782c9
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.h
@@ -0,0 +1,25 @@
+//===- DXILFlattenArrays.h - Perform flattening of DXIL Arrays -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#ifndef LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
+#define LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
+
+#include "DXILResource.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+
+namespace llvm {
+
+/// A pass that transforms multidimensional arrays into one-dimensional arrays.
+class DXILFlattenArrays : public PassInfoMixin<DXILFlattenArrays> {
+public:
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
+};
+} // namespace llvm
+
+#endif // LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 3221779be2f311..3454f16ecd5955 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -40,6 +40,12 @@ void initializeDXILDataScalarizationLegacyPass(PassRegistry &);
 /// Pass to scalarize llvm global data into a DXIL legal form
 ModulePass *createDXILDataScalarizationLegacyPass();
 
+/// Initializer for DXIL Array Flatten Pass
+void initializeDXILFlattenArraysLegacyPass(PassRegistry &);
+
+/// Pass to flatten arrays into a one dimensional DXIL legal form
+ModulePass *createDXILFlattenArraysLegacyPass();
+
 /// Initializer for DXILOpLowering
 void initializeDXILOpLoweringLegacyPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index ae729a1082b867..a0f864ed39375f 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -24,6 +24,7 @@ MODULE_ANALYSIS("dxil-resource-md", DXILResourceMDAnalysis())
 #define MODULE_PASS(NAME, CREATE_PASS)
 #endif
 MODULE_PASS("dxil-data-scalarization", DXILDataScalarization())
+MODULE_PASS("dxil-flatten-arrays", DXILFlattenArrays())
 MODULE_PASS("dxil-intrinsic-expansion", DXILIntrinsicExpansion(...
[truncated]

bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
GEPData &GEPInfo, GetElementPtrInst &GEP) {
IRBuilder<> Builder(&GEP);
Value *FlatIndex;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in DXC all the other indices were 0. We never seemed to use any index but the last one, so I dropped the other indices. Alternatively, I could preserve the other indices and just update the last index.

Copy link
Contributor

@coopp coopp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, but of course an expert will need to take a look.

Comment on lines 12 to 14
#include "DXILResource.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you only need llvm/IR/PassManager.h in the header here. Pass.h is for the legacy pass manager, and DXILResource.h is only used in the cpp file

Comment on lines 5 to 6
; CHECK: alloca [9 x i32], align 4
; CHECK-NOT: alloca [3 x [3 x i32]], align 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This CHECK-NOT: isn't necessarily doing what you want here. If the unflattened alloca was left over before the replacement instead of after it wouldn't catch that. Might be more reliable to use CHECK-NEXT: and check the whole function:

; CHECK-LABEL: alloca_2d_test
; CHECK-NEXT: alloca [9 x i32], align 4
; CHECK-NEXT: ret void

Comment on lines 10 to 14
; CHECK @staticArray
; CHECK-NOT: @staticArray.scalarized
; CHECK: @groushared3dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x [4 x i32]]]] zeroinitializer, align 16
; CHECK-NOT: @groushared3dArrayofVectors
; CHECK-NOT: @staticArray.scalarized.1dim
; CHECK-NOT: @staticArray.1dim
; DATACHECK: @groushared3dArrayofVectors.scalarized = local_unnamed_addr addrspace(3) global [3 x [3 x [3 x [4 x i32]]]] zeroinitializer, align 16
; CHECK: @groushared3dArrayofVectors.scalarized.1dim = local_unnamed_addr addrspace(3) global [108 x i32] zeroinitializer, align 16
; DATACHECK-NOT: @groushared3dArrayofVectors
; CHECK-NOT: @groushared3dArrayofVectors.scalarized
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a similar problem here where the CHECK-NOT:s aren't really going to be reliable. This one might be a bit harder to deal with though since we don't have good markers for the start and end of the globals here. This might be a moot point though - see my next comment.

Comment on lines 1 to 2
; RUN: opt -S -passes='dxil-data-scalarization,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
; RUN: opt -S -passes='dxil-data-scalarization,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s -check-prefix DATACHECK
; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that we really need to add the flatten-arrays versions to these scalar tests - this is specifically testing that the scalarizer does the right thing for nested arrays of vectors.

We're already testing that flatten-arrays will do the right thing and flatten these types of arrays in the dedicated test, so all we're getting here by running a second version of the test is a couple of extra test cases with arrays that happen to be flat. We could simply add those types of arrays to the test directly if we're worried about coverage, and then we don't need to have the complexity of multiple run lines that do rather different things.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is two fold I want this to be a scalarizer test case, but we are using the llc run line. llc was important to find bugs were my passes could have created invalid llvmir for later passes. Dropping a flatten-arrays pass also means dropping the llc run line. So I saw three paths forward.

  1. Drop llc and make this a scalarizer only test. Downside if later passes break things the CI won't catch it.
  2. Drop the scalarizer specifics from the test. Downside is we don't have scalarizer only tests
  3. Do what I'm currently doing to maintain both llc and data scalarization specifics in one test.

Copy link
Contributor

@bogner bogner Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think what we should do is add a test file for scalarization that checks that it does the right thing for non-flattened arrays. This will only have an opt run line.

Then we have two options for the existing scalarization tests that have llc run lines:

  1. Update them to just have flattened arrays, since that's what we'll actually see in practice
  2. Update the opt lines to include dxil-flatten-arrays

I think I have a preference for (1) since it keeps this to testing just one thing. That makes it easier to debug if/when the tests fail, and also avoids the combinatorial explosion of having tests for arbitrary sets of passes that work together.

llvm/lib/Target/DirectX/DXILFlattenArrays.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXILFlattenArrays.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXILFlattenArrays.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXILFlattenArrays.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXILFlattenArrays.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXILFlattenArrays.cpp Outdated Show resolved Hide resolved
Copy link

github-actions bot commented Nov 4, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@farzonl farzonl force-pushed the flatten-arrays branch 2 times, most recently from 0852e43 to 4fffccc Compare November 13, 2024 05:23
@farzonl farzonl merged commit 5ac624c into llvm:main Nov 13, 2024
6 of 8 checks passed
@farzonl farzonl mentioned this pull request Nov 13, 2024
7 tasks
@llvm-ci
Copy link
Collaborator

llvm-ci commented Nov 13, 2024

LLVM Buildbot has detected a new failure on builder clang-m68k-linux-cross running on suse-gary-m68k-cross while building llvm at step 5 "ninja check 1".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/27/builds/1992

Here is the relevant piece of the build log for the reference
Step 5 (ninja check 1) failure: stage 1 checked (failure)
...
  463 | class Sema final : public SemaBase {
      |       ^~~~
/var/lib/buildbot/workers/suse-gary-m68k-cross/clang-m68k-linux-cross/llvm/clang/include/clang/Sema/Sema.h:463:7: warning: ‘clang::Sema’ declared with greater visibility than the type of its field ‘clang::Sema::TentativeDefinitions’ [-Wattributes]
/var/lib/buildbot/workers/suse-gary-m68k-cross/clang-m68k-linux-cross/llvm/clang/include/clang/Sema/Sema.h:463:7: warning: ‘clang::Sema’ declared with greater visibility than the type of its field ‘clang::Sema::ExtVectorDecls’ [-Wattributes]
/var/lib/buildbot/workers/suse-gary-m68k-cross/clang-m68k-linux-cross/llvm/clang/include/clang/Sema/Sema.h:463:7: warning: ‘clang::Sema’ declared with greater visibility than the type of its field ‘clang::Sema::DelegatingCtorDecls’ [-Wattributes]
[295/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/CommentHandlerTest.cpp.o
[296/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/ASTSelectionTest.cpp.o
[297/1137] Building CXX object tools/clang/unittests/AST/CMakeFiles/ASTTests.dir/StructuralEquivalenceTest.cpp.o
[298/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/ExecutionTest.cpp.o
[299/1137] Building CXX object tools/clang/unittests/AST/CMakeFiles/ASTTests.dir/ASTImporterTest.cpp.o
FAILED: tools/clang/unittests/AST/CMakeFiles/ASTTests.dir/ASTImporterTest.cpp.o 
/usr/bin/c++ -DGTEST_HAS_RTTI=0 -DLLVM_BUILD_STATIC -D_DEBUG -D_GLIBCXX_ASSERTIONS -D_GNU_SOURCE -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -D__STDC_LIMIT_MACROS -I/var/lib/buildbot/workers/suse-gary-m68k-cross/clang-m68k-linux-cross/stage1/tools/clang/unittests/AST -I/var/lib/buildbot/workers/suse-gary-m68k-cross/clang-m68k-linux-cross/llvm/clang/unittests/AST -I/var/lib/buildbot/workers/suse-gary-m68k-cross/clang-m68k-linux-cross/llvm/clang/include -I/var/lib/buildbot/workers/suse-gary-m68k-cross/clang-m68k-linux-cross/stage1/tools/clang/include -I/var/lib/buildbot/workers/suse-gary-m68k-cross/clang-m68k-linux-cross/stage1/include -I/var/lib/buildbot/workers/suse-gary-m68k-cross/clang-m68k-linux-cross/llvm/llvm/include -I/var/lib/buildbot/workers/suse-gary-m68k-cross/clang-m68k-linux-cross/llvm/third-party/unittest/googletest/include -I/var/lib/buildbot/workers/suse-gary-m68k-cross/clang-m68k-linux-cross/llvm/third-party/unittest/googlemock/include -fPIC -fno-semantic-interposition -fvisibility-inlines-hidden -Werror=date-time -fno-lifetime-dse -Wall -Wextra -Wno-unused-parameter -Wwrite-strings -Wcast-qual -Wno-missing-field-initializers -pedantic -Wno-long-long -Wimplicit-fallthrough -Wno-maybe-uninitialized -Wno-nonnull -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wdelete-non-virtual-dtor -Wsuggest-override -Wno-comment -Wno-misleading-indentation -Wctad-maybe-unsupported -fdiagnostics-color -ffunction-sections -fdata-sections -fno-common -Woverloaded-virtual -fno-strict-aliasing -O3 -DNDEBUG -std=c++17  -Wno-variadic-macros -fno-exceptions -funwind-tables -fno-rtti -UNDEBUG -Wno-suggest-override -MD -MT tools/clang/unittests/AST/CMakeFiles/ASTTests.dir/ASTImporterTest.cpp.o -MF tools/clang/unittests/AST/CMakeFiles/ASTTests.dir/ASTImporterTest.cpp.o.d -o tools/clang/unittests/AST/CMakeFiles/ASTTests.dir/ASTImporterTest.cpp.o -c /var/lib/buildbot/workers/suse-gary-m68k-cross/clang-m68k-linux-cross/llvm/clang/unittests/AST/ASTImporterTest.cpp
c++: fatal error: Killed signal terminated program cc1plus
compilation terminated.
[300/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/QualTypeNamesTest.cpp.o
[301/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/LookupTest.cpp.o
[302/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/CXXBoolLiteralExpr.cpp.o
[303/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/CXXOperatorCallExprTraverser.cpp.o
[304/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/DeductionGuide.cpp.o
[305/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/ConstructExpr.cpp.o
[306/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RangeSelectorTest.cpp.o
[307/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/CXXMemberCall.cpp.o
[308/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/ImplicitCtor.cpp.o
[309/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/InitListExprPostOrder.cpp.o
[310/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/LexicallyOrderedRecursiveASTVisitorTest.cpp.o
[311/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/Concept.cpp.o
[312/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/Attr.cpp.o
[313/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/ImplicitCtorInitializer.cpp.o
[314/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/DeclRefExpr.cpp.o
[315/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/CallbacksCompoundAssignOperator.cpp.o
[316/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/Class.cpp.o
[317/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/LambdaDefaultCapture.cpp.o
[318/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/BitfieldInitializer.cpp.o
[319/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/InitListExprPreOrderNoQueue.cpp.o
[320/1137] Building CXX object tools/clang/unittests/ASTMatchers/CMakeFiles/ASTMatchersTests.dir/ASTMatchersNarrowingTest.cpp.o
[321/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/CXXMethodDecl.cpp.o
[322/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/CallbacksBinaryOperator.cpp.o
[323/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/InitListExprPostOrderNoQueue.cpp.o
[324/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/InitListExprPreOrder.cpp.o
[325/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/CallbacksCallExpr.cpp.o
[326/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/IntegerLiteral.cpp.o
[327/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/LambdaExpr.cpp.o
[328/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/CallbacksLeaf.cpp.o
[329/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/CallbacksUnaryOperator.cpp.o
[330/1137] Building CXX object tools/clang/unittests/ASTMatchers/CMakeFiles/ASTMatchersTests.dir/ASTMatchersTraversalTest.cpp.o
[331/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/LambdaTemplateParams.cpp.o
[332/1137] Building CXX object tools/clang/unittests/Tooling/CMakeFiles/ToolingTests.dir/RecursiveASTVisitorTests/MemberPointerTypeLoc.cpp.o
ninja: build stopped: subcommand failed.

@llvm-ci
Copy link
Collaborator

llvm-ci commented Nov 14, 2024

LLVM Buildbot has detected a new failure on builder lld-x86_64-win running on as-worker-93 while building llvm at step 7 "test-build-unified-tree-check-all".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/146/builds/1595

Here is the relevant piece of the build log for the reference
Step 7 (test-build-unified-tree-check-all) failure: test (failure)
******************** TEST 'LLVM-Unit :: Support/./SupportTests.exe/37/87' FAILED ********************
Script(shard):
--
GTEST_OUTPUT=json:C:\a\lld-x86_64-win\build\unittests\Support\.\SupportTests.exe-LLVM-Unit-8824-37-87.json GTEST_SHUFFLE=0 GTEST_TOTAL_SHARDS=87 GTEST_SHARD_INDEX=37 C:\a\lld-x86_64-win\build\unittests\Support\.\SupportTests.exe
--

Script:
--
C:\a\lld-x86_64-win\build\unittests\Support\.\SupportTests.exe --gtest_filter=ProgramEnvTest.CreateProcessLongPath
--
C:\a\lld-x86_64-win\llvm-project\llvm\unittests\Support\ProgramTest.cpp(160): error: Expected equality of these values:
  0
  RC
    Which is: -2

C:\a\lld-x86_64-win\llvm-project\llvm\unittests\Support\ProgramTest.cpp(163): error: fs::remove(Twine(LongPath)): did not return errc::success.
error number: 13
error message: permission denied



C:\a\lld-x86_64-win\llvm-project\llvm\unittests\Support\ProgramTest.cpp:160
Expected equality of these values:
  0
  RC
    Which is: -2

C:\a\lld-x86_64-win\llvm-project\llvm\unittests\Support\ProgramTest.cpp:163
fs::remove(Twine(LongPath)): did not return errc::success.
error number: 13
error message: permission denied




********************


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

[DirectX backend] flatten multi-dimensional array into a one-dimensional array.
5 participants