Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix crash in DXIL.dll caused by illegal DXIL intrinsic. #6302

Merged
merged 6 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/DXIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3079,6 +3079,8 @@ INSTR.EVALINTERPOLATIONMODE Interpolation mode on %0 used with eva
INSTR.EXTRACTVALUE ExtractValue should only be used on dxil struct types and cmpxchg.
INSTR.FAILTORESLOVETGSMPOINTER TGSM pointers must originate from an unambiguous TGSM global variable.
INSTR.HANDLENOTFROMCREATEHANDLE Resource handle should returned by createHandle.
INSTR.ILLEGALDXILOPCODE DXILOpCode must be [0..%0]. %1 specified.
INSTR.ILLEGALDXILOPFUNCTION '%0' is not a DXILOpFuncition for DXILOpcode '%1'.
INSTR.IMMBIASFORSAMPLEB bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate.
INSTR.INBOUNDSACCESS Access to out-of-bounds memory is disallowed.
INSTR.MINPRECISIONNOTPRECISE Instructions marked precise may not refer to minprecision values.
Expand Down
4 changes: 1 addition & 3 deletions lib/DXIL/DxilCounters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,7 @@ void CountInstructions(llvm::Module &M, DxilCounters &counters) {
}
} else if (CallInst *CI = dyn_cast<CallInst>(I)) {
if (hlsl::OP::IsDxilOpFuncCallInst(CI)) {
unsigned opcode =
(unsigned)llvm::cast<llvm::ConstantInt>(I->getOperand(0))
->getZExtValue();
unsigned opcode = static_cast<unsigned>(hlsl::OP::getOpCode(CI));
CountDxilOp(opcode, counters);
}
} else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
Expand Down
82 changes: 42 additions & 40 deletions lib/DXIL/DxilOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2705,8 +2705,6 @@ llvm::StringRef OP::ConstructOverloadName(Type *Ty, DXIL::OpCode opCode,
}

const char *OP::GetOpCodeName(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].pOpCodeName;
}

Expand All @@ -2719,26 +2717,22 @@ const char *OP::GetAtomicOpName(DXIL::AtomicBinOpCode OpCode) {
}

OP::OpCodeClass OP::GetOpCodeClass(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].opCodeClass;
}

const char *OP::GetOpCodeClassName(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].pOpCodeClassName;
}

llvm::Attribute::AttrKind OP::GetMemAccessAttr(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
pow2clk marked this conversation as resolved.
Show resolved Hide resolved
return m_OpCodeProps[(unsigned)opCode].FuncAttr;
}

bool OP::IsOverloadLegal(OpCode opCode, Type *pType) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
if (!pType)
return false;
if (opCode == OpCode::NumOpCodes)
return false;
unsigned TypeSlot = GetTypeSlot(pType);
return TypeSlot != UINT_MAX &&
m_OpCodeProps[(unsigned)opCode].bAllowOverload[TypeSlot];
Expand Down Expand Up @@ -2814,8 +2808,13 @@ bool OP::IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode) {
}

OP::OpCode OP::getOpCode(const llvm::Instruction *I) {
return (OP::OpCode)llvm::cast<llvm::ConstantInt>(I->getOperand(0))
->getZExtValue();
auto *OpConst = llvm::dyn_cast<llvm::ConstantInt>(I->getOperand(0));
if (!OpConst)
return OpCode::NumOpCodes;
uint64_t OpCodeVal = OpConst->getZExtValue();
if (OpCodeVal >= static_cast<uint64_t>(OP::OpCode::NumOpCodes))
return OP::OpCode::NumOpCodes;
return static_cast<OP::OpCode>(OpCodeVal);
}

