diff --git a/lib/DxilPIXPasses/DxilPIXAddTidToAmplificationShaderPayload.cpp b/lib/DxilPIXPasses/DxilPIXAddTidToAmplificationShaderPayload.cpp index d0022398f4..e756e98781 100644 --- a/lib/DxilPIXPasses/DxilPIXAddTidToAmplificationShaderPayload.cpp +++ b/lib/DxilPIXPasses/DxilPIXAddTidToAmplificationShaderPayload.cpp @@ -45,7 +45,6 @@ void DxilPIXAddTidToAmplificationShaderPayload::applyOptions(PassOptions O) { } void AddValueToExpandedPayload(OP *HlslOP, llvm::IRBuilder<> &B, - ExpandedStruct &expanded, AllocaInst *NewStructAlloca, unsigned int expandedValueIndex, Value *value) { Constant *Zero32Arg = HlslOP->GetU32Const(0); @@ -53,135 +52,147 @@ void AddValueToExpandedPayload(OP *HlslOP, llvm::IRBuilder<> &B, IndexToAppendedValue.push_back(Zero32Arg); IndexToAppendedValue.push_back(HlslOP->GetU32Const(expandedValueIndex)); auto *PointerToEmbeddedNewValue = B.CreateInBoundsGEP( - expanded.ExpandedPayloadStructType, NewStructAlloca, IndexToAppendedValue, + NewStructAlloca, IndexToAppendedValue, "PointerToEmbeddedNewValue" + std::to_string(expandedValueIndex)); B.CreateStore(value, PointerToEmbeddedNewValue); } -bool DxilPIXAddTidToAmplificationShaderPayload::runOnModule(Module &M) { +void CopyAggregate(IRBuilder<> &B, Type *Ty, Value *Source, Value *Dest, + ArrayRef GEPIndices) { + if (StructType *ST = dyn_cast(Ty)) { + SmallVector StructIndices; + StructIndices.append(GEPIndices.begin(), GEPIndices.end()); + StructIndices.push_back(nullptr); + for (unsigned j = 0; j < ST->getNumElements(); ++j) { + StructIndices.back() = B.getInt32(j); + CopyAggregate(B, ST->getElementType(j), Source, Dest, StructIndices); + } + } else if (ArrayType *AT = dyn_cast(Ty)) { + SmallVector StructIndices; + StructIndices.append(GEPIndices.begin(), GEPIndices.end()); + StructIndices.push_back(nullptr); + for (unsigned j = 0; j < AT->getNumElements(); ++j) { + StructIndices.back() = B.getInt32(j); + CopyAggregate(B, AT->getArrayElementType(), Source, Dest, StructIndices); + } + } else { + auto *SourceGEP = B.CreateGEP(Source, GEPIndices, "CopyStructSourceGEP"); + Value *Val = B.CreateLoad(SourceGEP, "CopyStructLoad"); + auto *DestGEP = B.CreateGEP(Dest, GEPIndices, "CopyStructDestGEP"); + B.CreateStore(Val, DestGEP, "CopyStructStore"); + } +} +bool DxilPIXAddTidToAmplificationShaderPayload::runOnModule(Module &M) { DxilModule &DM = M.GetOrCreateDxilModule(); LLVMContext &Ctx = M.getContext(); OP *HlslOP = DM.GetOP(); - - Type *OriginalPayloadStructPointerType = nullptr; - Type *OriginalPayloadStructType = nullptr; - ExpandedStruct expanded; llvm::Function *entryFunction = PIXPassHelpers::GetEntryFunction(DM); for (inst_iterator I = inst_begin(entryFunction), E = inst_end(entryFunction); I != E; ++I) { - if (auto *Instr = llvm::cast(&*I)) { - if (hlsl::OP::IsDxilOpFuncCallInst(Instr, - hlsl::OP::OpCode::DispatchMesh)) { - DxilInst_DispatchMesh DispatchMesh(Instr); - OriginalPayloadStructPointerType = - DispatchMesh.get_payload()->getType(); - OriginalPayloadStructType = - OriginalPayloadStructPointerType->getPointerElementType(); - expanded = ExpandStructType(Ctx, OriginalPayloadStructType); - } - } - } - - AllocaInst *OldStructAlloca = nullptr; - AllocaInst *NewStructAlloca = nullptr; - std::vector allocasOfPayloadType; - for (inst_iterator I = inst_begin(entryFunction), E = inst_end(entryFunction); - I != E; ++I) { - auto *Inst = &*I; - if (llvm::isa(Inst)) { - auto *Alloca = llvm::cast(Inst); - if (Alloca->getType() == OriginalPayloadStructPointerType) { - allocasOfPayloadType.push_back(Alloca); - } + if (hlsl::OP::IsDxilOpFuncCallInst(&*I, hlsl::OP::OpCode::DispatchMesh)) { + DxilInst_DispatchMesh DispatchMesh(&*I); + Type *OriginalPayloadStructPointerType = + DispatchMesh.get_payload()->getType(); + Type *OriginalPayloadStructType = + OriginalPayloadStructPointerType->getPointerElementType(); + ExpandedStruct expanded = + ExpandStructType(Ctx, OriginalPayloadStructType); + + llvm::IRBuilder<> B(&*I); + + auto *NewStructAlloca = + B.CreateAlloca(expanded.ExpandedPayloadStructType, + HlslOP->GetU32Const(1), "NewPayload"); + NewStructAlloca->setAlignment(4); + auto PayloadType = + llvm::dyn_cast(DispatchMesh.get_payload()->getType()); + SmallVector GEPIndices; + GEPIndices.push_back(B.getInt32(0)); + CopyAggregate(B, PayloadType->getPointerElementType(), + DispatchMesh.get_payload(), NewStructAlloca, GEPIndices); + + Constant *Zero32Arg = HlslOP->GetU32Const(0); + Constant *One32Arg = HlslOP->GetU32Const(1); + Constant *Two32Arg = HlslOP->GetU32Const(2); + + auto GroupIdFunc = + HlslOP->GetOpFunc(DXIL::OpCode::GroupId, Type::getInt32Ty(Ctx)); + Constant *GroupIdOpcode = + HlslOP->GetU32Const((unsigned)DXIL::OpCode::GroupId); + auto *GroupIdX = + B.CreateCall(GroupIdFunc, {GroupIdOpcode, Zero32Arg}, "GroupIdX"); + auto *GroupIdY = + B.CreateCall(GroupIdFunc, {GroupIdOpcode, One32Arg}, "GroupIdY"); + auto *GroupIdZ = + B.CreateCall(GroupIdFunc, {GroupIdOpcode, Two32Arg}, "GroupIdZ"); + + // FlatGroupID = z + y*numZ + x*numY*numZ + // Where x,y,z are the group ID components, and numZ and numY are the + // corresponding AS group-count arguments to the DispatchMesh Direct3D API + auto *GroupYxNumZ = B.CreateMul( + GroupIdY, HlslOP->GetU32Const(m_DispatchArgumentZ), "GroupYxNumZ"); + auto *FlatGroupNumZY = + B.CreateAdd(GroupIdZ, GroupYxNumZ, "FlatGroupNumZY"); + auto *GroupXxNumYZ = B.CreateMul( + GroupIdX, + HlslOP->GetU32Const(m_DispatchArgumentY * m_DispatchArgumentZ), + "GroupXxNumYZ"); + auto *FlatGroupID = + B.CreateAdd(GroupXxNumYZ, FlatGroupNumZY, "FlatGroupID"); + + // The ultimate goal is a single unique thread ID for this AS thread. + // So take the flat group number, multiply it by the number of + // threads per group... + auto *FlatGroupIDWithSpaceForThreadInGroupId = B.CreateMul( + FlatGroupID, + HlslOP->GetU32Const(DM.GetNumThreads(0) * DM.GetNumThreads(1) * + DM.GetNumThreads(2)), + "FlatGroupIDWithSpaceForThreadInGroupId"); + + auto *FlattenedThreadIdInGroupFunc = HlslOP->GetOpFunc( + DXIL::OpCode::FlattenedThreadIdInGroup, Type::getInt32Ty(Ctx)); + Constant *FlattenedThreadIdInGroupOpcode = + HlslOP->GetU32Const((unsigned)DXIL::OpCode::FlattenedThreadIdInGroup); + auto FlatThreadIdInGroup = B.CreateCall(FlattenedThreadIdInGroupFunc, + {FlattenedThreadIdInGroupOpcode}, + "FlattenedThreadIdInGroup"); + + // ...and add the flat thread id: + auto *FlatId = B.CreateAdd(FlatGroupIDWithSpaceForThreadInGroupId, + FlatThreadIdInGroup, "FlatId"); + + AddValueToExpandedPayload( + HlslOP, B, NewStructAlloca, + expanded.ExpandedPayloadStructType->getStructNumElements() - 3, + FlatId); + AddValueToExpandedPayload( + HlslOP, B, NewStructAlloca, + expanded.ExpandedPayloadStructType->getStructNumElements() - 2, + DispatchMesh.get_threadGroupCountY()); + AddValueToExpandedPayload( + HlslOP, B, NewStructAlloca, + expanded.ExpandedPayloadStructType->getStructNumElements() - 1, + DispatchMesh.get_threadGroupCountZ()); + + auto DispatchMeshFn = HlslOP->GetOpFunc( + DXIL::OpCode::DispatchMesh, expanded.ExpandedPayloadStructPtrType); + Constant *DispatchMeshOpcode = + HlslOP->GetU32Const((unsigned)DXIL::OpCode::DispatchMesh); + B.CreateCall(DispatchMeshFn, + {DispatchMeshOpcode, DispatchMesh.get_threadGroupCountX(), + DispatchMesh.get_threadGroupCountY(), + DispatchMesh.get_threadGroupCountZ(), NewStructAlloca}); + I->removeFromParent(); + delete &*I; + // Validation requires exactly one DispatchMesh in an AS, so we can exit + // after the first one: + DM.ReEmitDxilResources(); + return true; } } - for (auto &Alloca : allocasOfPayloadType) { - OldStructAlloca = Alloca; - llvm::IRBuilder<> B(Alloca->getContext()); - NewStructAlloca = B.CreateAlloca(expanded.ExpandedPayloadStructType, - HlslOP->GetU32Const(1), "NewPayload"); - NewStructAlloca->setAlignment(Alloca->getAlignment()); - NewStructAlloca->insertAfter(Alloca); - - ReplaceAllUsesOfInstructionWithNewValueAndDeleteInstruction( - Alloca, NewStructAlloca, expanded.ExpandedPayloadStructType); - } - - auto F = HlslOP->GetOpFunc(DXIL::OpCode::DispatchMesh, - expanded.ExpandedPayloadStructPtrType); - for (auto FI = F->user_begin(); FI != F->user_end();) { - auto *FunctionUser = *FI++; - auto *UserInstruction = llvm::cast(FunctionUser); - DxilInst_DispatchMesh DispatchMesh(UserInstruction); - - llvm::IRBuilder<> B(UserInstruction); - - Constant *Zero32Arg = HlslOP->GetU32Const(0); - Constant *One32Arg = HlslOP->GetU32Const(1); - Constant *Two32Arg = HlslOP->GetU32Const(2); - - auto GroupIdFunc = - HlslOP->GetOpFunc(DXIL::OpCode::GroupId, Type::getInt32Ty(Ctx)); - Constant *GroupIdOpcode = - HlslOP->GetU32Const((unsigned)DXIL::OpCode::GroupId); - auto *GroupIdX = - B.CreateCall(GroupIdFunc, {GroupIdOpcode, Zero32Arg}, "GroupIdX"); - auto *GroupIdY = - B.CreateCall(GroupIdFunc, {GroupIdOpcode, One32Arg}, "GroupIdY"); - auto *GroupIdZ = - B.CreateCall(GroupIdFunc, {GroupIdOpcode, Two32Arg}, "GroupIdZ"); - - // FlatGroupID = z + y*numZ + x*numY*numZ - // Where x,y,z are the group ID components, and numZ and numY are the - // corresponding AS group-count arguments to the DispatchMesh Direct3D API - auto *GroupYxNumZ = B.CreateMul( - GroupIdY, HlslOP->GetU32Const(m_DispatchArgumentZ), "GroupYxNumZ"); - auto *FlatGroupNumZY = B.CreateAdd(GroupIdZ, GroupYxNumZ, "FlatGroupNumZY"); - auto *GroupXxNumYZ = B.CreateMul( - GroupIdX, - HlslOP->GetU32Const(m_DispatchArgumentY * m_DispatchArgumentZ), - "GroupXxNumYZ"); - auto *FlatGroupID = - B.CreateAdd(GroupXxNumYZ, FlatGroupNumZY, "FlatGroFlatGroupIDupNum"); - - // The ultimate goal is a single unique thread ID for this AS thread. - // So take the flat group number, multiply it by the number of - // threads per group... - auto *FlatGroupIDWithSpaceForThreadInGroupId = B.CreateMul( - FlatGroupID, - HlslOP->GetU32Const(DM.GetNumThreads(0) * DM.GetNumThreads(1) * - DM.GetNumThreads(2)), - "FlatGroupIDWithSpaceForThreadInGroupId"); - - auto *FlattenedThreadIdInGroupFunc = HlslOP->GetOpFunc( - DXIL::OpCode::FlattenedThreadIdInGroup, Type::getInt32Ty(Ctx)); - Constant *FlattenedThreadIdInGroupOpcode = - HlslOP->GetU32Const((unsigned)DXIL::OpCode::FlattenedThreadIdInGroup); - auto FlatThreadIdInGroup = B.CreateCall(FlattenedThreadIdInGroupFunc, - {FlattenedThreadIdInGroupOpcode}, - "FlattenedThreadIdInGroup"); - - // ...and add the flat thread id: - auto *FlatId = B.CreateAdd(FlatGroupIDWithSpaceForThreadInGroupId, - FlatThreadIdInGroup, "FlatId"); - - AddValueToExpandedPayload(HlslOP, B, expanded, NewStructAlloca, - OriginalPayloadStructType->getStructNumElements(), - FlatId); - AddValueToExpandedPayload( - HlslOP, B, expanded, NewStructAlloca, - OriginalPayloadStructType->getStructNumElements() + 1, - DispatchMesh.get_threadGroupCountY()); - AddValueToExpandedPayload( - HlslOP, B, expanded, NewStructAlloca, - OriginalPayloadStructType->getStructNumElements() + 2, - DispatchMesh.get_threadGroupCountZ()); - } - - DM.ReEmitDxilResources(); - return true; + return false; } char DxilPIXAddTidToAmplificationShaderPayload::ID = 0; diff --git a/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedComplexPayload.hlsl b/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedComplexPayload.hlsl new file mode 100644 index 0000000000..28eff71474 --- /dev/null +++ b/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedComplexPayload.hlsl @@ -0,0 +1,88 @@ +// RUN: %dxc -enable-16bit-types -Od -Emain -Tas_6_6 %s | %opt -S -hlsl-dxil-PIX-add-tid-to-as-payload,dispatchArgY=3,dispatchArgZ=7 | %FileCheck %s + +// Check that the payload was piece-wise copied into a local copy from group-shared: +// There are 28 elements: + +// CHECK: [[LOAD0:%.*]] = load [[TYPE0:.*]], [[TYPE0]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE0]] [[LOAD0]] +// CHECK: [[LOAD1:%.*]] = load [[TYPE1:.*]], [[TYPE1]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE1]] [[LOAD1]] +// CHECK: [[LOAD2:%.*]] = load [[TYPE2:.*]], [[TYPE2]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE2]] [[LOAD2]] +// CHECK: [[LOAD3:%.*]] = load [[TYPE3:.*]], [[TYPE3]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE3]] [[LOAD3]] +// CHECK: [[LOAD4:%.*]] = load [[TYPE4:.*]], [[TYPE4]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE4]] [[LOAD4]] +// CHECK: [[LOAD5:%.*]] = load [[TYPE5:.*]], [[TYPE5]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE5]] [[LOAD5]] +// CHECK: [[LOAD6:%.*]] = load [[TYPE6:.*]], [[TYPE6]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE6]] [[LOAD6]] +// CHECK: [[LOAD7:%.*]] = load [[TYPE7:.*]], [[TYPE7]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE7]] [[LOAD7]] +// CHECK: [[LOAD8:%.*]] = load [[TYPE8:.*]], [[TYPE8]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE8]] [[LOAD8]] +// CHECK: [[LOAD9:%.*]] = load [[TYPE9:.*]], [[TYPE9]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE9]] [[LOAD9]] + +// CHECK: [[LOAD10:%.*]] = load [[TYPE10:.*]], [[TYPE10]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE10]] [[LOAD10]] +// CHECK: [[LOAD11:%.*]] = load [[TYPE11:.*]], [[TYPE11]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE11]] [[LOAD11]] +// CHECK: [[LOAD12:%.*]] = load [[TYPE12:.*]], [[TYPE12]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE12]] [[LOAD12]] +// CHECK: [[LOAD13:%.*]] = load [[TYPE13:.*]], [[TYPE13]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE13]] [[LOAD13]] +// CHECK: [[LOAD14:%.*]] = load [[TYPE14:.*]], [[TYPE14]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE14]] [[LOAD14]] +// CHECK: [[LOAD15:%.*]] = load [[TYPE15:.*]], [[TYPE15]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE15]] [[LOAD15]] +// CHECK: [[LOAD16:%.*]] = load [[TYPE16:.*]], [[TYPE16]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE16]] [[LOAD16]] +// CHECK: [[LOAD17:%.*]] = load [[TYPE17:.*]], [[TYPE17]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE17]] [[LOAD17]] +// CHECK: [[LOAD18:%.*]] = load [[TYPE18:.*]], [[TYPE18]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE18]] [[LOAD18]] +// CHECK: [[LOAD19:%.*]] = load [[TYPE19:.*]], [[TYPE19]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE19]] [[LOAD19]] + +// CHECK: [[LOAD20:%.*]] = load [[TYPE20:.*]], [[TYPE20]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE20]] [[LOAD20]] +// CHECK: [[LOAD21:%.*]] = load [[TYPE21:.*]], [[TYPE21]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE21]] [[LOAD21]] +// CHECK: [[LOAD22:%.*]] = load [[TYPE22:.*]], [[TYPE22]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE22]] [[LOAD22]] +// CHECK: [[LOAD23:%.*]] = load [[TYPE23:.*]], [[TYPE23]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE23]] [[LOAD23]] +// CHECK: [[LOAD24:%.*]] = load [[TYPE24:.*]], [[TYPE24]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE24]] [[LOAD24]] +// CHECK: [[LOAD25:%.*]] = load [[TYPE25:.*]], [[TYPE25]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE25]] [[LOAD25]] +// CHECK: [[LOAD26:%.*]] = load [[TYPE26:.*]], [[TYPE26]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE26]] [[LOAD26]] +// CHECK: [[LOAD27:%.*]] = load [[TYPE27:.*]], [[TYPE27]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE27]] [[LOAD27]] + +// And no more: +// CHECK-NOT: [[LOAD28:%.*]] = load [[TYPE28:.*]], [[TYPE28]] addrspace(3)* getelementptr inbounds + +struct Contained { + uint j; + float af[3]; +}; + +struct Bigger { + half h; + Contained a[2]; +}; + +struct MyPayload { + uint i; + Bigger big[3]; +}; + +groupshared MyPayload payload; + +[numthreads(1, 1, 1)] void main(uint gid + : SV_GroupID) { + DispatchMesh(1, 1, 1, payload); +} diff --git a/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedPayload.hlsl b/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedPayload.hlsl new file mode 100644 index 0000000000..7de78a895b --- /dev/null +++ b/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedPayload.hlsl @@ -0,0 +1,21 @@ +// RUN: %dxc -Od -Emain -Tas_6_6 %s | %opt -S -hlsl-dxil-PIX-add-tid-to-as-payload,dispatchArgY=3,dispatchArgZ=7 | %FileCheck %s + +// Check that the payload was piece-wise copied into a local copy +// CHECK: [[LOADGEP:%.*]] = getelementptr %struct.MyPayload +// CHECK: [[LOAD:%.*]] = load i32, i32* [[LOADGEP]] +// CHECK: store volatile i32 [[LOAD]] + +struct MyPayload +{ + uint i; +}; + +groupshared MyPayload payload; + +[numthreads(1, 1, 1)] +void main(uint gid : SV_GroupID) +{ + MyPayload copy; + copy = payload; + DispatchMesh(1, 1, 1, copy); +} diff --git a/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedTrickyTypesPayload.hlsl b/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedTrickyTypesPayload.hlsl new file mode 100644 index 0000000000..6f3e70da00 --- /dev/null +++ b/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedTrickyTypesPayload.hlsl @@ -0,0 +1,28 @@ +// RUN: %dxc -enable-16bit-types -Od -Emain -Tas_6_6 %s | %opt -S -hlsl-dxil-PIX-add-tid-to-as-payload,dispatchArgY=3,dispatchArgZ=7 | %FileCheck %s + +// Check that the payload was piece-wise copied into a local copy from group-shared: +// There are only 2 elements (the bitfield should take up 1 uint slot) + +// CHECK: [[LOAD0:%.*]] = load [[TYPE0:.*]], [[TYPE0]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE0]] [[LOAD0]] +// CHECK: [[LOAD1:%.*]] = load [[TYPE1:.*]], [[TYPE1]] addrspace(3)* getelementptr inbounds +// CHECK:store volatile [[TYPE1]] [[LOAD1]] + +// And no more: +// CHECK-NOT: [[LOAD2:%.*]] = load {{.*}}, {{.*}} addrspace(3)* getelementptr inbounds + +struct MyPayload { + uint i; + void Init() { i = 27; } +struct { + int bf0 : 7; + int bf1 : 11; +} bitfields; +}; + +groupshared MyPayload payload; + +[numthreads(1, 1, 1)] void main(uint gid + : SV_GroupID) { + DispatchMesh(1, 1, 1, payload); +} diff --git a/tools/clang/unittests/HLSL/PixTest.cpp b/tools/clang/unittests/HLSL/PixTest.cpp index 55dad73127..58289767b7 100644 --- a/tools/clang/unittests/HLSL/PixTest.cpp +++ b/tools/clang/unittests/HLSL/PixTest.cpp @@ -102,7 +102,8 @@ class PixTest : public ::testing::Test { TEST_METHOD(CompileDebugDisasmPDB) TEST_METHOD(AddToASPayload) - + TEST_METHOD(AddToASGroupSharedPayload) + TEST_METHOD(AddToASGroupSharedPayload_MeshletCullSample) TEST_METHOD(SignatureModification_Empty) TEST_METHOD(SignatureModification_VertexIdAlready) TEST_METHOD(SignatureModification_SomethingElseFirst) @@ -565,7 +566,7 @@ PixTest::RunDxilPIXAddTidToAmplificationShaderPayloadPass(IDxcBlob *blob) { TEST_F(PixTest, AddToASPayload) { - const char *dynamicResourceDecriptorHeapAccess = R"( + const char *hlsl = R"( struct MyPayload { float f1; @@ -603,12 +604,10 @@ void MSMain( )"; - auto as = Compile(m_dllSupport, dynamicResourceDecriptorHeapAccess, L"as_6_6", - {}, L"ASMain"); + auto as = Compile(m_dllSupport, hlsl, L"as_6_6", {}, L"ASMain"); RunDxilPIXAddTidToAmplificationShaderPayloadPass(as); - auto ms = Compile(m_dllSupport, dynamicResourceDecriptorHeapAccess, L"ms_6_6", - {}, L"MSMain"); + auto ms = Compile(m_dllSupport, hlsl, L"ms_6_6", {}, L"MSMain"); RunDxilPIXMeshShaderOutputPass(ms); } unsigned FindOrAddVSInSignatureElementForInstanceOrVertexID( @@ -704,6 +703,63 @@ TEST_F(PixTest, SignatureModification_SomethingElseFirst) { VERIFY_ARE_EQUAL(sig.GetElement(2).GetStartRow(), 2); } +TEST_F(PixTest, AddToASGroupSharedPayload) { + + const char *hlsl = R"( +struct Contained +{ + uint j; + float af[3]; +}; + +struct Bigger +{ + half h; + void Init() { h = 1.f; } + Contained a[2]; +}; + +struct MyPayload +{ + uint i; + Bigger big[3]; +}; + +groupshared MyPayload payload; + +[numthreads(1, 1, 1)] +void main(uint gid : SV_GroupID) +{ + DispatchMesh(1, 1, 1, payload); +} + + )"; + + auto as = Compile(m_dllSupport, hlsl, L"as_6_6", {L"-Od"}, L"main"); + RunDxilPIXAddTidToAmplificationShaderPayloadPass(as); +} + +TEST_F(PixTest, AddToASGroupSharedPayload_MeshletCullSample) { + + const char *hlsl = R"( +struct MyPayload +{ + uint i[32]; +}; + +groupshared MyPayload payload; + +[numthreads(1, 1, 1)] +void main(uint gid : SV_GroupID) +{ + DispatchMesh(1, 1, 1, payload); +} + + )"; + + auto as = Compile(m_dllSupport, hlsl, L"as_6_6", {L"-Od"}, L"main"); + RunDxilPIXAddTidToAmplificationShaderPayloadPass(as); +} static llvm::DIType *PeelTypedefs(llvm::DIType *diTy) { using namespace llvm; const llvm::DITypeIdentifierMap EmptyMap;