Skip to content

Commit

Permalink
Fix crash in DXIL.dll caused by illegal DXIL intrinsic. (#6302)
Browse files Browse the repository at this point in the history
Replace assert on illegal DXIL op with return illegal value. Check the
illegal cases in validation.

Fixes #6168
  • Loading branch information
python3kgae committed Feb 20, 2024
1 parent 2314d06 commit 5dee81c
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 52 deletions.
2 changes: 2 additions & 0 deletions docs/DXIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3080,6 +3080,8 @@ INSTR.EVALINTERPOLATIONMODE Interpolation mode on %0 used with eva
INSTR.EXTRACTVALUE ExtractValue should only be used on dxil struct types and cmpxchg.
INSTR.FAILTORESLOVETGSMPOINTER TGSM pointers must originate from an unambiguous TGSM global variable.
INSTR.HANDLENOTFROMCREATEHANDLE Resource handle should returned by createHandle.
INSTR.ILLEGALDXILOPCODE DXILOpCode must be [0..%0]. %1 specified.
INSTR.ILLEGALDXILOPFUNCTION '%0' is not a DXILOpFuncition for DXILOpcode '%1'.
INSTR.IMMBIASFORSAMPLEB bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate.
INSTR.INBOUNDSACCESS Access to out-of-bounds memory is disallowed.
INSTR.MINPRECISIONNOTPRECISE Instructions marked precise may not refer to minprecision values.
Expand Down
4 changes: 1 addition & 3 deletions lib/DXIL/DxilCounters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,7 @@ void CountInstructions(llvm::Module &M, DxilCounters &counters) {
}
} else if (CallInst *CI = dyn_cast<CallInst>(I)) {
if (hlsl::OP::IsDxilOpFuncCallInst(CI)) {
unsigned opcode =
(unsigned)llvm::cast<llvm::ConstantInt>(I->getOperand(0))
->getZExtValue();
unsigned opcode = static_cast<unsigned>(hlsl::OP::getOpCode(CI));
CountDxilOp(opcode, counters);
}
} else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
Expand Down
82 changes: 42 additions & 40 deletions lib/DXIL/DxilOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2705,8 +2705,6 @@ llvm::StringRef OP::ConstructOverloadName(Type *Ty, DXIL::OpCode opCode,
}

const char *OP::GetOpCodeName(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].pOpCodeName;
}

Expand All @@ -2719,26 +2717,22 @@ const char *OP::GetAtomicOpName(DXIL::AtomicBinOpCode OpCode) {
}

OP::OpCodeClass OP::GetOpCodeClass(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].opCodeClass;
}

const char *OP::GetOpCodeClassName(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].pOpCodeClassName;
}

llvm::Attribute::AttrKind OP::GetMemAccessAttr(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].FuncAttr;
}

bool OP::IsOverloadLegal(OpCode opCode, Type *pType) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
if (!pType)
return false;
if (opCode == OpCode::NumOpCodes)
return false;
unsigned TypeSlot = GetTypeSlot(pType);
return TypeSlot != UINT_MAX &&
m_OpCodeProps[(unsigned)opCode].bAllowOverload[TypeSlot];
Expand Down Expand Up @@ -2814,8 +2808,13 @@ bool OP::IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode) {
}

OP::OpCode OP::getOpCode(const llvm::Instruction *I) {
return (OP::OpCode)llvm::cast<llvm::ConstantInt>(I->getOperand(0))
->getZExtValue();
auto *OpConst = llvm::dyn_cast<llvm::ConstantInt>(I->getOperand(0));
if (!OpConst)
return OpCode::NumOpCodes;
uint64_t OpCodeVal = OpConst->getZExtValue();
if (OpCodeVal >= static_cast<uint64_t>(OP::OpCode::NumOpCodes))
return OP::OpCode::NumOpCodes;
return static_cast<OP::OpCode>(OpCodeVal);
}