OP::OpCode OP::GetDxilOpFuncCallInst(const llvm::Instruction *I) {
Expand Down Expand Up @@ -3383,9 +3382,7 @@ void OP::RefreshCache() {
CallInst *CI = cast<CallInst>(*F.user_begin());
OpCode OpCode = OP::GetDxilOpFuncCallInst(CI);
Type *pOverloadType = OP::GetOverloadType(OpCode, &F);
Function *OpFunc = GetOpFunc(OpCode, pOverloadType);
(void)(OpFunc);
DXASSERT_NOMSG(OpFunc == &F);
GetOpFunc(OpCode, pOverloadType);
}
}
}
Expand All @@ -3404,13 +3401,15 @@ void OP::FixOverloadNames() {
CallInst *CI = cast<CallInst>(*F.user_begin());
DXIL::OpCode opCode = OP::GetDxilOpFuncCallInst(CI);
llvm::Type *Ty = OP::GetOverloadType(opCode, &F);
if (isa<StructType>(Ty) || isa<PointerType>(Ty)) {
std::string funcName;
if (OP::ConstructOverloadName(Ty, opCode, funcName)
.compare(F.getName()) != 0) {
F.setName(funcName);
}
}
if (!OP::IsOverloadLegal(opCode, Ty))
continue;
if (!isa<StructType>(Ty) && !isa<PointerType>(Ty))
continue;

std::string funcName;
if (OP::ConstructOverloadName(Ty, opCode, funcName)
.compare(F.getName()) != 0)
F.setName(funcName);
}
}
}
Expand All @@ -3421,12 +3420,11 @@ void OP::UpdateCache(OpCodeClass opClass, Type *Ty, llvm::Function *F) {
}

Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB OpCode");
assert(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes);
llvm-beanz marked this conversation as resolved.
Show resolved Hide resolved
DXASSERT(IsOverloadLegal(opCode, pOverloadType),
"otherwise the caller requested illegal operation overload (eg HLSL "
"function with unsupported types for mapped intrinsic function)");
if (opCode == OpCode::NumOpCodes)
return nullptr;
if (!IsOverloadLegal(opCode, pOverloadType))
return nullptr;

