Skip to content

Commit

Permalink
Revert "Fix crash in DXIL.dll caused by illegal DXIL intrinsic. (#6302)…
Browse files Browse the repository at this point in the history
… (#6342)" (#6418)

This file deleted with conflicts from subsequent changes:
  tools/clang/test/LitDXILValidation/illegalDXILOp.ll

This reverts commit 487080f.

Fixes #6419.
  • Loading branch information
tex3d authored Mar 14, 2024
1 parent c9660a8 commit 11ee8de
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 211 deletions.
2 changes: 0 additions & 2 deletions docs/DXIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3080,8 +3080,6 @@ INSTR.EVALINTERPOLATIONMODE Interpolation mode on %0 used with eva
INSTR.EXTRACTVALUE ExtractValue should only be used on dxil struct types and cmpxchg.
INSTR.FAILTORESLOVETGSMPOINTER TGSM pointers must originate from an unambiguous TGSM global variable.
INSTR.HANDLENOTFROMCREATEHANDLE Resource handle should returned by createHandle.
INSTR.ILLEGALDXILOPCODE DXILOpCode must be [0..%0]. %1 specified.
INSTR.ILLEGALDXILOPFUNCTION '%0' is not a DXILOpFuncition for DXILOpcode '%1'.
INSTR.IMMBIASFORSAMPLEB bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate.
INSTR.INBOUNDSACCESS Access to out-of-bounds memory is disallowed.
INSTR.MINPRECISIONNOTPRECISE Instructions marked precise may not refer to minprecision values.
Expand Down
4 changes: 3 additions & 1 deletion lib/DXIL/DxilCounters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ void CountInstructions(llvm::Module &M, DxilCounters &counters) {
}
} else if (CallInst *CI = dyn_cast<CallInst>(I)) {
if (hlsl::OP::IsDxilOpFuncCallInst(CI)) {
unsigned opcode = static_cast<unsigned>(hlsl::OP::getOpCode(CI));
unsigned opcode =
(unsigned)llvm::cast<llvm::ConstantInt>(I->getOperand(0))
->getZExtValue();
CountDxilOp(opcode, counters);
}
} else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
Expand Down
82 changes: 40 additions & 42 deletions lib/DXIL/DxilOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2705,6 +2705,8 @@ llvm::StringRef OP::ConstructOverloadName(Type *Ty, DXIL::OpCode opCode,
}

const char *OP::GetOpCodeName(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].pOpCodeName;
}

Expand All @@ -2717,22 +2719,26 @@ const char *OP::GetAtomicOpName(DXIL::AtomicBinOpCode OpCode) {
}

OP::OpCodeClass OP::GetOpCodeClass(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].opCodeClass;
}

const char *OP::GetOpCodeClassName(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].pOpCodeClassName;
}

llvm::Attribute::AttrKind OP::GetMemAccessAttr(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].FuncAttr;
}

bool OP::IsOverloadLegal(OpCode opCode, Type *pType) {
if (!pType)
return false;
if (opCode == OpCode::NumOpCodes)
return false;
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
unsigned TypeSlot = GetTypeSlot(pType);
return TypeSlot != UINT_MAX &&
m_OpCodeProps[(unsigned)opCode].bAllowOverload[TypeSlot];
Expand Down Expand Up @@ -2808,13 +2814,8 @@ bool OP::IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode) {
}

OP::OpCode OP::getOpCode(const llvm::Instruction *I) {
auto *OpConst = llvm::dyn_cast<llvm::ConstantInt>(I->getOperand(0));
if (!OpConst)
return OpCode::NumOpCodes;
uint64_t OpCodeVal = OpConst->getZExtValue();
if (OpCodeVal >= static_cast<uint64_t>(OP::OpCode::NumOpCodes))
return OP::OpCode::NumOpCodes;
return static_cast<OP::OpCode>(OpCodeVal);
return (OP::OpCode)llvm::cast<llvm::ConstantInt>(I->getOperand(0))
->getZExtValue();
}

