Skip to content

Commit

Permalink
[DirectX] Introduce the DXILResourceAccess pass
Browse files Browse the repository at this point in the history
This pass transforms resource access via `llvm.dx.resource.getpointer`
into buffer loads and stores.

Fixes llvm#114848.
  • Loading branch information
bogner committed Nov 25, 2024
1 parent 289d416 commit 0af8bca
Show file tree
Hide file tree
Showing 13 changed files with 397 additions and 3 deletions.
3 changes: 3 additions & 0 deletions llvm/include/llvm/Analysis/DXILResource.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ class DXILResourceMap {
DXILResourceMap(
SmallVectorImpl<std::pair<CallInst *, dxil::ResourceInfo>> &&CIToRI);

bool invalidate(Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &Inv);

iterator begin() { return Resources.begin(); }
const_iterator begin() const { return Resources.begin(); }
iterator end() { return Resources.end(); }
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def int_dx_handle_fromBinding
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
[IntrNoMem]>;

def int_dx_resource_getpointer
: DefaultAttrsIntrinsic<[llvm_anyptr_ty], [llvm_any_ty, llvm_i32_ty],
[IntrNoMem]>;
def int_dx_typedBufferLoad
: DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty],
[IntrReadMem]>;
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Analysis/DXILResource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,12 @@ DXILResourceMap::DXILResourceMap(
}
}

bool DXILResourceMap::invalidate(Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &Inv) {
auto PAC = PA.getChecker<DXILResourceAnalysis>();
return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Module>>());
}

void DXILResourceMap::print(raw_ostream &OS) const {
for (unsigned I = 0, E = Resources.size(); I != E; ++I) {
OS << "Binding " << I << ":\n";
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ add_llvm_target(DirectXCodeGen
DXILPrettyPrinter.cpp
DXILResource.cpp
DXILResourceAnalysis.cpp
DXILResourceAccess.cpp
DXILShaderFlags.cpp
DXILTranslateMetadata.cpp

Expand Down
196 changes: 196 additions & 0 deletions llvm/lib/Target/DirectX/DXILResourceAccess.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
//===- DXILResourceAccess.cpp - Resource access via load/store ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "DXILResourceAccess.h"
#include "DirectX.h"
#include "llvm/Analysis/DXILResource.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/InitializePasses.h"

#define DEBUG_TYPE "dxil-resource-access"

using namespace llvm;

static void replaceTypedBufferAccess(IntrinsicInst *II,
dxil::ResourceInfo &RI) {
const DataLayout &DL = II->getDataLayout();

auto *HandleType = cast<TargetExtType>(II->getOperand(0)->getType());
assert(HandleType->getName() == "dx.TypedBuffer" &&
"Unexpected typed buffer type");
Type *ContainedType = HandleType->getTypeParameter(0);
Type *ScalarType = ContainedType->getScalarType();
uint64_t ScalarSize = DL.getTypeSizeInBits(ScalarType) / 8;
int NumElements = ContainedType->getNumContainedTypes();
if (!NumElements)
NumElements = 1;

// Process users keeping track of indexing accumulated from GEPs.
struct AccessAndIndex {
User *Access;
Value *Index;
};
SmallVector<AccessAndIndex> Worklist;
for (User *U : II->users())
Worklist.push_back({U, nullptr});

SmallVector<Instruction *> DeadInsts;
while (!Worklist.empty()) {
AccessAndIndex Current = Worklist.back();
Worklist.pop_back();

if (auto *GEP = dyn_cast<GetElementPtrInst>(Current.Access)) {
IRBuilder<> Builder(GEP);

Value *Index;
APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
if (GEP->accumulateConstantOffset(DL, ConstantOffset)) {
APInt Scaled = ConstantOffset.udiv(ScalarSize);
Index = ConstantInt::get(Builder.getInt32Ty(), Scaled);
} else {
auto IndexIt = GEP->idx_begin();
assert(cast<ConstantInt>(IndexIt)->getZExtValue() == 0 &&
"GEP is not indexing through pointer");
++IndexIt;
Index = *IndexIt;
assert(++IndexIt == GEP->idx_end() && "Too many indices in GEP");
}

for (User *U : GEP->users())
Worklist.push_back({U, Index});
DeadInsts.push_back(GEP);

} else if (auto *SI = dyn_cast<StoreInst>(Current.Access)) {
assert(SI->getValueOperand() != II && "Pointer escaped!");
IRBuilder<> Builder(SI);

Value *V = SI->getValueOperand();
if (V->getType() == ContainedType) {
// V is already the right type.
} else if (V->getType() == ScalarType) {
// We're storing a scalar, so we need to load the current value and only
// replace the relevant part.
auto *Load = Builder.CreateIntrinsic(
ContainedType, Intrinsic::dx_typedBufferLoad,
{II->getOperand(0), II->getOperand(1)});
// If we have an offset from seeing a GEP earlier, use it.
Value *IndexOp = Current.Index
? Current.Index
: ConstantInt::get(Builder.getInt32Ty(), 0);
V = Builder.CreateInsertElement(Load, V, IndexOp);
} else {
llvm_unreachable("Store to typed resource has invalid type");
}

auto *Inst = Builder.CreateIntrinsic(
Builder.getVoidTy(), Intrinsic::dx_typedBufferStore,
{II->getOperand(0), II->getOperand(1), V});
SI->replaceAllUsesWith(Inst);
DeadInsts.push_back(SI);

} else if (auto *LI = dyn_cast<LoadInst>(Current.Access)) {
IRBuilder<> Builder(LI);
Value *V =
Builder.CreateIntrinsic(ContainedType, Intrinsic::dx_typedBufferLoad,
{II->getOperand(0), II->getOperand(1)});
if (Current.Index)
V = Builder.CreateExtractElement(V, Current.Index);

LI->replaceAllUsesWith(V);
DeadInsts.push_back(LI);

} else
llvm_unreachable("Unhandled instruction - pointer escaped?");
}

// Traverse the now-dead instructions in RPO and remove them.
for (Instruction *Dead : llvm::reverse(DeadInsts))
Dead->eraseFromParent();
II->eraseFromParent();
}

static bool transformResourcePointers(Function &F, DXILResourceMap &DRM) {
// TODO: Should we have a more efficient way to find resources used in a
// particular function?
SmallVector<std::pair<IntrinsicInst *, dxil::ResourceInfo &>> Resources;
for (BasicBlock &BB : F)
for (Instruction &I : BB)
if (auto *CI = dyn_cast<CallInst>(&I)) {
auto It = DRM.find(CI);
if (It == DRM.end())
continue;
for (User *U : CI->users())
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U))
if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer)
Resources.emplace_back(II, *It);
}

for (const auto &[II, RI] : Resources) {
if (RI.isTyped())
replaceTypedBufferAccess(II, RI);

// TODO: handle other resource types. We should probably have an
// `unreachable` here once we've added support for all of them.
}

return false;
}