OP::OpCode OP::GetDxilOpFuncCallInst(const llvm::Instruction *I) {
Expand Down Expand Up @@ -3525,9 +3524,7 @@ void OP::RefreshCache() {
CallInst *CI = cast<CallInst>(*F.user_begin());
OpCode OpCode = OP::GetDxilOpFuncCallInst(CI);
Type *pOverloadType = OP::GetOverloadType(OpCode, &F);
Function *OpFunc = GetOpFunc(OpCode, pOverloadType);
(void)(OpFunc);
DXASSERT_NOMSG(OpFunc == &F);
GetOpFunc(OpCode, pOverloadType);
}
}
}
Expand All @@ -3546,13 +3543,15 @@ void OP::FixOverloadNames() {
CallInst *CI = cast<CallInst>(*F.user_begin());
DXIL::OpCode opCode = OP::GetDxilOpFuncCallInst(CI);
llvm::Type *Ty = OP::GetOverloadType(opCode, &F);
if (isa<StructType>(Ty) || isa<PointerType>(Ty)) {
std::string funcName;
if (OP::ConstructOverloadName(Ty, opCode, funcName)
.compare(F.getName()) != 0) {
F.setName(funcName);
}
}
if (!OP::IsOverloadLegal(opCode, Ty))
continue;
if (!isa<StructType>(Ty) && !isa<PointerType>(Ty))
continue;

std::string funcName;
if (OP::ConstructOverloadName(Ty, opCode, funcName)
.compare(F.getName()) != 0)
F.setName(funcName);
}
}
}
Expand All @@ -3563,12 +3562,11 @@ void OP::UpdateCache(OpCodeClass opClass, Type *Ty, llvm::Function *F) {
}

Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB OpCode");
assert(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes);
DXASSERT(IsOverloadLegal(opCode, pOverloadType),
"otherwise the caller requested illegal operation overload (eg HLSL "
"function with unsupported types for mapped intrinsic function)");
if (opCode == OpCode::NumOpCodes)
return nullptr;
if (!IsOverloadLegal(opCode, pOverloadType))
return nullptr;