OpCodeClass opClass = m_OpCodeProps[(unsigned)opCode].opCodeClass;
Function *&F =
m_OpCodeClassCache[(unsigned)opClass].pOverloads[pOverloadType];
Expand Down Expand Up @@ -5369,8 +5367,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
// and return values to ensure that ResRetType is constructed in the
// RefreshCache case.
if (Function *existF = m_pModule->getFunction(funcName)) {
DXASSERT(existF->getFunctionType() == pFT,
"existing function must have the expected function type");
if (existF->getFunctionType() != pFT)
return nullptr;
F = existF;
UpdateCache(opClass, pOverloadType, F);
return F;
Expand All @@ -5389,9 +5387,6 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {

const SmallMapVector<llvm::Type *, llvm::Function *, 8> &
OP::GetOpFuncList(OpCode opCode) const {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB OpCode");
assert(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes);
pow2clk marked this conversation as resolved.
Show resolved Hide resolved
return m_OpCodeClassCache[(unsigned)m_OpCodeProps[(unsigned)opCode]
.opCodeClass]
.pOverloads;
Expand Down Expand Up @@ -5489,7 +5484,8 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::CallShader:
case OpCode::Pack4x8:
case OpCode::WaveMatrix_Fill:
DXASSERT_NOMSG(FT->getNumParams() > 2);
if (FT->getNumParams() <= 2)
return nullptr;
return FT->getParamType(2);
case OpCode::MinPrecXRegStore:
case OpCode::StoreOutput:
Expand All @@ -5499,7 +5495,8 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::StoreVertexOutput:
case OpCode::StorePrimitiveOutput:
case OpCode::DispatchMesh:
DXASSERT_NOMSG(FT->getNumParams() > 4);
if (FT->getNumParams() <= 4)
return nullptr;
return FT->getParamType(4);
case OpCode::IsNaN:
case OpCode::IsInf:
Expand All @@ -5517,22 +5514,27 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::WaveActiveAllEqual:
case OpCode::CreateHandleForLib:
case OpCode::WaveMatch:
DXASSERT_NOMSG(FT->getNumParams() > 1);
if (FT->getNumParams() <= 1)
return nullptr;
return FT->getParamType(1);
case OpCode::TextureStore:
case OpCode::TextureStoreSample:
DXASSERT_NOMSG(FT->getNumParams() > 5);
if (FT->getNumParams() <= 5)
return nullptr;
return FT->getParamType(5);
case OpCode::TraceRay:
DXASSERT_NOMSG(FT->getNumParams() > 15);
if (FT->getNumParams() <= 15)
return nullptr;
return FT->getParamType(15);
case OpCode::ReportHit:
case OpCode::WaveMatrix_ScalarOp:
DXASSERT_NOMSG(FT->getNumParams() > 3);
if (FT->getNumParams() <= 3)
return nullptr;
return FT->getParamType(3);
case OpCode::WaveMatrix_LoadGroupShared:
case OpCode::WaveMatrix_StoreGroupShared:
DXASSERT_NOMSG(FT->getNumParams() > 2);
if (FT->getNumParams() <= 2)
return nullptr;
return FT->getParamType(2)->getPointerElementType();
case OpCode::CreateHandle:
case OpCode::BufferUpdateCounter:
Expand Down
10 changes: 3 additions & 7 deletions lib/DXIL/DxilShaderFlags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,13 +572,9 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F,
if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
if (!OP::IsDxilOpFunc(CI->getCalledFunction()))
continue;
Value *opcodeArg = CI->getArgOperand(DXIL::OperandIndex::kOpcodeIdx);
ConstantInt *opcodeConst = dyn_cast<ConstantInt>(opcodeArg);
DXASSERT(opcodeConst, "DXIL opcode arg must be immediate");
unsigned opcode = opcodeConst->getLimitedValue();
DXASSERT(opcode < static_cast<unsigned>(DXIL::OpCode::NumOpCodes),
"invalid DXIL opcode");
DXIL::OpCode dxilOp = static_cast<DXIL::OpCode>(opcode);
DXIL::OpCode dxilOp = hlsl::OP::getOpCode(CI);
if (dxilOp == DXIL::OpCode::NumOpCodes)
continue;
if (hlsl::OP::IsDxilOpWave(dxilOp))
hasWaveOps = true;
switch (dxilOp) {
Expand Down
24 changes: 24 additions & 0 deletions lib/HLSL/DxilValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3188,6 +3188,8 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
CallInst *setMeshOutputCounts = nullptr;
CallInst *getMeshPayload = nullptr;
CallInst *dispatchMesh = nullptr;
hlsl::OP *hlslOP = ValCtx.DxilMod.GetOP();

for (auto b = F->begin(), bend = F->end(); b != bend; ++b) {
for (auto i = b->begin(), iend = b->end(); i != iend; ++i) {
llvm::Instruction &I = *i;
Expand Down Expand Up @@ -3237,8 +3239,30 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
}

unsigned opcode = OpcodeConst->getLimitedValue();
if (opcode >= static_cast<unsigned>(DXIL::OpCode::NumOpCodes)) {
pow2clk marked this conversation as resolved.
Show resolved Hide resolved
ValCtx.EmitInstrFormatError(
&I, ValidationRule::InstrIllegalDXILOpCode,
{std::to_string((unsigned)DXIL::OpCode::NumOpCodes),
std::to_string(opcode)});
continue;
}
DXIL::OpCode dxilOpcode = (DXIL::OpCode)opcode;

bool IllegalOpFunc = true;
for (auto &it : hlslOP->GetOpFuncList(dxilOpcode)) {
if (it.second == FCalled) {
IllegalOpFunc = false;
break;
}
}

if (IllegalOpFunc) {
ValCtx.EmitInstrFormatError(
&I, ValidationRule::InstrIllegalDXILOpFunction,
{FCalled->getName(), OP::GetOpCodeName(dxilOpcode)});
continue;
}

if (OP::IsDxilOpGradient(dxilOpcode)) {
gradientOps.push_back(CI);
}
Expand Down
120 changes: 120 additions & 0 deletions tools/clang/test/LitDXILValidation/illegalDXILOp.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
; REQUIRES: dxil-1-8
; RUN: not %dxv %s 2>&1 | FileCheck %s

target datalayout = "e-m:e-p:32:32-i1:32-i8:32-i16:32-i32:32-i64:64-f16:32-f32:32-f64:64-n8:16:32:64"
target triple = "dxil-ms-dx"

%dx.types.Handle = type { i8* }
%dx.types.ResBind = type { i32, i32, i32, i8 }
%dx.types.ResourceProperties = type { i32, i32 }
%dx.types.ResRet.f32 = type { float, float, float, float, i32 }
%"class.Texture2D<float>" = type { float, %"class.Texture2D<float>::mips_type" }
%"class.Texture2D<float>::mips_type" = type { i32 }
%struct.SamplerComparisonState = type { i32 }


define void @main() {
%1 = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind zeroinitializer, i32 0, i1 false) ; CreateHandleFromBinding(bind,index,nonUniformIndex)
%2 = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 0, i32 0, i32 0, i8 3 }, i32 0, i1 false) ; CreateHandleFromBinding(bind,index,nonUniformIndex)


; CHECK: error: 'dx.op.loadInput.f32' is not a DXILOpFuncition for DXILOpcode 'LoadInput'.
; CHECK: note: at '%3 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0)' in block '#0' of function 'main'.

%3 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0) ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)

; CHECK: error: 'dx.op.loadInput.f32' is not a DXILOpFuncition for DXILOpcode 'LoadInput'.
; CHECK: note: at '%4 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1)' in block '#0' of function 'main'.

%4 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1) ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)


