Skip to content

Commit

Permalink
[DirectX] Introduce the DXILResourceAccess pass (#116726)
Browse files Browse the repository at this point in the history
This pass transforms resource access via `llvm.dx.resource.getpointer`
into buffer loads and stores.

Fixes #114848.
  • Loading branch information
bogner authored Dec 18, 2024
1 parent 21de514 commit 0fca76d
Show file tree
Hide file tree
Showing 10 changed files with 389 additions and 2 deletions.
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ add_llvm_target(DirectXCodeGen
DXILPrettyPrinter.cpp
DXILResource.cpp
DXILResourceAnalysis.cpp
DXILResourceAccess.cpp
DXILShaderFlags.cpp
DXILTranslateMetadata.cpp

Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/DirectX/DXILOpLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,14 @@ class OpLowerer {
});
}

[[nodiscard]] bool lowerGetPointer(Function &F) {
// These should have already been handled in DXILResourceAccess, so we can
// just clean up the dead prototype.
assert(F.user_empty() && "getpointer operations should have been removed");
F.eraseFromParent();
return false;
}

[[nodiscard]] bool lowerTypedBufferStore(Function &F) {
IRBuilder<> &IRB = OpBuilder.getIRB();
Type *Int8Ty = IRB.getInt8Ty();
Expand Down Expand Up @@ -707,6 +715,9 @@ class OpLowerer {
case Intrinsic::dx_handle_fromBinding:
HasErrors |= lowerHandleFromBinding(F);
break;
case Intrinsic::dx_resource_getpointer:
HasErrors |= lowerGetPointer(F);
break;
case Intrinsic::dx_typedBufferLoad:
HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/false);
break;
Expand Down
192 changes: 192 additions & 0 deletions llvm/lib/Target/DirectX/DXILResourceAccess.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
//===- DXILResourceAccess.cpp - Resource access via load/store ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "DXILResourceAccess.h"
#include "DirectX.h"
#include "llvm/Analysis/DXILResource.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/InitializePasses.h"

#define DEBUG_TYPE "dxil-resource-access"

using namespace llvm;

static void replaceTypedBufferAccess(IntrinsicInst *II,
dxil::ResourceTypeInfo &RTI) {
const DataLayout &DL = II->getDataLayout();

auto *HandleType = cast<TargetExtType>(II->getOperand(0)->getType());
assert(HandleType->getName() == "dx.TypedBuffer" &&
"Unexpected typed buffer type");
Type *ContainedType = HandleType->getTypeParameter(0);

// We need the size of an element in bytes so that we can calculate the offset
// in elements given a total offset in bytes later.
Type *ScalarType = ContainedType->getScalarType();
uint64_t ScalarSize = DL.getTypeSizeInBits(ScalarType) / 8;

// Process users keeping track of indexing accumulated from GEPs.
struct AccessAndIndex {
User *Access;
Value *Index;
};
SmallVector<AccessAndIndex> Worklist;
for (User *U : II->users())
Worklist.push_back({U, nullptr});

SmallVector<Instruction *> DeadInsts;
while (!Worklist.empty()) {
AccessAndIndex Current = Worklist.back();
Worklist.pop_back();

if (auto *GEP = dyn_cast<GetElementPtrInst>(Current.Access)) {
IRBuilder<> Builder(GEP);

Value *Index;
APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
if (GEP->accumulateConstantOffset(DL, ConstantOffset)) {
APInt Scaled = ConstantOffset.udiv(ScalarSize);
Index = ConstantInt::get(Builder.getInt32Ty(), Scaled);
} else {
auto IndexIt = GEP->idx_begin();
assert(cast<ConstantInt>(IndexIt)->getZExtValue() == 0 &&
"GEP is not indexing through pointer");
++IndexIt;
Index = *IndexIt;
assert(++IndexIt == GEP->idx_end() && "Too many indices in GEP");
}

for (User *U : GEP->users())
Worklist.push_back({U, Index});
DeadInsts.push_back(GEP);

} else if (auto *SI = dyn_cast<StoreInst>(Current.Access)) {
assert(SI->getValueOperand() != II && "Pointer escaped!");
IRBuilder<> Builder(SI);

Value *V = SI->getValueOperand();
if (V->getType() == ContainedType) {
// V is already the right type.
} else if (V->getType() == ScalarType) {
// We're storing a scalar, so we need to load the current value and only
// replace the relevant part.
auto *Load = Builder.CreateIntrinsic(
ContainedType, Intrinsic::dx_typedBufferLoad,
{II->getOperand(0), II->getOperand(1)});
// If we have an offset from seeing a GEP earlier, use it.
Value *IndexOp = Current.Index
? Current.Index
: ConstantInt::get(Builder.getInt32Ty(), 0);
V = Builder.CreateInsertElement(Load, V, IndexOp);
} else {
llvm_unreachable("Store to typed resource has invalid type");
}

auto *Inst = Builder.CreateIntrinsic(
Builder.getVoidTy(), Intrinsic::dx_typedBufferStore,
{II->getOperand(0), II->getOperand(1), V});
SI->replaceAllUsesWith(Inst);
DeadInsts.push_back(SI);

} else if (auto *LI = dyn_cast<LoadInst>(Current.Access)) {
IRBuilder<> Builder(LI);
Value *V =
Builder.CreateIntrinsic(ContainedType, Intrinsic::dx_typedBufferLoad,
{II->getOperand(0), II->getOperand(1)});
if (Current.Index)
V = Builder.CreateExtractElement(V, Current.Index);

LI->replaceAllUsesWith(V);
DeadInsts.push_back(LI);

} else
llvm_unreachable("Unhandled instruction - pointer escaped?");
}

// Traverse the now-dead instructions in RPO and remove them.
for (Instruction *Dead : llvm::reverse(DeadInsts))
Dead->eraseFromParent();
II->eraseFromParent();
}

static bool transformResourcePointers(Function &F, DXILResourceTypeMap &DRTM) {
bool Changed = false;
SmallVector<std::pair<IntrinsicInst *, dxil::ResourceTypeInfo>> Resources;
for (BasicBlock &BB : F)
for (Instruction &I : BB)
if (auto *II = dyn_cast<IntrinsicInst>(&I))
if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer) {
auto *HandleTy = cast<TargetExtType>(II->getArgOperand(0)->getType());
Resources.emplace_back(II, DRTM[HandleTy]);
}

for (auto &[II, RI] : Resources) {
if (RI.isTyped()) {
Changed = true;
replaceTypedBufferAccess(II, RI);
}

// TODO: handle other resource types. We should probably have an
// `unreachable` here once we've added support for all of them.
}

return Changed;
}

