From cd7438e75cd00d791d6b77aade7773f74102a243 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Tue, 13 Feb 2024 14:06:13 -0500 Subject: [PATCH 1/6] Fix crash in DXIL.dll caused by illegal DXIL intrinsic. Replace assert on illegal DXIL op with return illegal value. Check the illegal cases in validation. Fixes #6168 --- docs/DXIL.rst | 2 + lib/DXIL/DxilCounters.cpp | 4 +- lib/DXIL/DxilOperations.cpp | 60 ++++++---- lib/DXIL/DxilShaderFlags.cpp | 10 +- lib/HLSL/DxilValidation.cpp | 25 ++++ .../test/LitDXILValidation/illegalDXILOp.ll | 111 ++++++++++++++++++ utils/hct/hctdb.py | 8 ++ utils/hct/hctdb_instrhelp.py | 4 +- 8 files changed, 187 insertions(+), 37 deletions(-) create mode 100644 tools/clang/test/LitDXILValidation/illegalDXILOp.ll diff --git a/docs/DXIL.rst b/docs/DXIL.rst index 4da81c71b5..0964eca41b 100644 --- a/docs/DXIL.rst +++ b/docs/DXIL.rst @@ -3079,6 +3079,8 @@ INSTR.EVALINTERPOLATIONMODE Interpolation mode on %0 used with eva INSTR.EXTRACTVALUE ExtractValue should only be used on dxil struct types and cmpxchg. INSTR.FAILTORESLOVETGSMPOINTER TGSM pointers must originate from an unambiguous TGSM global variable. INSTR.HANDLENOTFROMCREATEHANDLE Resource handle should returned by createHandle. +INSTR.ILLEGALDXILOPCODE DXILOpCode must be [0..%0]. %1 specified. +INSTR.ILLEGALDXILOPFUNCTION '%0' is not a DXILOpFuncition for DXILOpcode '%1'. INSTR.IMMBIASFORSAMPLEB bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate. INSTR.INBOUNDSACCESS Access to out-of-bounds memory is disallowed. INSTR.MINPRECISIONNOTPRECISE Instructions marked precise may not refer to minprecision values. diff --git a/lib/DXIL/DxilCounters.cpp b/lib/DXIL/DxilCounters.cpp index 5aa359b633..b1e4e48162 100644 --- a/lib/DXIL/DxilCounters.cpp +++ b/lib/DXIL/DxilCounters.cpp @@ -336,9 +336,7 @@ void CountInstructions(llvm::Module &M, DxilCounters &counters) { } } else if (CallInst *CI = dyn_cast(I)) { if (hlsl::OP::IsDxilOpFuncCallInst(CI)) { - unsigned opcode = - (unsigned)llvm::cast(I->getOperand(0)) - ->getZExtValue(); + unsigned opcode = static_cast(hlsl::OP::getOpCode(CI)); CountDxilOp(opcode, counters); } } else if (isa(I) || isa(I)) { diff --git a/lib/DXIL/DxilOperations.cpp b/lib/DXIL/DxilOperations.cpp index c67cfcab5d..93bbefde9c 100644 --- a/lib/DXIL/DxilOperations.cpp +++ b/lib/DXIL/DxilOperations.cpp @@ -2737,8 +2737,8 @@ llvm::Attribute::AttrKind OP::GetMemAccessAttr(OpCode opCode) { } bool OP::IsOverloadLegal(OpCode opCode, Type *pType) { - DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes, - "otherwise caller passed OOB index"); + if (opCode >= OpCode::NumOpCodes) + return false; unsigned TypeSlot = GetTypeSlot(pType); return TypeSlot != UINT_MAX && m_OpCodeProps[(unsigned)opCode].bAllowOverload[TypeSlot]; @@ -2814,8 +2814,10 @@ bool OP::IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode) { } OP::OpCode OP::getOpCode(const llvm::Instruction *I) { - return (OP::OpCode)llvm::cast(I->getOperand(0)) - ->getZExtValue(); + auto *OpConst = llvm::dyn_cast(I->getOperand(0)); + if (!OpConst) + return OpCode::NumOpCodes; + return (OP::OpCode)OpConst->getZExtValue(); } OP::OpCode OP::GetDxilOpFuncCallInst(const llvm::Instruction *I) { @@ -3383,9 +3385,7 @@ void OP::RefreshCache() { CallInst *CI = cast(*F.user_begin()); OpCode OpCode = OP::GetDxilOpFuncCallInst(CI); Type *pOverloadType = OP::GetOverloadType(OpCode, &F); - Function *OpFunc = GetOpFunc(OpCode, pOverloadType); - (void)(OpFunc); - DXASSERT_NOMSG(OpFunc == &F); + GetOpFunc(OpCode, pOverloadType); } } } @@ -3404,11 +3404,13 @@ void OP::FixOverloadNames() { CallInst *CI = cast(*F.user_begin()); DXIL::OpCode opCode = OP::GetDxilOpFuncCallInst(CI); llvm::Type *Ty = OP::GetOverloadType(opCode, &F); - if (isa(Ty) || isa(Ty)) { - std::string funcName; - if (OP::ConstructOverloadName(Ty, opCode, funcName) - .compare(F.getName()) != 0) { - F.setName(funcName); + if (OP::IsOverloadLegal(opCode, Ty)) { + if (isa(Ty) || isa(Ty)) { + std::string funcName; + if (OP::ConstructOverloadName(Ty, opCode, funcName) + .compare(F.getName()) != 0) { + F.setName(funcName); + } } } } @@ -3421,12 +3423,11 @@ void OP::UpdateCache(OpCodeClass opClass, Type *Ty, llvm::Function *F) { } Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { - DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes, - "otherwise caller passed OOB OpCode"); - assert(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes); - DXASSERT(IsOverloadLegal(opCode, pOverloadType), - "otherwise the caller requested illegal operation overload (eg HLSL " - "function with unsupported types for mapped intrinsic function)"); + if (opCode >= OpCode::NumOpCodes) + return nullptr; + if (!IsOverloadLegal(opCode, pOverloadType)) + return nullptr; + OpCodeClass opClass = m_OpCodeProps[(unsigned)opCode].opCodeClass; Function *&F = m_OpCodeClassCache[(unsigned)opClass].pOverloads[pOverloadType]; @@ -5369,6 +5370,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { // and return values to ensure that ResRetType is constructed in the // RefreshCache case. if (Function *existF = m_pModule->getFunction(funcName)) { + if (existF->getFunctionType() != pFT) + return nullptr; DXASSERT(existF->getFunctionType() == pFT, "existing function must have the expected function type"); F = existF; @@ -5489,7 +5492,8 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { case OpCode::CallShader: case OpCode::Pack4x8: case OpCode::WaveMatrix_Fill: - DXASSERT_NOMSG(FT->getNumParams() > 2); + if (FT->getNumParams() <= 2) + return Ty; return FT->getParamType(2); case OpCode::MinPrecXRegStore: case OpCode::StoreOutput: @@ -5499,7 +5503,8 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { case OpCode::StoreVertexOutput: case OpCode::StorePrimitiveOutput: case OpCode::DispatchMesh: - DXASSERT_NOMSG(FT->getNumParams() > 4); + if (FT->getNumParams() <= 4) + return Ty; return FT->getParamType(4); case OpCode::IsNaN: case OpCode::IsInf: @@ -5517,22 +5522,27 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { case OpCode::WaveActiveAllEqual: case OpCode::CreateHandleForLib: case OpCode::WaveMatch: - DXASSERT_NOMSG(FT->getNumParams() > 1); + if (FT->getNumParams() <= 1) + return Ty; return FT->getParamType(1); case OpCode::TextureStore: case OpCode::TextureStoreSample: - DXASSERT_NOMSG(FT->getNumParams() > 5); + if (FT->getNumParams() <= 5) + return Ty; return FT->getParamType(5); case OpCode::TraceRay: - DXASSERT_NOMSG(FT->getNumParams() > 15); + if (FT->getNumParams() <= 15) + return Ty; return FT->getParamType(15); case OpCode::ReportHit: case OpCode::WaveMatrix_ScalarOp: - DXASSERT_NOMSG(FT->getNumParams() > 3); + if (FT->getNumParams() <= 3) + return Ty; return FT->getParamType(3); case OpCode::WaveMatrix_LoadGroupShared: case OpCode::WaveMatrix_StoreGroupShared: - DXASSERT_NOMSG(FT->getNumParams() > 2); + if (FT->getNumParams() <= 2) + return Ty; return FT->getParamType(2)->getPointerElementType(); case OpCode::CreateHandle: case OpCode::BufferUpdateCounter: diff --git a/lib/DXIL/DxilShaderFlags.cpp b/lib/DXIL/DxilShaderFlags.cpp index 48e7289f79..46c5fa75cb 100644 --- a/lib/DXIL/DxilShaderFlags.cpp +++ b/lib/DXIL/DxilShaderFlags.cpp @@ -572,13 +572,9 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F, if (const CallInst *CI = dyn_cast(&I)) { if (!OP::IsDxilOpFunc(CI->getCalledFunction())) continue; - Value *opcodeArg = CI->getArgOperand(DXIL::OperandIndex::kOpcodeIdx); - ConstantInt *opcodeConst = dyn_cast(opcodeArg); - DXASSERT(opcodeConst, "DXIL opcode arg must be immediate"); - unsigned opcode = opcodeConst->getLimitedValue(); - DXASSERT(opcode < static_cast(DXIL::OpCode::NumOpCodes), - "invalid DXIL opcode"); - DXIL::OpCode dxilOp = static_cast(opcode); + DXIL::OpCode dxilOp = hlsl::OP::getOpCode(CI); + if (dxilOp >= DXIL::OpCode::NumOpCodes) + continue; if (hlsl::OP::IsDxilOpWave(dxilOp)) hasWaveOps = true; switch (dxilOp) { diff --git a/lib/HLSL/DxilValidation.cpp b/lib/HLSL/DxilValidation.cpp index 8fe34fb281..36187402cd 100644 --- a/lib/HLSL/DxilValidation.cpp +++ b/lib/HLSL/DxilValidation.cpp @@ -3188,6 +3188,8 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) { CallInst *setMeshOutputCounts = nullptr; CallInst *getMeshPayload = nullptr; CallInst *dispatchMesh = nullptr; + hlsl::OP *hlslOP = ValCtx.DxilMod.GetOP(); + for (auto b = F->begin(), bend = F->end(); b != bend; ++b) { for (auto i = b->begin(), iend = b->end(); i != iend; ++i) { llvm::Instruction &I = *i; @@ -3239,6 +3241,29 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) { unsigned opcode = OpcodeConst->getLimitedValue(); DXIL::OpCode dxilOpcode = (DXIL::OpCode)opcode; + if (dxilOpcode >= DXIL::OpCode::NumOpCodes) { + ValCtx.EmitInstrFormatError( + &I, ValidationRule::InstrIllegalDXILOpCode, + {std::to_string((unsigned) DXIL::OpCode::NumOpCodes), + std::to_string(opcode)}); + continue; + } + + bool IllegalOpFunc = true; + for (auto &it : hlslOP->GetOpFuncList(dxilOpcode)) { + if (it.second == FCalled) { + IllegalOpFunc = false; + break; + } + } + + if (IllegalOpFunc) { + ValCtx.EmitInstrFormatError( + &I, ValidationRule::InstrIllegalDXILOpFunction, + {FCalled->getName(), OP::GetOpCodeName(dxilOpcode)}); + continue; + } + if (OP::IsDxilOpGradient(dxilOpcode)) { gradientOps.push_back(CI); } diff --git a/tools/clang/test/LitDXILValidation/illegalDXILOp.ll b/tools/clang/test/LitDXILValidation/illegalDXILOp.ll new file mode 100644 index 0000000000..e55eb2b1f1 --- /dev/null +++ b/tools/clang/test/LitDXILValidation/illegalDXILOp.ll @@ -0,0 +1,111 @@ +; RUN: not %dxv %s 2>&1 | FileCheck %s + +target datalayout = "e-m:e-p:32:32-i1:32-i8:32-i16:32-i32:32-i64:64-f16:32-f32:32-f64:64-n8:16:32:64" +target triple = "dxil-ms-dx" + +%dx.types.Handle = type { i8* } +%dx.types.ResBind = type { i32, i32, i32, i8 } +%dx.types.ResourceProperties = type { i32, i32 } +%"class.Texture2D" = type { float, %"class.Texture2D::mips_type" } +%"class.Texture2D::mips_type" = type { i32 } +%struct.SamplerComparisonState = type { i32 } + + +define void @main() { + %1 = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind zeroinitializer, i32 0, i1 false) ; CreateHandleFromBinding(bind,index,nonUniformIndex) + %2 = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 0, i32 0, i32 0, i8 3 }, i32 0, i1 false) ; CreateHandleFromBinding(bind,index,nonUniformIndex) + + +; CHECK: error: 'dx.op.loadInput.f32' is not a DXILOpFuncition for DXILOpcode 'LoadInput'. +; CHECK: note: at '%3 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0)' in block '#0' of function 'main'. + + %3 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0) ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis) + +; CHECK: error: 'dx.op.loadInput.f32' is not a DXILOpFuncition for DXILOpcode 'LoadInput'. +; CHECK: note: at '%4 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1)' in block '#0' of function 'main'. + + %4 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1) ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis) + + +; CHECK: error: 'dx.op.annotateHandle' is not a DXILOpFuncition for DXILOpcode 'MinPrecXRegStore'. +; CHECK: note: at '%5 = call %dx.types.Handle @dx.op.annotateHandle(i32 3, %dx.types.Handle %1, %dx.types.ResourceProperties { i32 2, i32 265 })' in block '#0' of function 'main'. + + %5 = call %dx.types.Handle @dx.op.annotateHandle(i32 3, %dx.types.Handle %1, %dx.types.ResourceProperties { i32 2, i32 265 }) ; AnnotateHandle(res,props) resource: Texture2D + +; CHECK: error: 'dx.op.annotateHandle2' is not a DXILOpFuncition for DXILOpcode 'AnnotateHandle'. +; CHECK: note: at '%6 = call %dx.types.Handle @dx.op.annotateHandle2(i32 216, %dx.types.Handle %2, %dx.types.ResourceProperties { i32 32782, i32 0 })' in block '#0' of function 'main'. + + %6 = call %dx.types.Handle @dx.op.annotateHandle2(i32 216, %dx.types.Handle %2, %dx.types.ResourceProperties { i32 32782, i32 0 }) ; AnnotateHandle(res,props) resource: SamplerComparisonState + +; CHECK: error: DXILOpCode must be [0..258]. 1999981 specified. +; CHECK: note: at '%7 = call float @dx.op.calculateLOD.f32(i32 1999981, %dx.types.Handle %5, %dx.types.Handle %6, float %3, float %4, float undef, i1 true)' in block '#0' of function 'main'. + + %7 = call float @dx.op.calculateLOD.f32(i32 1999981, %dx.types.Handle %5, %dx.types.Handle %6, float %3, float %4, float undef, i1 true) ; CalculateLOD(handle,sampler,coord0,coord1,coord2,clamped) + + %I = call i32 @dx.op.loadInput.i32(i32 4, i32 0, i32 0, i8 0, i32 undef) ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis) + +; CHECK: error: Opcode of DXIL operation must be an immediate constant. +; CHECK: note: at 'call void @dx.op.storeOutput.f32(i32 %I, i32 0, i32 0, i8 0, float %7)' in block '#0' of function 'main'. + call void @dx.op.storeOutput.f32(i32 %I, i32 0, i32 0, i8 0, float %7) ; StoreOutput(outputSigId,rowIndex,colIndex,value) + ret void +} + + +; CHECK: error: DXIL intrinsic overload must be valid. +; CHECK: note: at '%4 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1)' in block '#0' of function 'main'. +; CHECK: error: DXIL intrinsic overload must be valid. +; CHECK: note: at '%3 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0)' in block '#0' of function 'main'. +; CHECK: error: DXIL intrinsic overload must be valid. +; CHECK: note: at '%5 = call %dx.types.Handle @dx.op.annotateHandle(i32 3, %dx.types.Handle %1, %dx.types.ResourceProperties { i32 2, i32 265 })' in block '#0' of function 'main'. + +; Function Attrs: nounwind readnone +declare float @dx.op.loadInput.f32(i32, i32, i32, i8) #0 + +; Function Attrs: nounwind +declare void @dx.op.storeOutput.f32(i32, i32, i32, i8, float) #1 + +; Function Attrs: nounwind readonly +declare float @dx.op.calculateLOD.f32(i32, %dx.types.Handle, %dx.types.Handle, float, float, float, i1) #2 + +; Function Attrs: nounwind readnone +declare %dx.types.Handle @dx.op.annotateHandle(i32, %dx.types.Handle, %dx.types.ResourceProperties) #0 + +declare %dx.types.Handle @dx.op.annotateHandle2(i32, %dx.types.Handle, %dx.types.ResourceProperties) #0 + +; Function Attrs: nounwind readnone +declare %dx.types.Handle @dx.op.createHandleFromBinding(i32, %dx.types.ResBind, i32, i1) #0 + +declare i32 @dx.op.loadInput.i32(i32, i32, i32, i8, i32) #0 + + +attributes #0 = { nounwind readnone } +attributes #1 = { nounwind } +attributes #2 = { nounwind readonly } + +!llvm.ident = !{!0} +!dx.version = !{!1} +!dx.valver = !{!1} +!dx.shaderModel = !{!2} +!dx.resources = !{!3} +!dx.viewIdState = !{!9} +!dx.entryPoints = !{!10} + +!0 = !{!"dxc(private) 1.7.0.4396 (test_time, 849f8b884-dirty)"} +!1 = !{i32 1, i32 7} +!2 = !{!"ps", i32 6, i32 7} +!3 = !{!4, null, null, !7} +!4 = !{!5} +!5 = !{i32 0, %"class.Texture2D"* undef, !"", i32 0, i32 0, i32 1, i32 2, i32 0, !6} +!6 = !{i32 0, i32 9} +!7 = !{!8} +!8 = !{i32 0, %struct.SamplerComparisonState* undef, !"", i32 0, i32 0, i32 1, i32 1, null} +!9 = !{[4 x i32] [i32 2, i32 1, i32 1, i32 1]} +!10 = !{void ()* @main, !"main", !11, !3, null} +!11 = !{!12, !16, null} +!12 = !{!13} +!13 = !{i32 0, !"A", i8 9, i8 0, !14, i8 2, i32 1, i8 2, i32 0, i8 0, !15} +!14 = !{i32 0} +!15 = !{i32 3, i32 3} +!16 = !{!17} +!17 = !{i32 0, !"SV_Target", i8 9, i8 16, !14, i8 0, i32 1, i8 1, i32 0, i8 0, !18} +!18 = !{i32 3, i32 1} diff --git a/utils/hct/hctdb.py b/utils/hct/hctdb.py index 0f6d7cb664..03752ea4a0 100644 --- a/utils/hct/hctdb.py +++ b/utils/hct/hctdb.py @@ -7312,6 +7312,14 @@ def build_valrules(self): "Instr.ImmBiasForSampleB", "bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate.", ) + self.add_valrule( + "Instr.IllegalDXILOpCode", + "DXILOpCode must be [0..%0]. %1 specified." + ) + self.add_valrule( + "Instr.IllegalDXILOpFunction", + "'%0' is not a DXILOpFuncition for DXILOpcode '%1'." + ) # If streams have not been declared, you must use cut instead of cut_stream in GS - is there an equivalent rule here? # Need to clean up all error messages and actually implement. diff --git a/utils/hct/hctdb_instrhelp.py b/utils/hct/hctdb_instrhelp.py index 7d20c25e9b..12a78a13c1 100644 --- a/utils/hct/hctdb_instrhelp.py +++ b/utils/hct/hctdb_instrhelp.py @@ -722,7 +722,7 @@ def print_opfunc_oload_type(self): for opcode in opcodes: line = line + "case OpCode::{name}".format(name=opcode + ":\n") - line = line + " DXASSERT_NOMSG(FT->getNumParams() > " + str(index) + ");\n" + line = line + " if (FT->getNumParams() <= " + str(index) + ") return Ty;\n" line = line + " return FT->getParamType(" + str(index) + ");" print(line) @@ -732,7 +732,7 @@ def print_opfunc_oload_type(self): for opcode in opcodes: line = line + "case OpCode::{name}".format(name=opcode + ":\n") - line = line + " DXASSERT_NOMSG(FT->getNumParams() > " + str(index) + ");\n" + line = line + " if (FT->getNumParams() <= " + str(index) + ") return Ty;\n" line = ( line + " return FT->getParamType(" From 30dd78d26a85d2f0c629b561aa7491649a4258f8 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Tue, 13 Feb 2024 14:12:53 -0500 Subject: [PATCH 2/6] Format fix. --- lib/HLSL/DxilValidation.cpp | 2 +- utils/hct/hctdb.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/HLSL/DxilValidation.cpp b/lib/HLSL/DxilValidation.cpp index 36187402cd..8ca66e86d5 100644 --- a/lib/HLSL/DxilValidation.cpp +++ b/lib/HLSL/DxilValidation.cpp @@ -3244,7 +3244,7 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) { if (dxilOpcode >= DXIL::OpCode::NumOpCodes) { ValCtx.EmitInstrFormatError( &I, ValidationRule::InstrIllegalDXILOpCode, - {std::to_string((unsigned) DXIL::OpCode::NumOpCodes), + {std::to_string((unsigned)DXIL::OpCode::NumOpCodes), std::to_string(opcode)}); continue; } diff --git a/utils/hct/hctdb.py b/utils/hct/hctdb.py index 03752ea4a0..8704a19338 100644 --- a/utils/hct/hctdb.py +++ b/utils/hct/hctdb.py @@ -7313,12 +7313,11 @@ def build_valrules(self): "bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate.", ) self.add_valrule( - "Instr.IllegalDXILOpCode", - "DXILOpCode must be [0..%0]. %1 specified." + "Instr.IllegalDXILOpCode", "DXILOpCode must be [0..%0]. %1 specified." ) self.add_valrule( "Instr.IllegalDXILOpFunction", - "'%0' is not a DXILOpFuncition for DXILOpcode '%1'." + "'%0' is not a DXILOpFuncition for DXILOpcode '%1'.", ) # If streams have not been declared, you must use cut instead of cut_stream in GS - is there an equivalent rule here? From dbdf2368b265485d04198a46be58407763b95e4c Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Thu, 15 Feb 2024 11:13:17 -0500 Subject: [PATCH 3/6] Cleanup code. --- lib/DXIL/DxilOperations.cpp | 56 ++++++++++++++++-------------------- utils/hct/hctdb_instrhelp.py | 4 +-- 2 files changed, 26 insertions(+), 34 deletions(-) diff --git a/lib/DXIL/DxilOperations.cpp b/lib/DXIL/DxilOperations.cpp index 93bbefde9c..8a0a7f0e63 100644 --- a/lib/DXIL/DxilOperations.cpp +++ b/lib/DXIL/DxilOperations.cpp @@ -2705,8 +2705,6 @@ llvm::StringRef OP::ConstructOverloadName(Type *Ty, DXIL::OpCode opCode, } const char *OP::GetOpCodeName(OpCode opCode) { - DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes, - "otherwise caller passed OOB index"); return m_OpCodeProps[(unsigned)opCode].pOpCodeName; } @@ -2719,25 +2717,21 @@ const char *OP::GetAtomicOpName(DXIL::AtomicBinOpCode OpCode) { } OP::OpCodeClass OP::GetOpCodeClass(OpCode opCode) { - DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes, - "otherwise caller passed OOB index"); return m_OpCodeProps[(unsigned)opCode].opCodeClass; } const char *OP::GetOpCodeClassName(OpCode opCode) { - DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes, - "otherwise caller passed OOB index"); return m_OpCodeProps[(unsigned)opCode].pOpCodeClassName; } llvm::Attribute::AttrKind OP::GetMemAccessAttr(OpCode opCode) { - DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes, - "otherwise caller passed OOB index"); return m_OpCodeProps[(unsigned)opCode].FuncAttr; } bool OP::IsOverloadLegal(OpCode opCode, Type *pType) { - if (opCode >= OpCode::NumOpCodes) + if (!pType) + return false; + if (opCode == OpCode::NumOpCodes) return false; unsigned TypeSlot = GetTypeSlot(pType); return TypeSlot != UINT_MAX && @@ -2817,7 +2811,10 @@ OP::OpCode OP::getOpCode(const llvm::Instruction *I) { auto *OpConst = llvm::dyn_cast(I->getOperand(0)); if (!OpConst) return OpCode::NumOpCodes; - return (OP::OpCode)OpConst->getZExtValue(); + uint64_t OpCodeVal = OpConst->getZExtValue(); + if (OpCodeVal >= static_cast(OP::OpCode::NumOpCodes)) + return OP::OpCode::NumOpCodes; + return static_cast(OpCodeVal); } OP::OpCode OP::GetDxilOpFuncCallInst(const llvm::Instruction *I) { @@ -3404,15 +3401,15 @@ void OP::FixOverloadNames() { CallInst *CI = cast(*F.user_begin()); DXIL::OpCode opCode = OP::GetDxilOpFuncCallInst(CI); llvm::Type *Ty = OP::GetOverloadType(opCode, &F); - if (OP::IsOverloadLegal(opCode, Ty)) { - if (isa(Ty) || isa(Ty)) { - std::string funcName; - if (OP::ConstructOverloadName(Ty, opCode, funcName) - .compare(F.getName()) != 0) { - F.setName(funcName); - } - } - } + if (!OP::IsOverloadLegal(opCode, Ty)) + continue; + if (!isa(Ty) && !isa(Ty)) + continue; + + std::string funcName; + if (OP::ConstructOverloadName(Ty, opCode, funcName) + .compare(F.getName()) != 0) + F.setName(funcName); } } } @@ -3423,7 +3420,7 @@ void OP::UpdateCache(OpCodeClass opClass, Type *Ty, llvm::Function *F) { } Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { - if (opCode >= OpCode::NumOpCodes) + if (opCode == OpCode::NumOpCodes) return nullptr; if (!IsOverloadLegal(opCode, pOverloadType)) return nullptr; @@ -5372,8 +5369,6 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { if (Function *existF = m_pModule->getFunction(funcName)) { if (existF->getFunctionType() != pFT) return nullptr; - DXASSERT(existF->getFunctionType() == pFT, - "existing function must have the expected function type"); F = existF; UpdateCache(opClass, pOverloadType, F); return F; @@ -5392,9 +5387,6 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { const SmallMapVector & OP::GetOpFuncList(OpCode opCode) const { - DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes, - "otherwise caller passed OOB OpCode"); - assert(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes); return m_OpCodeClassCache[(unsigned)m_OpCodeProps[(unsigned)opCode] .opCodeClass] .pOverloads; @@ -5493,7 +5485,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { case OpCode::Pack4x8: case OpCode::WaveMatrix_Fill: if (FT->getNumParams() <= 2) - return Ty; + return nullptr; return FT->getParamType(2); case OpCode::MinPrecXRegStore: case OpCode::StoreOutput: @@ -5504,7 +5496,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { case OpCode::StorePrimitiveOutput: case OpCode::DispatchMesh: if (FT->getNumParams() <= 4) - return Ty; + return nullptr; return FT->getParamType(4); case OpCode::IsNaN: case OpCode::IsInf: @@ -5523,26 +5515,26 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { case OpCode::CreateHandleForLib: case OpCode::WaveMatch: if (FT->getNumParams() <= 1) - return Ty; + return nullptr; return FT->getParamType(1); case OpCode::TextureStore: case OpCode::TextureStoreSample: if (FT->getNumParams() <= 5) - return Ty; + return nullptr; return FT->getParamType(5); case OpCode::TraceRay: if (FT->getNumParams() <= 15) - return Ty; + return nullptr; return FT->getParamType(15); case OpCode::ReportHit: case OpCode::WaveMatrix_ScalarOp: if (FT->getNumParams() <= 3) - return Ty; + return nullptr; return FT->getParamType(3); case OpCode::WaveMatrix_LoadGroupShared: case OpCode::WaveMatrix_StoreGroupShared: if (FT->getNumParams() <= 2) - return Ty; + return nullptr; return FT->getParamType(2)->getPointerElementType(); case OpCode::CreateHandle: case OpCode::BufferUpdateCounter: diff --git a/utils/hct/hctdb_instrhelp.py b/utils/hct/hctdb_instrhelp.py index 12a78a13c1..b87569401f 100644 --- a/utils/hct/hctdb_instrhelp.py +++ b/utils/hct/hctdb_instrhelp.py @@ -722,7 +722,7 @@ def print_opfunc_oload_type(self): for opcode in opcodes: line = line + "case OpCode::{name}".format(name=opcode + ":\n") - line = line + " if (FT->getNumParams() <= " + str(index) + ") return Ty;\n" + line = line + " if (FT->getNumParams() <= " + str(index) + ") return nullptr;\n" line = line + " return FT->getParamType(" + str(index) + ");" print(line) @@ -732,7 +732,7 @@ def print_opfunc_oload_type(self): for opcode in opcodes: line = line + "case OpCode::{name}".format(name=opcode + ":\n") - line = line + " if (FT->getNumParams() <= " + str(index) + ") return Ty;\n" + line = line + " if (FT->getNumParams() <= " + str(index) + ") return nullptr;\n" line = ( line + " return FT->getParamType(" From d5f0a1844ca6b9d664394797014361ac254bb01e Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Thu, 15 Feb 2024 13:34:25 -0500 Subject: [PATCH 4/6] Code cleanup and format. --- lib/DXIL/DxilShaderFlags.cpp | 2 +- lib/HLSL/DxilValidation.cpp | 5 ++--- utils/hct/hctdb_instrhelp.py | 14 ++++++++++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/lib/DXIL/DxilShaderFlags.cpp b/lib/DXIL/DxilShaderFlags.cpp index 46c5fa75cb..efa57c2564 100644 --- a/lib/DXIL/DxilShaderFlags.cpp +++ b/lib/DXIL/DxilShaderFlags.cpp @@ -573,7 +573,7 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F, if (!OP::IsDxilOpFunc(CI->getCalledFunction())) continue; DXIL::OpCode dxilOp = hlsl::OP::getOpCode(CI); - if (dxilOp >= DXIL::OpCode::NumOpCodes) + if (dxilOp == DXIL::OpCode::NumOpCodes) continue; if (hlsl::OP::IsDxilOpWave(dxilOp)) hasWaveOps = true; diff --git a/lib/HLSL/DxilValidation.cpp b/lib/HLSL/DxilValidation.cpp index 8ca66e86d5..38aba47151 100644 --- a/lib/HLSL/DxilValidation.cpp +++ b/lib/HLSL/DxilValidation.cpp @@ -3239,15 +3239,14 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) { } unsigned opcode = OpcodeConst->getLimitedValue(); - DXIL::OpCode dxilOpcode = (DXIL::OpCode)opcode; - - if (dxilOpcode >= DXIL::OpCode::NumOpCodes) { + if (opcode >= static_cast(DXIL::OpCode::NumOpCodes)) { ValCtx.EmitInstrFormatError( &I, ValidationRule::InstrIllegalDXILOpCode, {std::to_string((unsigned)DXIL::OpCode::NumOpCodes), std::to_string(opcode)}); continue; } + DXIL::OpCode dxilOpcode = (DXIL::OpCode)opcode; bool IllegalOpFunc = true; for (auto &it : hlslOP->GetOpFuncList(dxilOpcode)) { diff --git a/utils/hct/hctdb_instrhelp.py b/utils/hct/hctdb_instrhelp.py index b87569401f..888d70f305 100644 --- a/utils/hct/hctdb_instrhelp.py +++ b/utils/hct/hctdb_instrhelp.py @@ -722,7 +722,12 @@ def print_opfunc_oload_type(self): for opcode in opcodes: line = line + "case OpCode::{name}".format(name=opcode + ":\n") - line = line + " if (FT->getNumParams() <= " + str(index) + ") return nullptr;\n" + line = ( + line + + " if (FT->getNumParams() <= " + + str(index) + + ") return nullptr;\n" + ) line = line + " return FT->getParamType(" + str(index) + ");" print(line) @@ -732,7 +737,12 @@ def print_opfunc_oload_type(self): for opcode in opcodes: line = line + "case OpCode::{name}".format(name=opcode + ":\n") - line = line + " if (FT->getNumParams() <= " + str(index) + ") return nullptr;\n" + line = ( + line + + " if (FT->getNumParams() <= " + + str(index) + + ") return nullptr;\n" + ) line = ( line + " return FT->getParamType(" From 1cd8042ee5511abe5764009a5b88fbe7e71b0e7c Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Fri, 16 Feb 2024 17:01:54 -0500 Subject: [PATCH 5/6] Limit to sm6.8 --- tools/clang/test/LitDXILValidation/illegalDXILOp.ll | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/clang/test/LitDXILValidation/illegalDXILOp.ll b/tools/clang/test/LitDXILValidation/illegalDXILOp.ll index e55eb2b1f1..e4263cfbe0 100644 --- a/tools/clang/test/LitDXILValidation/illegalDXILOp.ll +++ b/tools/clang/test/LitDXILValidation/illegalDXILOp.ll @@ -1,3 +1,4 @@ +; REQUIRES: dxil-1-8 ; RUN: not %dxv %s 2>&1 | FileCheck %s target datalayout = "e-m:e-p:32:32-i1:32-i8:32-i16:32-i32:32-i64:64-f16:32-f32:32-f64:64-n8:16:32:64" From ad56e996b40d005c981619576d3fd7d1b0666578 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Tue, 20 Feb 2024 12:24:28 -0500 Subject: [PATCH 6/6] Add test for use dxil opcode not support in current shader model. --- .../test/LitDXILValidation/illegalDXILOp.ll | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tools/clang/test/LitDXILValidation/illegalDXILOp.ll b/tools/clang/test/LitDXILValidation/illegalDXILOp.ll index e4263cfbe0..bb1309fddf 100644 --- a/tools/clang/test/LitDXILValidation/illegalDXILOp.ll +++ b/tools/clang/test/LitDXILValidation/illegalDXILOp.ll @@ -7,6 +7,7 @@ target triple = "dxil-ms-dx" %dx.types.Handle = type { i8* } %dx.types.ResBind = type { i32, i32, i32, i8 } %dx.types.ResourceProperties = type { i32, i32 } +%dx.types.ResRet.f32 = type { float, float, float, float, i32 } %"class.Texture2D" = type { float, %"class.Texture2D::mips_type" } %"class.Texture2D::mips_type" = type { i32 } %struct.SamplerComparisonState = type { i32 } @@ -48,16 +49,21 @@ define void @main() { ; CHECK: error: Opcode of DXIL operation must be an immediate constant. ; CHECK: note: at 'call void @dx.op.storeOutput.f32(i32 %I, i32 0, i32 0, i8 0, float %7)' in block '#0' of function 'main'. call void @dx.op.storeOutput.f32(i32 %I, i32 0, i32 0, i8 0, float %7) ; StoreOutput(outputSigId,rowIndex,colIndex,value) + + +; CHECK-DAG: error: Opcode SampleCmpBias not valid in shader model ps_6_7. + %CmpBias = call %dx.types.ResRet.f32 @dx.op.sampleCmpBias.f32(i32 255, %dx.types.Handle %5, %dx.types.Handle %6, float %3, float %4, float undef, float undef, i32 0, i32 0, i32 undef, float 5.000000e-01, float 5.000000e-01, float undef) ; SampleCmpBias(srv,sampler,coord0,coord1,coord2,coord3,offset0,offset1,offset2,compareValue,bias,clamp) + ret void } -; CHECK: error: DXIL intrinsic overload must be valid. -; CHECK: note: at '%4 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1)' in block '#0' of function 'main'. -; CHECK: error: DXIL intrinsic overload must be valid. -; CHECK: note: at '%3 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0)' in block '#0' of function 'main'. -; CHECK: error: DXIL intrinsic overload must be valid. -; CHECK: note: at '%5 = call %dx.types.Handle @dx.op.annotateHandle(i32 3, %dx.types.Handle %1, %dx.types.ResourceProperties { i32 2, i32 265 })' in block '#0' of function 'main'. +; CHECK-DAG: error: DXIL intrinsic overload must be valid. +; CHECK-DAG: note: at '%4 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1)' in block '#0' of function 'main'. +; CHECK-DAG: error: DXIL intrinsic overload must be valid. +; CHECK-DAG: note: at '%3 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0)' in block '#0' of function 'main'. +; CHECK-DAG: error: DXIL intrinsic overload must be valid. +; CHECK-DAG: note: at '%5 = call %dx.types.Handle @dx.op.annotateHandle(i32 3, %dx.types.Handle %1, %dx.types.ResourceProperties { i32 2, i32 265 })' in block '#0' of function 'main'. ; Function Attrs: nounwind readnone declare float @dx.op.loadInput.f32(i32, i32, i32, i8) #0 @@ -78,6 +84,8 @@ declare %dx.types.Handle @dx.op.createHandleFromBinding(i32, %dx.types.ResBind, declare i32 @dx.op.loadInput.i32(i32, i32, i32, i8, i32) #0 +; Function Attrs: nounwind readonly +declare %dx.types.ResRet.f32 @dx.op.sampleCmpBias.f32(i32, %dx.types.Handle, %dx.types.Handle, float, float, float, float, i32, i32, i32, float, float, float) #2 attributes #0 = { nounwind readnone } attributes #1 = { nounwind }