OP::OpCode OP::GetDxilOpFuncCallInst(const llvm::Instruction *I) {
Expand Down Expand Up @@ -3524,7 +3525,9 @@ void OP::RefreshCache() {
CallInst *CI = cast<CallInst>(*F.user_begin());
OpCode OpCode = OP::GetDxilOpFuncCallInst(CI);
Type *pOverloadType = OP::GetOverloadType(OpCode, &F);
GetOpFunc(OpCode, pOverloadType);
Function *OpFunc = GetOpFunc(OpCode, pOverloadType);
(void)(OpFunc);
DXASSERT_NOMSG(OpFunc == &F);
}
}
}
Expand All @@ -3543,15 +3546,13 @@ void OP::FixOverloadNames() {
CallInst *CI = cast<CallInst>(*F.user_begin());
DXIL::OpCode opCode = OP::GetDxilOpFuncCallInst(CI);
llvm::Type *Ty = OP::GetOverloadType(opCode, &F);
if (!OP::IsOverloadLegal(opCode, Ty))
continue;
if (!isa<StructType>(Ty) && !isa<PointerType>(Ty))
continue;

std::string funcName;
if (OP::ConstructOverloadName(Ty, opCode, funcName)
.compare(F.getName()) != 0)
F.setName(funcName);
if (isa<StructType>(Ty) || isa<PointerType>(Ty)) {
std::string funcName;
if (OP::ConstructOverloadName(Ty, opCode, funcName)
.compare(F.getName()) != 0) {
F.setName(funcName);
}
}
}
}
}
Expand All @@ -3562,11 +3563,12 @@ void OP::UpdateCache(OpCodeClass opClass, Type *Ty, llvm::Function *F) {
}

Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
if (opCode == OpCode::NumOpCodes)
return nullptr;
if (!IsOverloadLegal(opCode, pOverloadType))
return nullptr;

DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB OpCode");
assert(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes);
DXASSERT(IsOverloadLegal(opCode, pOverloadType),
"otherwise the caller requested illegal operation overload (eg HLSL "
"function with unsupported types for mapped intrinsic function)");
OpCodeClass opClass = m_OpCodeProps[(unsigned)opCode].opCodeClass;
Function *&F =
m_OpCodeClassCache[(unsigned)opClass].pOverloads[pOverloadType];
Expand Down Expand Up @@ -5509,8 +5511,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
// and return values to ensure that ResRetType is constructed in the
// RefreshCache case.
if (Function *existF = m_pModule->getFunction(funcName)) {
if (existF->getFunctionType() != pFT)
return nullptr;
DXASSERT(existF->getFunctionType() == pFT,
"existing function must have the expected function type");
F = existF;
UpdateCache(opClass, pOverloadType, F);
return F;
Expand All @@ -5529,6 +5531,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {

const SmallMapVector<llvm::Type *, llvm::Function *, 8> &
OP::GetOpFuncList(OpCode opCode) const {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB OpCode");
assert(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes);
return m_OpCodeClassCache[(unsigned)m_OpCodeProps[(unsigned)opCode]
.opCodeClass]
.pOverloads;
Expand Down Expand Up @@ -5626,8 +5631,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::CallShader:
case OpCode::Pack4x8:
case OpCode::WaveMatrix_Fill:
if (FT->getNumParams() <= 2)
return nullptr;
DXASSERT_NOMSG(FT->getNumParams() > 2);
return FT->getParamType(2);
case OpCode::MinPrecXRegStore:
case OpCode::StoreOutput:
Expand All @@ -5637,8 +5641,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::StoreVertexOutput:
case OpCode::StorePrimitiveOutput:
case OpCode::DispatchMesh:
if (FT->getNumParams() <= 4)
return nullptr;
DXASSERT_NOMSG(FT->getNumParams() > 4);
return FT->getParamType(4);
case OpCode::IsNaN:
case OpCode::IsInf:
Expand All @@ -5656,27 +5659,22 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::WaveActiveAllEqual:
case OpCode::CreateHandleForLib:
case OpCode::WaveMatch:
if (FT->getNumParams() <= 1)
return nullptr;
DXASSERT_NOMSG(FT->getNumParams() > 1);
return FT->getParamType(1);
case OpCode::TextureStore:
case OpCode::TextureStoreSample:
if (FT->getNumParams() <= 5)
return nullptr;
DXASSERT_NOMSG(FT->getNumParams() > 5);
return FT->getParamType(5);
case OpCode::TraceRay:
if (FT->getNumParams() <= 15)
return nullptr;
DXASSERT_NOMSG(FT->getNumParams() > 15);
return FT->getParamType(15);
case OpCode::ReportHit:
case OpCode::WaveMatrix_ScalarOp:
if (FT->getNumParams() <= 3)
return nullptr;
DXASSERT_NOMSG(FT->getNumParams() > 3);
return FT->getParamType(3);
case OpCode::WaveMatrix_LoadGroupShared:
case OpCode::WaveMatrix_StoreGroupShared:
if (FT->getNumParams() <= 2)
return nullptr;
DXASSERT_NOMSG(FT->getNumParams() > 2);
return FT->getParamType(2)->getPointerElementType();
case OpCode::CreateHandle:
case OpCode::BufferUpdateCounter:
Expand Down
10 changes: 7 additions & 3 deletions lib/DXIL/DxilShaderFlags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,13 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F,
if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
if (!OP::IsDxilOpFunc(CI->getCalledFunction()))
continue;
DXIL::OpCode dxilOp = hlsl::OP::getOpCode(CI);
if (dxilOp == DXIL::OpCode::NumOpCodes)
continue;
Value *opcodeArg = CI->getArgOperand(DXIL::OperandIndex::kOpcodeIdx);
ConstantInt *opcodeConst = dyn_cast<ConstantInt>(opcodeArg);
DXASSERT(opcodeConst, "DXIL opcode arg must be immediate");
unsigned opcode = opcodeConst->getLimitedValue();
DXASSERT(opcode < static_cast<unsigned>(DXIL::OpCode::NumOpCodes),
"invalid DXIL opcode");
DXIL::OpCode dxilOp = static_cast<DXIL::OpCode>(opcode);
if (hlsl::OP::IsDxilOpWave(dxilOp))
hasWaveOps = true;
switch (dxilOp) {
Expand Down
24 changes: 0 additions & 24 deletions lib/HLSL/DxilValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3208,8 +3208,6 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
CallInst *setMeshOutputCounts = nullptr;
CallInst *getMeshPayload = nullptr;
CallInst *dispatchMesh = nullptr;
hlsl::OP *hlslOP = ValCtx.DxilMod.GetOP();

for (auto b = F->begin(), bend = F->end(); b != bend; ++b) {
for (auto i = b->begin(), iend = b->end(); i != iend; ++i) {
llvm::Instruction &I = *i;
Expand Down Expand Up @@ -3259,30 +3257,8 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
}

unsigned opcode = OpcodeConst->getLimitedValue();
if (opcode >= static_cast<unsigned>(DXIL::OpCode::NumOpCodes)) {
ValCtx.EmitInstrFormatError(
&I, ValidationRule::InstrIllegalDXILOpCode,
{std::to_string((unsigned)DXIL::OpCode::NumOpCodes),
std::to_string(opcode)});
continue;
}
DXIL::OpCode dxilOpcode = (DXIL::OpCode)opcode;

bool IllegalOpFunc = true;
for (auto &it : hlslOP->GetOpFuncList(dxilOpcode)) {
if (it.second == FCalled) {
IllegalOpFunc = false;
break;
}
}

if (IllegalOpFunc) {
ValCtx.EmitInstrFormatError(
&I, ValidationRule::InstrIllegalDXILOpFunction,
{FCalled->getName(), OP::GetOpCodeName(dxilOpcode)});
continue;
}

if (OP::IsDxilOpGradient(dxilOpcode)) {
gradientOps.push_back(CI);
}
Expand Down
120 changes: 0 additions & 120 deletions tools/clang/test/LitDXILValidation/illegalDXILOp.ll

This file was deleted.

7 changes: 0 additions & 7 deletions utils/hct/hctdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7324,13 +7324,6 @@ def build_valrules(self):
"Instr.ImmBiasForSampleB",
"bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate.",
)
self.add_valrule(
"Instr.IllegalDXILOpCode", "DXILOpCode must be [0..%0]. %1 specified."
)
self.add_valrule(
"Instr.IllegalDXILOpFunction",
"'%0' is not a DXILOpFuncition for DXILOpcode '%1'.",
)
# If streams have not been declared, you must use cut instead of cut_stream in GS - is there an equivalent rule here?

# Need to clean up all error messages and actually implement.
Expand Down
Loading

0 comments on commit 11ee8de

Please sign in to comment.