diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h index 6b577c02f0545..016e45e78c398 100644 --- a/llvm/include/llvm/Analysis/DXILResource.h +++ b/llvm/include/llvm/Analysis/DXILResource.h @@ -264,6 +264,8 @@ class ResourceInfo { class DXILResourceMap { SmallVector Resources; DenseMap CallMap; + // Mapping from Resource use to Resource Handle + DenseMap ResUseToHandleMap; unsigned FirstUAV = 0; unsigned FirstCBuffer = 0; unsigned FirstSampler = 0; @@ -335,6 +337,31 @@ class DXILResourceMap { } void print(raw_ostream &OS) const; + + void updateResourceMap(CallInst *origCallInst, CallInst *newCallInst); + + void updateResUseMap(CallInst *origResUse, CallInst *newResUse) { + assert((origResUse != nullptr) && (newResUse != nullptr) && + (origResUse != newResUse) && "Wrong Inputs"); + + updateResUseMapCommon(origResUse, newResUse, /*keepOrigResUseInMap=*/false); + } + + CallInst *findResHandleByUse(CallInst *resUse) { + auto Pos = ResUseToHandleMap.find(resUse); + assert((Pos != ResUseToHandleMap.end()) && + "Can't find the resource handle"); + + return Pos->second; + } + +private: + void updateResUseMapCommon(CallInst *origResUse, CallInst *newResUse, + bool keepOrigResUseInMap) { + ResUseToHandleMap.try_emplace(newResUse, findResHandleByUse(origResUse)); + if (!keepOrigResUseInMap) + ResUseToHandleMap.erase(origResUse); + } }; class DXILResourceAnalysis : public AnalysisInfoMixin { diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp index 2802480481690..601d2648ae028 100644 --- a/llvm/lib/Analysis/DXILResource.cpp +++ b/llvm/lib/Analysis/DXILResource.cpp @@ -719,6 +719,12 @@ DXILResourceMap::DXILResourceMap( if (Resources.empty() || RI != Resources.back()) Resources.push_back(RI); CallMap[CI] = Resources.size() - 1; + + // Build ResUseToHandleMap + for (auto it = CI->users().begin(); it != CI->users().end(); ++it) { + CallInst *CI_Use = dyn_cast(*it); + ResUseToHandleMap[CI_Use] = CI; + } } unsigned Size = Resources.size(); @@ -744,6 +750,47 @@ DXILResourceMap::DXILResourceMap( } } +// Parameter origCallInst: original Resource Handle +// Parameter newCallInst: new Resource Handle +// +// This function is needed when origCallInst's lowered to newCallInst. +// +// Because origCallInst and its uses will be replaced by newCallInst and new def +// instructions after lowering. The [origCallInst, resource info] entry in +// CallMap and [origCallInst's use, origCallInst] entries in ResUseToHandleMap +// have to be updated per the changes in lowering. +// +// What this function does are: +// 1. Add [newCallInst, resource info] entry in CallMap +// 2. Remove [origCallInst, resource info] entry in CallMap +// 3. Remap [origCallInst's use, origCallInst] entries to +// [origCallInst's use, newCallInst] entries in ResUseToHandleMap +// +// Remove those entries related to origCallInst in maps is necessary since +// origCallInst's no longer existing after lowering. Moreover, keeping those +// entries in maps will crash DXILResourceMap::print function +// +// FYI: +// Make sure to invoke this function before origCallInst->replaceAllUsesWith() +// and origCallInst->eraseFromParent() since this function needs to visit +// origCallInst and its uses. +// +void DXILResourceMap::updateResourceMap(CallInst *origCallInst, + CallInst *newCallInst) { + assert((origCallInst != nullptr) && (newCallInst != nullptr) && + (origCallInst != newCallInst)); + + CallMap.try_emplace(newCallInst, CallMap[origCallInst]); + CallMap.erase(origCallInst); + + // Update ResUseToHandleMap since Resource Handle changed + for (auto it = origCallInst->users().begin(); + it != origCallInst->users().end(); ++it) { + CallInst *CI_Use = dyn_cast(*it); + ResUseToHandleMap[CI_Use] = newCallInst; + } +} + void DXILResourceMap::print(raw_ostream &OS) const { for (unsigned I = 0, E = Resources.size(); I != E; ++I) { OS << "Binding " << I << ":\n"; @@ -756,6 +803,14 @@ void DXILResourceMap::print(raw_ostream &OS) const { CI->print(OS); OS << "\n"; } + + for (const auto &[ResUse, ResHandle] : ResUseToHandleMap) { + OS << "\n"; + OS << "Resource " << CallMap.find(ResHandle)->second; + OS << " is used by "; + ResUse->print(OS); + OS << "\n"; + } } //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index c62ba8c21d679..dcf4fd7e187aa 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -203,6 +203,8 @@ class OpLowerer { Value *Cast = createTmpHandleCast(*OpCall, CI->getType()); + DRM.updateResourceMap(CI, *OpCall); + CI->replaceAllUsesWith(Cast); CI->eraseFromParent(); return Error::success(); @@ -247,6 +249,8 @@ class OpLowerer { Value *Cast = createTmpHandleCast(*OpAnnotate, CI->getType()); + DRM.updateResourceMap(CI, *OpBind); + CI->replaceAllUsesWith(Cast); CI->eraseFromParent(); @@ -411,6 +415,9 @@ class OpLowerer { OpCode::BufferLoad, Args, CI->getName(), NewRetTy); if (Error E = OpCall.takeError()) return E; + + DRM.updateResUseMap(CI, *OpCall); + if (Error E = replaceResRetUses(CI, *OpCall, HasCheckBit)) return E; @@ -455,6 +462,8 @@ class OpLowerer { if (Error E = OpCall.takeError()) return E; + DRM.updateResUseMap(CI, *OpCall); + CI->eraseFromParent(); return Error::success(); }); diff --git a/llvm/test/Analysis/DXILResource/resource-map.ll b/llvm/test/Analysis/DXILResource/resource-map.ll new file mode 100644 index 0000000000000..65255d4c942e5 --- /dev/null +++ b/llvm/test/Analysis/DXILResource/resource-map.ll @@ -0,0 +1,36 @@ +; RUN: opt -S -disable-output -disable-output -passes="print" < %s 2>&1 | FileCheck %s + +define void @test_typedbuffer() { + ; RWBuffer Buf : register(u5, space3) + %uav1 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_f32_1_0( + i32 3, i32 5, i32 1, i32 0, i1 false) + ; CHECK: Binding [[UAV1:[0-9]+]]: + ; CHECK: Symbol: ptr undef + ; CHECK: Name: "" + ; CHECK: Binding: + ; CHECK: Record ID: 0 + ; CHECK: Space: 3 + ; CHECK: Lower Bound: 5 + ; CHECK: Size: 1 + ; CHECK: Class: UAV + ; CHECK: Kind: TypedBuffer + ; CHECK: Globally Coherent: 0 + ; CHECK: HasCounter: 0 + ; CHECK: IsROV: 0 + ; CHECK: Element Type: f32 + ; CHECK: Element Count: 4 + + ; CHECK: Call bound to [[UAV1]]: %uav1 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0t(i32 3, i32 5, i32 1, i32 0, i1 false) + ; CHECK-DAG: Resource [[UAV1]] is used by %data0 = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, i32 0) + ; CHECK-DAG: Resource [[UAV1]] is used by call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, i32 2, <4 x float> %data0) + + %data0 = call <4 x float> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, i32 0) + call void @llvm.dx.typedBufferStore( + target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, + i32 2, <4 x float> %data0) + + ret void +} + diff --git a/llvm/test/CodeGen/DirectX/DXILResource/dxil-resource-map.ll b/llvm/test/CodeGen/DirectX/DXILResource/dxil-resource-map.ll new file mode 100644 index 0000000000000..5f196e3ba7e82 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/DXILResource/dxil-resource-map.ll @@ -0,0 +1,48 @@ +; RUN: opt -S -disable-output -disable-output -passes="print,dxil-op-lower,print" -mtriple=dxil-pc-shadermodel6.6-compute < %s 2>&1 | FileCheck %s -check-prefixes=CHECK,CHECK_SM66 +; RUN: opt -S -disable-output -disable-output -passes="print,dxil-op-lower,print" -mtriple=dxil-pc-shadermodel6.2-compute < %s 2>&1 | FileCheck %s -check-prefixes=CHECK,CHECK_SM62 + +define void @test_typedbuffer() { + ; RWBuffer Buf : register(u5, space3) + %uav1 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_f32_1_0( + i32 3, i32 5, i32 1, i32 0, i1 false) + ; CHECK: Binding [[UAV1:[0-9]+]]: + ; CHECK: Symbol: ptr undef + ; CHECK: Name: "" + ; CHECK: Binding: + ; CHECK: Record ID: 0 + ; CHECK: Space: 3 + ; CHECK: Lower Bound: 5 + ; CHECK: Size: 1 + ; CHECK: Class: UAV + ; CHECK: Kind: TypedBuffer + ; CHECK: Globally Coherent: 0 + ; CHECK: HasCounter: 0 + ; CHECK: IsROV: 0 + ; CHECK: Element Type: f32 + ; CHECK: Element Count: 4 + + ; CHECK: Call bound to [[UAV1]]: %uav1 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0t(i32 3, i32 5, i32 1, i32 0, i1 false) + ; CHECK-DAG: Resource [[UAV1]] is used by %data0 = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, i32 0) + ; CHECK-DAG: Resource [[UAV1]] is used by call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, i32 2, <4 x float> %data0) + + %data0 = call <4 x float> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, i32 0) + call void @llvm.dx.typedBufferStore( + target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, + i32 2, <4 x float> %data0) + + ; + ;;; After dxil-op-lower, the DXILResourceMap info should be updated. + ; + ; CHECK_SM66: Call bound to [[UAV1]]: %uav11 = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 218, %dx.types.ResBind { i32 5, i32 5, i32 3, i8 1 }, i32 0, i1 false) + ; CHECK_SM66-DAG: Resource [[UAV1]] is used by %data02 = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle %uav1_annot, i32 0, i32 undef) + ; CHECK_SM66-DAG: Resource [[UAV1]] is used by call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %uav1_annot, i32 2, i32 undef, float %9, float %10, float %11, float %12, i8 15) + ; + ; CHECK_SM62: Call bound to [[UAV1]]: %uav11 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 0, i1 false) + ; CHECK_SM62-DAG: Resource [[UAV1]] is used by %data02 = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle %uav11, i32 0, i32 undef) + ; CHECK_SM62-DAG: Resource [[UAV1]] is used by call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %uav11, i32 2, i32 undef, float %9, float %10, float %11, float %12, i8 15) + + ret void +} +