OpCodeClass opClass = m_OpCodeProps[(unsigned)opCode].opCodeClass;
Function *&F =
m_OpCodeClassCache[(unsigned)opClass].pOverloads[pOverloadType];
Expand Down Expand Up @@ -5511,8 +5509,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
// and return values to ensure that ResRetType is constructed in the
// RefreshCache case.
if (Function *existF = m_pModule->getFunction(funcName)) {
DXASSERT(existF->getFunctionType() == pFT,
"existing function must have the expected function type");
if (existF->getFunctionType() != pFT)
return nullptr;
F = existF;
UpdateCache(opClass, pOverloadType, F);
return F;
Expand All @@ -5531,9 +5529,6 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {

const SmallMapVector<llvm::Type *, llvm::Function *, 8> &
OP::GetOpFuncList(OpCode opCode) const {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB OpCode");
assert(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes);
return m_OpCodeClassCache[(unsigned)m_OpCodeProps[(unsigned)opCode]
.opCodeClass]
.pOverloads;
Expand Down Expand Up @@ -5631,7 +5626,8 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::CallShader:
case OpCode::Pack4x8:
case OpCode::WaveMatrix_Fill:
DXASSERT_NOMSG(FT->getNumParams() > 2);
if (FT->getNumParams() <= 2)
return nullptr;
return FT->getParamType(2);
case OpCode::MinPrecXRegStore:
case OpCode::StoreOutput:
Expand All @@ -5641,7 +5637,8 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::StoreVertexOutput:
case OpCode::StorePrimitiveOutput:
case OpCode::DispatchMesh:
DXASSERT_NOMSG(FT->getNumParams() > 4);
if (FT->getNumParams() <= 4)
return nullptr;
return FT->getParamType(4);
case OpCode::IsNaN:
case OpCode::IsInf:
Expand All @@ -5659,22 +5656,27 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::WaveActiveAllEqual:
case OpCode::CreateHandleForLib:
case OpCode::WaveMatch:
DXASSERT_NOMSG(FT->getNumParams() > 1);
if (FT->getNumParams() <= 1)
return nullptr;
return FT->getParamType(1);
case OpCode::TextureStore:
case OpCode::TextureStoreSample:
DXASSERT_NOMSG(FT->getNumParams() > 5);
if (FT->getNumParams() <= 5)
return nullptr;
return FT->getParamType(5);
case OpCode::TraceRay:
DXASSERT_NOMSG(FT->getNumParams() > 15);
if (FT->getNumParams() <= 15)
return nullptr;
return FT->getParamType(15);
case OpCode::ReportHit:
case OpCode::WaveMatrix_ScalarOp:
DXASSERT_NOMSG(FT->getNumParams() > 3);
if (FT->getNumParams() <= 3)
return nullptr;
return FT->getParamType(3);
case OpCode::WaveMatrix_LoadGroupShared:
case OpCode::WaveMatrix_StoreGroupShared:
DXASSERT_NOMSG(FT->getNumParams() > 2);
if (FT->getNumParams() <= 2)
return nullptr;
return FT->getParamType(2)->getPointerElementType();
case OpCode::CreateHandle:
case OpCode::BufferUpdateCounter:
Expand Down
10 changes: 3 additions & 7 deletions lib/DXIL/DxilShaderFlags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,13 +585,9 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F,
if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
if (!OP::IsDxilOpFunc(CI->getCalledFunction()))
continue;
Value *opcodeArg = CI->getArgOperand(DXIL::OperandIndex::kOpcodeIdx);
ConstantInt *opcodeConst = dyn_cast<ConstantInt>(opcodeArg);
DXASSERT(opcodeConst, "DXIL opcode arg must be immediate");
unsigned opcode = opcodeConst->getLimitedValue();
DXASSERT(opcode < static_cast<unsigned>(DXIL::OpCode::NumOpCodes),
"invalid DXIL opcode");
DXIL::OpCode dxilOp = static_cast<DXIL::OpCode>(opcode);
DXIL::OpCode dxilOp = hlsl::OP::getOpCode(CI);
if (dxilOp == DXIL::OpCode::NumOpCodes)
continue;
if (hlsl::OP::IsDxilOpWave(dxilOp))
hasWaveOps = true;
switch (dxilOp) {
Expand Down
24 changes: 24 additions & 0 deletions lib/HLSL/DxilValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3201,6 +3201,8 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
CallInst *setMeshOutputCounts = nullptr;
CallInst *getMeshPayload = nullptr;
CallInst *dispatchMesh = nullptr;
hlsl::OP *hlslOP = ValCtx.DxilMod.GetOP();

for (auto b = F->begin(), bend = F->end(); b != bend; ++b) {
for (auto i = b->begin(), iend = b->end(); i != iend; ++i) {
llvm::Instruction &I = *i;
Expand Down Expand Up @@ -3250,8 +3252,30 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
}

unsigned opcode = OpcodeConst->getLimitedValue();
if (opcode >= static_cast<unsigned>(DXIL::OpCode::NumOpCodes)) {
ValCtx.EmitInstrFormatError(
&I, ValidationRule::InstrIllegalDXILOpCode,
{std::to_string((unsigned)DXIL::OpCode::NumOpCodes),
std::to_string(opcode)});
continue;
}
DXIL::OpCode dxilOpcode = (DXIL::OpCode)opcode;

bool IllegalOpFunc = true;
for (auto &it : hlslOP->GetOpFuncList(dxilOpcode)) {
if (it.second == FCalled) {
IllegalOpFunc = false;
break;
}
}

if (IllegalOpFunc) {
ValCtx.EmitInstrFormatError(
&I, ValidationRule::InstrIllegalDXILOpFunction,
{FCalled->getName(), OP::GetOpCodeName(dxilOpcode)});
continue;
}

if (OP::IsDxilOpGradient(dxilOpcode)) {
gradientOps.push_back(CI);
}
Expand Down
120 changes: 120 additions & 0 deletions tools/clang/test/LitDXILValidation/illegalDXILOp.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
; REQUIRES: dxil-1-8
; RUN: not %dxv %s 2>&1 | FileCheck %s

target datalayout = "e-m:e-p:32:32-i1:32-i8:32-i16:32-i32:32-i64:64-f16:32-f32:32-f64:64-n8:16:32:64"
target triple = "dxil-ms-dx"

%dx.types.Handle = type { i8* }
%dx.types.ResBind = type { i32, i32, i32, i8 }
%dx.types.ResourceProperties = type { i32, i32 }
%dx.types.ResRet.f32 = type { float, float, float, float, i32 }
%"class.Texture2D<float>" = type { float, %"class.Texture2D<float>::mips_type" }
%"class.Texture2D<float>::mips_type" = type { i32 }
%struct.SamplerComparisonState = type { i32 }


define void @main() {
%1 = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind zeroinitializer, i32 0, i1 false) ; CreateHandleFromBinding(bind,index,nonUniformIndex)
%2 = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 0, i32 0, i32 0, i8 3 }, i32 0, i1 false) ; CreateHandleFromBinding(bind,index,nonUniformIndex)


; CHECK: error: 'dx.op.loadInput.f32' is not a DXILOpFuncition for DXILOpcode 'LoadInput'.
; CHECK: note: at '%3 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0)' in block '#0' of function 'main'.

%3 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0) ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)

; CHECK: error: 'dx.op.loadInput.f32' is not a DXILOpFuncition for DXILOpcode 'LoadInput'.
; CHECK: note: at '%4 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1)' in block '#0' of function 'main'.

%4 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1) ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)