PreservedAnalyses DXILResourceAccess::run(Function &F,
FunctionAnalysisManager &FAM) {
auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
DXILResourceTypeMap *DRTM =
MAMProxy.getCachedResult<DXILResourceTypeAnalysis>(*F.getParent());
assert(DRTM && "DXILResourceTypeAnalysis must be available");

bool MadeChanges = transformResourcePointers(F, *DRTM);
if (!MadeChanges)
return PreservedAnalyses::all();

PreservedAnalyses PA;
PA.preserve<DXILResourceTypeAnalysis>();
PA.preserve<DominatorTreeAnalysis>();
return PA;
}

namespace {
class DXILResourceAccessLegacy : public FunctionPass {
public:
bool runOnFunction(Function &F) override {
DXILResourceTypeMap &DRTM =
getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();

return transformResourcePointers(F, DRTM);
}
StringRef getPassName() const override { return "DXIL Resource Access"; }
DXILResourceAccessLegacy() : FunctionPass(ID) {}

static char ID; // Pass identification.
void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
AU.addRequired<DXILResourceTypeWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
}
};
char DXILResourceAccessLegacy::ID = 0;
} // end anonymous namespace

INITIALIZE_PASS_BEGIN(DXILResourceAccessLegacy, DEBUG_TYPE,
"DXIL Resource Access", false, false)
INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
INITIALIZE_PASS_END(DXILResourceAccessLegacy, DEBUG_TYPE,
"DXIL Resource Access", false, false)

FunctionPass *llvm::createDXILResourceAccessLegacyPass() {
return new DXILResourceAccessLegacy();
}
28 changes: 28 additions & 0 deletions llvm/lib/Target/DirectX/DXILResourceAccess.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===- DXILResourceAccess.h - Resource access via load/store ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// \file Pass for replacing pointers to DXIL resources with load and store
// operations.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
#define LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H

#include "llvm/IR/PassManager.h"