; CHECK: error: 'dx.op.annotateHandle' is not a DXILOpFuncition for DXILOpcode 'MinPrecXRegStore'.
; CHECK: note: at '%5 = call %dx.types.Handle @dx.op.annotateHandle(i32 3, %dx.types.Handle %1, %dx.types.ResourceProperties { i32 2, i32 265 })' in block '#0' of function 'main'.

%5 = call %dx.types.Handle @dx.op.annotateHandle(i32 3, %dx.types.Handle %1, %dx.types.ResourceProperties { i32 2, i32 265 }) ; AnnotateHandle(res,props) resource: Texture2D<F32>

; CHECK: error: 'dx.op.annotateHandle2' is not a DXILOpFuncition for DXILOpcode 'AnnotateHandle'.
; CHECK: note: at '%6 = call %dx.types.Handle @dx.op.annotateHandle2(i32 216, %dx.types.Handle %2, %dx.types.ResourceProperties { i32 32782, i32 0 })' in block '#0' of function 'main'.

%6 = call %dx.types.Handle @dx.op.annotateHandle2(i32 216, %dx.types.Handle %2, %dx.types.ResourceProperties { i32 32782, i32 0 }) ; AnnotateHandle(res,props) resource: SamplerComparisonState

; CHECK: error: DXILOpCode must be [0..258]. 1999981 specified.
; CHECK: note: at '%7 = call float @dx.op.calculateLOD.f32(i32 1999981, %dx.types.Handle %5, %dx.types.Handle %6, float %3, float %4, float undef, i1 true)' in block '#0' of function 'main'.

%7 = call float @dx.op.calculateLOD.f32(i32 1999981, %dx.types.Handle %5, %dx.types.Handle %6, float %3, float %4, float undef, i1 true) ; CalculateLOD(handle,sampler,coord0,coord1,coord2,clamped)

%I = call i32 @dx.op.loadInput.i32(i32 4, i32 0, i32 0, i8 0, i32 undef) ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)

; CHECK: error: Opcode of DXIL operation must be an immediate constant.
; CHECK: note: at 'call void @dx.op.storeOutput.f32(i32 %I, i32 0, i32 0, i8 0, float %7)' in block '#0' of function 'main'.
call void @dx.op.storeOutput.f32(i32 %I, i32 0, i32 0, i8 0, float %7) ; StoreOutput(outputSigId,rowIndex,colIndex,value)


; CHECK-DAG: error: Opcode SampleCmpBias not valid in shader model ps_6_7.
%CmpBias = call %dx.types.ResRet.f32 @dx.op.sampleCmpBias.f32(i32 255, %dx.types.Handle %5, %dx.types.Handle %6, float %3, float %4, float undef, float undef, i32 0, i32 0, i32 undef, float 5.000000e-01, float 5.000000e-01, float undef) ; SampleCmpBias(srv,sampler,coord0,coord1,coord2,coord3,offset0,offset1,offset2,compareValue,bias,clamp)

ret void
}