PreservedAnalyses DXILResourceAccess::run(Function &F,
FunctionAnalysisManager &FAM) {
auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
DXILResourceMap *DRM =
MAMProxy.getCachedResult<DXILResourceAnalysis>(*F.getParent());
assert(DRM && "DXILResourceAnalysis must be available");

bool MadeChanges = transformResourcePointers(F, *DRM);
if (!MadeChanges)
return PreservedAnalyses::all();

PreservedAnalyses PA;
PA.preserve<DXILResourceAnalysis>();
PA.preserve<DominatorTreeAnalysis>();
return PA;
}

namespace {
class DXILResourceAccessLegacy : public FunctionPass {
public:
bool runOnFunction(Function &F) override {
DXILResourceMap &DRM =
getAnalysis<DXILResourceWrapperPass>().getResourceMap();

return transformResourcePointers(F, DRM);
}
StringRef getPassName() const override { return "DXIL Resource Access"; }
DXILResourceAccessLegacy() : FunctionPass(ID) {}

static char ID; // Pass identification.
void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
AU.addRequired<DXILResourceWrapperPass>();
AU.addPreserved<DXILResourceWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
}
};
char DXILResourceAccessLegacy::ID = 0;
} // end anonymous namespace

INITIALIZE_PASS_BEGIN(DXILResourceAccessLegacy, DEBUG_TYPE,
"DXIL Resource Access", false, false)
INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
INITIALIZE_PASS_END(DXILResourceAccessLegacy, DEBUG_TYPE,
"DXIL Resource Access", false, false)

FunctionPass *llvm::createDXILResourceAccessLegacyPass() {
return new DXILResourceAccessLegacy();
}
28 changes: 28 additions & 0 deletions llvm/lib/Target/DirectX/DXILResourceAccess.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===- DXILResourceAccess.h - Resource access via load/store ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// \file Pass for replacing pointers to DXIL resources with load and store
// operations.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
#define LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H

#include "llvm/IR/PassManager.h"

namespace llvm {

class DXILResourceAccess : public PassInfoMixin<DXILResourceAccess> {
public:
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
};

} // namespace llvm

#endif // LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
7 changes: 7 additions & 0 deletions llvm/lib/Target/DirectX/DirectX.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#define LLVM_LIB_TARGET_DIRECTX_DIRECTX_H