namespace llvm {

class DXILResourceAccess : public PassInfoMixin<DXILResourceAccess> {
public:
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
};

} // namespace llvm

#endif // LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
7 changes: 7 additions & 0 deletions llvm/lib/Target/DirectX/DirectX.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#define LLVM_LIB_TARGET_DIRECTX_DIRECTX_H

namespace llvm {
class FunctionPass;
class ModulePass;
class PassRegistry;
class raw_ostream;
Expand Down Expand Up @@ -52,6 +53,12 @@ void initializeDXILOpLoweringLegacyPass(PassRegistry &);
/// Pass to lowering LLVM intrinsic call to DXIL op function call.
ModulePass *createDXILOpLoweringLegacyPass();

/// Initializer for DXILResourceAccess
void initializeDXILResourceAccessLegacyPass(PassRegistry &);

/// Pass to update resource accesses to use load/store directly.
FunctionPass *createDXILResourceAccessLegacyPass();

/// Initializer for DXILTranslateMetadata.
void initializeDXILTranslateMetadataLegacyPass(PassRegistry &);

Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/DirectX/DirectXPassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ MODULE_PASS("dxil-translate-metadata", DXILTranslateMetadata())
// TODO: rename to print<foo> after NPM switch
MODULE_PASS("print-dx-shader-flags", dxil::ShaderFlagsAnalysisPrinter(dbgs()))
#undef MODULE_PASS

#ifndef FUNCTION_PASS
#define FUNCTION_PASS(NAME, CREATE_PASS)
#endif
FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
#undef FUNCTION_PASS
5 changes: 4 additions & 1 deletion llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "DXILIntrinsicExpansion.h"
#include "DXILOpLowering.h"
#include "DXILPrettyPrinter.h"
#include "DXILResourceAccess.h"
#include "DXILResourceAnalysis.h"
#include "DXILShaderFlags.h"
#include "DXILTranslateMetadata.h"
Expand Down Expand Up @@ -56,6 +57,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
initializeWriteDXILPassPass(*PR);
initializeDXContainerGlobalsPass(*PR);
initializeDXILOpLoweringLegacyPass(*PR);
initializeDXILResourceAccessLegacyPass(*PR);
initializeDXILTranslateMetadataLegacyPass(*PR);
initializeDXILResourceMDWrapperPass(*PR);
initializeShaderFlagsAnalysisWrapperPass(*PR);
Expand Down Expand Up @@ -92,9 +94,10 @@ class DirectXPassConfig : public TargetPassConfig {
addPass(createDXILFinalizeLinkageLegacyPass());
addPass(createDXILIntrinsicExpansionLegacyPass());
addPass(createDXILDataScalarizationLegacyPass());
addPass(createDXILFlattenArraysLegacyPass());
addPass(createDXILResourceAccessLegacyPass());
ScalarizerPassOptions DxilScalarOptions;
DxilScalarOptions.ScalarizeLoadStore = true;
addPass(createDXILFlattenArraysLegacyPass());
addPass(createScalarizerPass(DxilScalarOptions));
addPass(createDXILOpLoweringLegacyPass());
addPass(createDXILTranslateMetadataLegacyPass());
Expand Down
35 changes: 35 additions & 0 deletions llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
; RUN: opt -S -dxil-resource-access %s | FileCheck %s

target triple = "dxil-pc-shadermodel6.6-compute"

declare void @use_float4(<4 x float>)
declare void @use_float(float)

; CHECK-LABEL: define void @load_float4
define void @load_float4(i32 %index, i32 %elemindex) {
%buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK-NOT: @llvm.dx.resource.getpointer
%ptr = call ptr @llvm.dx.resource.getpointer(
target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)

; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
%vec_data = load <4 x float>, ptr %ptr
call void @use_float4(<4 x float> %vec_data)

; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
; CHECK: extractelement <4 x float> %[[VALUE]], i32 1
%y_ptr = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 1
%y_data = load float, ptr %y_ptr
call void @use_float(float %y_data)

; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
; CHECK: extractelement <4 x float> %[[VALUE]], i32 %elemindex
%dynamic = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 %elemindex
%dyndata = load float, ptr %dynamic
call void @use_float(float %dyndata)

ret void
}
Loading

0 comments on commit 0fca76d

Please sign in to comment.