; CHECK: error: 'dx.op.annotateHandle' is not a DXILOpFuncition for DXILOpcode 'MinPrecXRegStore'.
; CHECK: note: at '%5 = call %dx.types.Handle @dx.op.annotateHandle(i32 3, %dx.types.Handle %1, %dx.types.ResourceProperties { i32 2, i32 265 })' in block '#0' of function 'main'.

%5 = call %dx.types.Handle @dx.op.annotateHandle(i32 3, %dx.types.Handle %1, %dx.types.ResourceProperties { i32 2, i32 265 }) ; AnnotateHandle(res,props) resource: Texture2D<F32>

; CHECK: error: 'dx.op.annotateHandle2' is not a DXILOpFuncition for DXILOpcode 'AnnotateHandle'.
; CHECK: note: at '%6 = call %dx.types.Handle @dx.op.annotateHandle2(i32 216, %dx.types.Handle %2, %dx.types.ResourceProperties { i32 32782, i32 0 })' in block '#0' of function 'main'.

%6 = call %dx.types.Handle @dx.op.annotateHandle2(i32 216, %dx.types.Handle %2, %dx.types.ResourceProperties { i32 32782, i32 0 }) ; AnnotateHandle(res,props) resource: SamplerComparisonState

; CHECK: error: DXILOpCode must be [0..258]. 1999981 specified.
; CHECK: note: at '%7 = call float @dx.op.calculateLOD.f32(i32 1999981, %dx.types.Handle %5, %dx.types.Handle %6, float %3, float %4, float undef, i1 true)' in block '#0' of function 'main'.

%7 = call float @dx.op.calculateLOD.f32(i32 1999981, %dx.types.Handle %5, %dx.types.Handle %6, float %3, float %4, float undef, i1 true) ; CalculateLOD(handle,sampler,coord0,coord1,coord2,clamped)

%I = call i32 @dx.op.loadInput.i32(i32 4, i32 0, i32 0, i8 0, i32 undef) ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)

; CHECK: error: Opcode of DXIL operation must be an immediate constant.
; CHECK: note: at 'call void @dx.op.storeOutput.f32(i32 %I, i32 0, i32 0, i8 0, float %7)' in block '#0' of function 'main'.
call void @dx.op.storeOutput.f32(i32 %I, i32 0, i32 0, i8 0, float %7) ; StoreOutput(outputSigId,rowIndex,colIndex,value)


; CHECK-DAG: error: Opcode SampleCmpBias not valid in shader model ps_6_7.
%CmpBias = call %dx.types.ResRet.f32 @dx.op.sampleCmpBias.f32(i32 255, %dx.types.Handle %5, %dx.types.Handle %6, float %3, float %4, float undef, float undef, i32 0, i32 0, i32 undef, float 5.000000e-01, float 5.000000e-01, float undef) ; SampleCmpBias(srv,sampler,coord0,coord1,coord2,coord3,offset0,offset1,offset2,compareValue,bias,clamp)

ret void
}


