Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Fix crash in DXIL.dll caused by illegal DXIL intrinsic. (#6302) (#6342)" #6418

Merged
merged 1 commit into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/DXIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3080,8 +3080,6 @@ INSTR.EVALINTERPOLATIONMODE Interpolation mode on %0 used with eva
INSTR.EXTRACTVALUE ExtractValue should only be used on dxil struct types and cmpxchg.
INSTR.FAILTORESLOVETGSMPOINTER TGSM pointers must originate from an unambiguous TGSM global variable.
INSTR.HANDLENOTFROMCREATEHANDLE Resource handle should returned by createHandle.
INSTR.ILLEGALDXILOPCODE DXILOpCode must be [0..%0]. %1 specified.
INSTR.ILLEGALDXILOPFUNCTION '%0' is not a DXILOpFuncition for DXILOpcode '%1'.
INSTR.IMMBIASFORSAMPLEB bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate.
INSTR.INBOUNDSACCESS Access to out-of-bounds memory is disallowed.
INSTR.MINPRECISIONNOTPRECISE Instructions marked precise may not refer to minprecision values.
Expand Down
4 changes: 3 additions & 1 deletion lib/DXIL/DxilCounters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ void CountInstructions(llvm::Module &M, DxilCounters &counters) {
}
} else if (CallInst *CI = dyn_cast<CallInst>(I)) {
if (hlsl::OP::IsDxilOpFuncCallInst(CI)) {
unsigned opcode = static_cast<unsigned>(hlsl::OP::getOpCode(CI));
unsigned opcode =
(unsigned)llvm::cast<llvm::ConstantInt>(I->getOperand(0))
->getZExtValue();
CountDxilOp(opcode, counters);
}
} else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
Expand Down
82 changes: 40 additions & 42 deletions lib/DXIL/DxilOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2705,6 +2705,8 @@ llvm::StringRef OP::ConstructOverloadName(Type *Ty, DXIL::OpCode opCode,
}

const char *OP::GetOpCodeName(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].pOpCodeName;
}

Expand All @@ -2717,22 +2719,26 @@ const char *OP::GetAtomicOpName(DXIL::AtomicBinOpCode OpCode) {
}

OP::OpCodeClass OP::GetOpCodeClass(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].opCodeClass;
}

const char *OP::GetOpCodeClassName(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].pOpCodeClassName;
}

llvm::Attribute::AttrKind OP::GetMemAccessAttr(OpCode opCode) {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
return m_OpCodeProps[(unsigned)opCode].FuncAttr;
}

bool OP::IsOverloadLegal(OpCode opCode, Type *pType) {
if (!pType)
return false;
if (opCode == OpCode::NumOpCodes)
return false;
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB index");
unsigned TypeSlot = GetTypeSlot(pType);
return TypeSlot != UINT_MAX &&
m_OpCodeProps[(unsigned)opCode].bAllowOverload[TypeSlot];
Expand Down Expand Up @@ -2808,13 +2814,8 @@ bool OP::IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode) {
}

OP::OpCode OP::getOpCode(const llvm::Instruction *I) {
auto *OpConst = llvm::dyn_cast<llvm::ConstantInt>(I->getOperand(0));
if (!OpConst)
return OpCode::NumOpCodes;
uint64_t OpCodeVal = OpConst->getZExtValue();
if (OpCodeVal >= static_cast<uint64_t>(OP::OpCode::NumOpCodes))
return OP::OpCode::NumOpCodes;
return static_cast<OP::OpCode>(OpCodeVal);
return (OP::OpCode)llvm::cast<llvm::ConstantInt>(I->getOperand(0))
->getZExtValue();
}