namespace llvm {
class FunctionPass;
class ModulePass;
class PassRegistry;
class raw_ostream;
Expand Down Expand Up @@ -52,6 +53,12 @@ void initializeDXILOpLoweringLegacyPass(PassRegistry &);
/// Pass to lowering LLVM intrinsic call to DXIL op function call.
ModulePass *createDXILOpLoweringLegacyPass();

/// Initializer for DXILResourceAccess
void initializeDXILResourceAccessLegacyPass(PassRegistry &);

/// Pass to update resource accesses to use load/store directly.
FunctionPass *createDXILResourceAccessLegacyPass();

/// Initializer for DXILTranslateMetadata.
void initializeDXILTranslateMetadataLegacyPass(PassRegistry &);

Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/DirectX/DirectXPassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ MODULE_PASS("dxil-translate-metadata", DXILTranslateMetadata())
// TODO: rename to print<foo> after NPM switch
MODULE_PASS("print-dx-shader-flags", dxil::ShaderFlagsAnalysisPrinter(dbgs()))
#undef MODULE_PASS

#ifndef FUNCTION_PASS
#define FUNCTION_PASS(NAME, CREATE_PASS)
#endif
FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
#undef FUNCTION_PASS
5 changes: 4 additions & 1 deletion llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "DXILIntrinsicExpansion.h"
#include "DXILOpLowering.h"
#include "DXILPrettyPrinter.h"
#include "DXILResourceAccess.h"
#include "DXILResourceAnalysis.h"
#include "DXILShaderFlags.h"
#include "DXILTranslateMetadata.h"
Expand Down Expand Up @@ -56,6 +57,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
initializeWriteDXILPassPass(*PR);
initializeDXContainerGlobalsPass(*PR);
initializeDXILOpLoweringLegacyPass(*PR);
initializeDXILResourceAccessLegacyPass(*PR);
initializeDXILTranslateMetadataLegacyPass(*PR);
initializeDXILResourceMDWrapperPass(*PR);
initializeShaderFlagsAnalysisWrapperPass(*PR);
Expand Down Expand Up @@ -91,9 +93,10 @@ class DirectXPassConfig : public TargetPassConfig {
void addCodeGenPrepare() override {
addPass(createDXILIntrinsicExpansionLegacyPass());
addPass(createDXILDataScalarizationLegacyPass());
addPass(createDXILFlattenArraysLegacyPass());
addPass(createDXILResourceAccessLegacyPass());
ScalarizerPassOptions DxilScalarOptions;
DxilScalarOptions.ScalarizeLoadStore = true;
addPass(createDXILFlattenArraysLegacyPass());
addPass(createScalarizerPass(DxilScalarOptions));
addPass(createDXILOpLoweringLegacyPass());
addPass(createDXILFinalizeLinkageLegacyPass());
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Transforms/Scalar/Scalarizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/DXILResource.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/Argument.h"
Expand Down Expand Up @@ -351,6 +352,7 @@ void ScalarizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
AU.addPreserved<DXILResourceWrapperPass>();
}

char ScalarizerLegacyPass::ID = 0;
Expand Down Expand Up @@ -1348,5 +1350,6 @@ PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM)
bool Changed = Impl.visit(F);
PreservedAnalyses PA;
PA.preserve<DominatorTreeAnalysis>();
PA.preserve<DXILResourceAnalysis>();
return Changed ? PA : PreservedAnalyses::all();
}
35 changes: 35 additions & 0 deletions llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
; RUN: opt -S -dxil-resource-access %s | FileCheck %s

target triple = "dxil-pc-shadermodel6.6-compute"

declare void @use_float4(<4 x float>)
declare void @use_float(<4 x float>)

; CHECK-LABEL: define void @load_float4
define void @load_float4(i32 %index, i32 %elemindex) {
%buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK-NOT: @llvm.dx.resource.getpointer
%ptr = call ptr @llvm.dx.resource.getpointer(
target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)

; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
%vec_data = load <4 x float>, ptr %ptr
call void @use_float4(<4 x float> %vec_data)

; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
; CHECK: extractelement <4 x float> %[[VALUE]], i32 1
%y_ptr = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 1
%y_data = load float, ptr %y_ptr
call void @use_float(float %y_data)

; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
; CHECK: extractelement <4 x float> %[[VALUE]], i32 %elemindex
%dynamic = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 %elemindex
%dyndata = load float, ptr %dynamic
call void @use_float(float %dyndata)

ret void
}
Loading

0 comments on commit 0af8bca

Please sign in to comment.