; CHECK-DAG: error: DXIL intrinsic overload must be valid.
; CHECK-DAG: note: at '%4 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1)' in block '#0' of function 'main'.
; CHECK-DAG: error: DXIL intrinsic overload must be valid.
; CHECK-DAG: note: at '%3 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0)' in block '#0' of function 'main'.
; CHECK-DAG: error: DXIL intrinsic overload must be valid.
; CHECK-DAG: note: at '%5 = call %dx.types.Handle @dx.op.annotateHandle(i32 3, %dx.types.Handle %1, %dx.types.ResourceProperties { i32 2, i32 265 })' in block '#0' of function 'main'.

; Function Attrs: nounwind readnone
declare float @dx.op.loadInput.f32(i32, i32, i32, i8) #0

; Function Attrs: nounwind
declare void @dx.op.storeOutput.f32(i32, i32, i32, i8, float) #1

; Function Attrs: nounwind readonly
declare float @dx.op.calculateLOD.f32(i32, %dx.types.Handle, %dx.types.Handle, float, float, float, i1) #2

; Function Attrs: nounwind readnone
declare %dx.types.Handle @dx.op.annotateHandle(i32, %dx.types.Handle, %dx.types.ResourceProperties) #0

declare %dx.types.Handle @dx.op.annotateHandle2(i32, %dx.types.Handle, %dx.types.ResourceProperties) #0

; Function Attrs: nounwind readnone
declare %dx.types.Handle @dx.op.createHandleFromBinding(i32, %dx.types.ResBind, i32, i1) #0

declare i32 @dx.op.loadInput.i32(i32, i32, i32, i8, i32) #0

; Function Attrs: nounwind readonly
declare %dx.types.ResRet.f32 @dx.op.sampleCmpBias.f32(i32, %dx.types.Handle, %dx.types.Handle, float, float, float, float, i32, i32, i32, float, float, float) #2

attributes #0 = { nounwind readnone }
attributes #1 = { nounwind }
attributes #2 = { nounwind readonly }

!llvm.ident = !{!0}
!dx.version = !{!1}
!dx.valver = !{!1}
!dx.shaderModel = !{!2}
!dx.resources = !{!3}
!dx.viewIdState = !{!9}
!dx.entryPoints = !{!10}

!0 = !{!"dxc(private) 1.7.0.4396 (test_time, 849f8b884-dirty)"}
!1 = !{i32 1, i32 7}
!2 = !{!"ps", i32 6, i32 7}
!3 = !{!4, null, null, !7}
!4 = !{!5}
!5 = !{i32 0, %"class.Texture2D<float>"* undef, !"", i32 0, i32 0, i32 1, i32 2, i32 0, !6}
!6 = !{i32 0, i32 9}
!7 = !{!8}
!8 = !{i32 0, %struct.SamplerComparisonState* undef, !"", i32 0, i32 0, i32 1, i32 1, null}
!9 = !{[4 x i32] [i32 2, i32 1, i32 1, i32 1]}
!10 = !{void ()* @main, !"main", !11, !3, null}
!11 = !{!12, !16, null}
!12 = !{!13}
!13 = !{i32 0, !"A", i8 9, i8 0, !14, i8 2, i32 1, i8 2, i32 0, i8 0, !15}
!14 = !{i32 0}
!15 = !{i32 3, i32 3}
!16 = !{!17}
!17 = !{i32 0, !"SV_Target", i8 9, i8 16, !14, i8 0, i32 1, i8 1, i32 0, i8 0, !18}
!18 = !{i32 3, i32 1}
7 changes: 7 additions & 0 deletions utils/hct/hctdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7324,6 +7324,13 @@ def build_valrules(self):
"Instr.ImmBiasForSampleB",
"bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate.",
)
self.add_valrule(
"Instr.IllegalDXILOpCode", "DXILOpCode must be [0..%0]. %1 specified."
)
self.add_valrule(
"Instr.IllegalDXILOpFunction",
"'%0' is not a DXILOpFuncition for DXILOpcode '%1'.",
)
# If streams have not been declared, you must use cut instead of cut_stream in GS - is there an equivalent rule here?

# Need to clean up all error messages and actually implement.
Expand Down
Loading

0 comments on commit 5dee81c

Please sign in to comment.