OP::OpCode OP::GetDxilOpFuncCallInst(const llvm::Instruction *I) {
Expand Down Expand Up @@ -3524,7 +3525,9 @@ void OP::RefreshCache() {
CallInst *CI = cast<CallInst>(*F.user_begin());
OpCode OpCode = OP::GetDxilOpFuncCallInst(CI);
Type *pOverloadType = OP::GetOverloadType(OpCode, &F);
GetOpFunc(OpCode, pOverloadType);
Function *OpFunc = GetOpFunc(OpCode, pOverloadType);
(void)(OpFunc);
DXASSERT_NOMSG(OpFunc == &F);
}
}
}
Expand All @@ -3543,15 +3546,13 @@ void OP::FixOverloadNames() {
CallInst *CI = cast<CallInst>(*F.user_begin());
DXIL::OpCode opCode = OP::GetDxilOpFuncCallInst(CI);
llvm::Type *Ty = OP::GetOverloadType(opCode, &F);
if (!OP::IsOverloadLegal(opCode, Ty))
continue;
if (!isa<StructType>(Ty) && !isa<PointerType>(Ty))
continue;

std::string funcName;
if (OP::ConstructOverloadName(Ty, opCode, funcName)
.compare(F.getName()) != 0)
F.setName(funcName);
if (isa<StructType>(Ty) || isa<PointerType>(Ty)) {
std::string funcName;
if (OP::ConstructOverloadName(Ty, opCode, funcName)
.compare(F.getName()) != 0) {
F.setName(funcName);
}
}
}
}
}
Expand All @@ -3562,11 +3563,12 @@ void OP::UpdateCache(OpCodeClass opClass, Type *Ty, llvm::Function *F) {
}

Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
if (opCode == OpCode::NumOpCodes)
return nullptr;
if (!IsOverloadLegal(opCode, pOverloadType))
return nullptr;

DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB OpCode");
assert(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes);
DXASSERT(IsOverloadLegal(opCode, pOverloadType),
"otherwise the caller requested illegal operation overload (eg HLSL "
"function with unsupported types for mapped intrinsic function)");
OpCodeClass opClass = m_OpCodeProps[(unsigned)opCode].opCodeClass;
Function *&F =
m_OpCodeClassCache[(unsigned)opClass].pOverloads[pOverloadType];
Expand Down Expand Up @@ -5509,8 +5511,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
// and return values to ensure that ResRetType is constructed in the
// RefreshCache case.
if (Function *existF = m_pModule->getFunction(funcName)) {
if (existF->getFunctionType() != pFT)
return nullptr;
DXASSERT(existF->getFunctionType() == pFT,
"existing function must have the expected function type");
F = existF;
UpdateCache(opClass, pOverloadType, F);
return F;
Expand All @@ -5529,6 +5531,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {

const SmallMapVector<llvm::Type *, llvm::Function *, 8> &
OP::GetOpFuncList(OpCode opCode) const {
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
"otherwise caller passed OOB OpCode");
assert(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes);
return m_OpCodeClassCache[(unsigned)m_OpCodeProps[(unsigned)opCode]
.opCodeClass]
.pOverloads;
Expand Down Expand Up @@ -5626,8 +5631,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::CallShader:
case OpCode::Pack4x8:
case OpCode::WaveMatrix_Fill:
if (FT->getNumParams() <= 2)
return nullptr;
DXASSERT_NOMSG(FT->getNumParams() > 2);
return FT->getParamType(2);
case OpCode::MinPrecXRegStore:
case OpCode::StoreOutput:
Expand All @@ -5637,8 +5641,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::StoreVertexOutput:
case OpCode::StorePrimitiveOutput:
case OpCode::DispatchMesh:
if (FT->getNumParams() <= 4)
return nullptr;
DXASSERT_NOMSG(FT->getNumParams() > 4);
return FT->getParamType(4);
case OpCode::IsNaN:
case OpCode::IsInf:
Expand All @@ -5656,27 +5659,22 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::WaveActiveAllEqual:
case OpCode::CreateHandleForLib:
case OpCode::WaveMatch:
if (FT->getNumParams() <= 1)
return nullptr;
DXASSERT_NOMSG(FT->getNumParams() > 1);
return FT->getParamType(1);
case OpCode::TextureStore:
case OpCode::TextureStoreSample:
if (FT->getNumParams() <= 5)
return nullptr;
DXASSERT_NOMSG(FT->getNumParams() > 5);
return FT->getParamType(5);
case OpCode::TraceRay:
if (FT->getNumParams() <= 15)
return nullptr;
DXASSERT_NOMSG(FT->getNumParams() > 15);
return FT->getParamType(15);
case OpCode::ReportHit:
case OpCode::WaveMatrix_ScalarOp:
if (FT->getNumParams() <= 3)
return nullptr;
DXASSERT_NOMSG(FT->getNumParams() > 3);
return FT->getParamType(3);
case OpCode::WaveMatrix_LoadGroupShared:
case OpCode::WaveMatrix_StoreGroupShared:
if (FT->getNumParams() <= 2)
return nullptr;
DXASSERT_NOMSG(FT->getNumParams() > 2);
return FT->getParamType(2)->getPointerElementType();
case OpCode::CreateHandle:
case OpCode::BufferUpdateCounter:
Expand Down
10 changes: 7 additions & 3 deletions lib/DXIL/DxilShaderFlags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,13 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F,
if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
if (!OP::IsDxilOpFunc(CI->getCalledFunction()))
continue;
DXIL::OpCode dxilOp = hlsl::OP::getOpCode(CI);
if (dxilOp == DXIL::OpCode::NumOpCodes)
continue;
Value *opcodeArg = CI->getArgOperand(DXIL::OperandIndex::kOpcodeIdx);
ConstantInt *opcodeConst = dyn_cast<ConstantInt>(opcodeArg);
DXASSERT(opcodeConst, "DXIL opcode arg must be immediate");
unsigned opcode = opcodeConst->getLimitedValue();
DXASSERT(opcode < static_cast<unsigned>(DXIL::OpCode::NumOpCodes),
"invalid DXIL opcode");
DXIL::OpCode dxilOp = static_cast<DXIL::OpCode>(opcode);
if (hlsl::OP::IsDxilOpWave(dxilOp))
hasWaveOps = true;
switch (dxilOp) {
Expand Down
24 changes: 0 additions & 24 deletions lib/HLSL/DxilValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3208,8 +3208,6 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
CallInst *setMeshOutputCounts = nullptr;
CallInst *getMeshPayload = nullptr;
CallInst *dispatchMesh = nullptr;
hlsl::OP *hlslOP = ValCtx.DxilMod.GetOP();

for (auto b = F->begin(), bend = F->end(); b != bend; ++b) {
for (auto i = b->begin(), iend = b->end(); i != iend; ++i) {
llvm::Instruction &I = *i;
Expand Down Expand Up @@ -3259,30 +3257,8 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
}

unsigned opcode = OpcodeConst->getLimitedValue();
if (opcode >= static_cast<unsigned>(DXIL::OpCode::NumOpCodes)) {
ValCtx.EmitInstrFormatError(
&I, ValidationRule::InstrIllegalDXILOpCode,
{std::to_string((unsigned)DXIL::OpCode::NumOpCodes),
std::to_string(opcode)});
continue;
}
DXIL::OpCode dxilOpcode = (DXIL::OpCode)opcode;

bool IllegalOpFunc = true;
for (auto &it : hlslOP->GetOpFuncList(dxilOpcode)) {
if (it.second == FCalled) {
IllegalOpFunc = false;
break;
}
}

if (IllegalOpFunc) {
ValCtx.EmitInstrFormatError(
&I, ValidationRule::InstrIllegalDXILOpFunction,
{FCalled->getName(), OP::GetOpCodeName(dxilOpcode)});
continue;
}

if (OP::IsDxilOpGradient(dxilOpcode)) {
gradientOps.push_back(CI);
}
Expand Down
120 changes: 0 additions & 120 deletions tools/clang/test/LitDXILValidation/illegalDXILOp.ll

This file was deleted.

7 changes: 0 additions & 7 deletions utils/hct/hctdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7324,13 +7324,6 @@ def build_valrules(self):
"Instr.ImmBiasForSampleB",
"bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate.",
)
self.add_valrule(
"Instr.IllegalDXILOpCode", "DXILOpCode must be [0..%0]. %1 specified."
)
self.add_valrule(
"Instr.IllegalDXILOpFunction",
"'%0' is not a DXILOpFuncition for DXILOpcode '%1'.",
)
# If streams have not been declared, you must use cut instead of cut_stream in GS - is there an equivalent rule here?

# Need to clean up all error messages and actually implement.
Expand Down
Loading
Loading