; CHECK-DAG: error: DXIL intrinsic overload must be valid.
; CHECK-DAG: note: at '%4 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1)' in block '#0' of function 'main'.
; CHECK-DAG: error: DXIL intrinsic overload must be valid.
; CHECK-DAG: note: at '%3 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0)' in block '#0' of function 'main'.
; CHECK-DAG: error: DXIL intrinsic overload must be valid.
; CHECK-DAG: note: at '%5 = call %dx.types.Handle @dx.op.annotateHandle(i32 3, %dx.types.Handle %1, %dx.types.ResourceProperties { i32 2, i32 265 })' in block '#0' of function 'main'.

; Function Attrs: nounwind readnone
declare float @dx.op.loadInput.f32(i32, i32, i32, i8) #0

; Function Attrs: nounwind
declare void @dx.op.storeOutput.f32(i32, i32, i32, i8, float) #1

; Function Attrs: nounwind readonly
declare float @dx.op.calculateLOD.f32(i32, %dx.types.Handle, %dx.types.Handle, float, float, float, i1) #2

; Function Attrs: nounwind readnone
declare %dx.types.Handle @dx.op.annotateHandle(i32, %dx.types.Handle, %dx.types.ResourceProperties) #0

declare %dx.types.Handle @dx.op.annotateHandle2(i32, %dx.types.Handle, %dx.types.ResourceProperties) #0

; Function Attrs: nounwind readnone
declare %dx.types.Handle @dx.op.createHandleFromBinding(i32, %dx.types.ResBind, i32, i1) #0

declare i32 @dx.op.loadInput.i32(i32, i32, i32, i8, i32) #0

; Function Attrs: nounwind readonly
declare %dx.types.ResRet.f32 @dx.op.sampleCmpBias.f32(i32, %dx.types.Handle, %dx.types.Handle, float, float, float, float, i32, i32, i32, float, float, float) #2

attributes #0 = { nounwind readnone }
attributes #1 = { nounwind }
attributes #2 = { nounwind readonly }

!llvm.ident = !{!0}
!dx.version = !{!1}
!dx.valver = !{!1}
!dx.shaderModel = !{!2}
!dx.resources = !{!3}
!dx.viewIdState = !{!9}
!dx.entryPoints = !{!10}

!0 = !{!"dxc(private) 1.7.0.4396 (test_time, 849f8b884-dirty)"}
!1 = !{i32 1, i32 7}
!2 = !{!"ps", i32 6, i32 7}
!3 = !{!4, null, null, !7}
!4 = !{!5}
!5 = !{i32 0, %"class.Texture2D<float>"* undef, !"", i32 0, i32 0, i32 1, i32 2, i32 0, !6}
!6 = !{i32 0, i32 9}
!7 = !{!8}
!8 = !{i32 0, %struct.SamplerComparisonState* undef, !"", i32 0, i32 0, i32 1, i32 1, null}
!9 = !{[4 x i32] [i32 2, i32 1, i32 1, i32 1]}
!10 = !{void ()* @main, !"main", !11, !3, null}
!11 = !{!12, !16, null}
!12 = !{!13}
!13 = !{i32 0, !"A", i8 9, i8 0, !14, i8 2, i32 1, i8 2, i32 0, i8 0, !15}
!14 = !{i32 0}
!15 = !{i32 3, i32 3}
!16 = !{!17}
!17 = !{i32 0, !"SV_Target", i8 9, i8 16, !14, i8 0, i32 1, i8 1, i32 0, i8 0, !18}
!18 = !{i32 3, i32 1}
7 changes: 7 additions & 0 deletions utils/hct/hctdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7312,6 +7312,13 @@ def build_valrules(self):
"Instr.ImmBiasForSampleB",
"bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate.",
)
self.add_valrule(
"Instr.IllegalDXILOpCode", "DXILOpCode must be [0..%0]. %1 specified."
)
self.add_valrule(
"Instr.IllegalDXILOpFunction",
"'%0' is not a DXILOpFuncition for DXILOpcode '%1'.",
)
# If streams have not been declared, you must use cut instead of cut_stream in GS - is there an equivalent rule here?

# Need to clean up all error messages and actually implement.
Expand Down
Loading
Loading