Skip to content

Commit

Permalink
Fix attribute collision for HL intrinsics (#5451)
Browse files Browse the repository at this point in the history
HL Intrinsic functions share declarations with those that match group
and function signature, regardless of the original intrinsic name. This
means that intrinsics with differing attributes can be collapsed into
the same HL functions, leading to incorrect attributes for some HL
intrinsics.

This fixes this issue by adding the attributes to the HL operation
mangling, the same way this issue was fixed for the HLWaveSensitive
attribute before.

Fixes #3505

---------

Co-authored-by: Joshua Batista <jbatista@microsoft.com>
  • Loading branch information
tex3d and bob80905 authored Aug 9, 2023
1 parent a6497c3 commit d9c07e9
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 67 deletions.
11 changes: 8 additions & 3 deletions lib/HLSL/HLMatrixLowerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1517,10 +1517,15 @@ void HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, Value *MatPtr, Small
HLMatLoadStoreOpcode Opcode = (HLSubscriptOpcode)GetHLOpcode(Call) == HLSubscriptOpcode::RowMatSubscript ?
HLMatLoadStoreOpcode::RowMatLoad : HLMatLoadStoreOpcode::ColMatLoad;
HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
// Don't pass attributes from subscript (ReadNone) - load is ReadOnly.
// Attributes will be set when HL function is created.
// FIXME: This seems to indicate a potential bug, since the load should be
// placed where pointer users would have loaded from the pointer.
LoweredMatrix = callHLFunction(
*m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
MatTy.getLoweredVectorTypeForReg(), { CallBuilder.getInt32((uint32_t)Opcode), MatPtr },
Call->getCalledFunction()->getAttributes().getFnAttributes(), CallBuilder);
*m_pModule, HLOpcodeGroup::HLMatLoadStore,
static_cast<unsigned>(Opcode), MatTy.getLoweredVectorTypeForReg(),
{CallBuilder.getInt32((uint32_t)Opcode), MatPtr}, AttributeSet(),
CallBuilder);
}
// For global variables, we can GEP directly into the lowered vector pointer.
// This is necessary to support group shared memory atomics and the likes.
Expand Down
185 changes: 143 additions & 42 deletions lib/HLSL/HLOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,10 @@ bool IsHLWaveSensitive(Function *F) {
return attrSet.hasAttribute(AttributeSet::FunctionIndex, HLWaveSensitive);
}

std::string GetHLFullName(HLOpcodeGroup op, unsigned opcode) {
static std::string GetHLFunctionAttributeMangling(const AttributeSet &attribs);

std::string GetHLFullName(HLOpcodeGroup op, unsigned opcode,
const AttributeSet &attribs = AttributeSet()) {
assert(op != HLOpcodeGroup::HLExtIntrinsic && "else table name should be used");
std::string opName = GetHLOpcodeGroupFullName(op).str() + ".";

Expand All @@ -321,22 +324,26 @@ std::string GetHLFullName(HLOpcodeGroup op, unsigned opcode) {
case HLOpcodeGroup::HLIntrinsic: {
// intrinsic with same signature will share the funciton now
// The opcode is in arg0.
return opName;
return opName + GetHLFunctionAttributeMangling(attribs);
}
case HLOpcodeGroup::HLMatLoadStore: {
HLMatLoadStoreOpcode matOp = static_cast<HLMatLoadStoreOpcode>(opcode);
return opName + GetHLOpcodeName(matOp).str();
}
case HLOpcodeGroup::HLSubscript: {
HLSubscriptOpcode subOp = static_cast<HLSubscriptOpcode>(opcode);
return opName + GetHLOpcodeName(subOp).str();
return opName + GetHLOpcodeName(subOp).str() + "." +
GetHLFunctionAttributeMangling(attribs);
}
case HLOpcodeGroup::HLCast: {
HLCastOpcode castOp = static_cast<HLCastOpcode>(opcode);
return opName + GetHLOpcodeName(castOp).str();
}
default:
case HLOpcodeGroup::HLCreateHandle:
case HLOpcodeGroup::HLAnnotateHandle:
return opName;
default:
return opName + GetHLFunctionAttributeMangling(attribs);
}
}

Expand Down Expand Up @@ -417,38 +424,59 @@ HLBinaryOpcode GetUnsignedOpcode(HLBinaryOpcode opcode) {
}
}

static void SetHLFunctionAttribute(Function *F, HLOpcodeGroup group,
unsigned opcode) {
F->addFnAttr(Attribute::NoUnwind);
static AttributeSet
GetHLFunctionAttributes(LLVMContext &C, FunctionType *funcTy,
const AttributeSet &origAttribs,
HLOpcodeGroup group, unsigned opcode) {
// Always add nounwind
AttributeSet attribs =
AttributeSet::get(C, AttributeSet::FunctionIndex,
ArrayRef<Attribute::AttrKind>({Attribute::NoUnwind}));

auto addAttr = [&](Attribute::AttrKind Attr) {
if (!attribs.hasAttribute(AttributeSet::FunctionIndex, Attr))
attribs = attribs.addAttribute(C, AttributeSet::FunctionIndex, Attr);
};
auto copyAttr = [&](Attribute::AttrKind Attr) {
if (origAttribs.hasAttribute(AttributeSet::FunctionIndex, Attr))
addAttr(Attr);
};
auto copyStrAttr = [&](StringRef Kind) {
if (origAttribs.hasAttribute(AttributeSet::FunctionIndex, Kind))
attribs = attribs.addAttribute(
C, AttributeSet::FunctionIndex, Kind,
origAttribs.getAttribute(AttributeSet::FunctionIndex, Kind)
.getValueAsString());
};

// Copy attributes we preserve from the original function.
copyAttr(Attribute::ReadOnly);
copyAttr(Attribute::ReadNone);
copyStrAttr(HLWaveSensitive);

switch (group) {
case HLOpcodeGroup::HLUnOp:
case HLOpcodeGroup::HLBinOp:
case HLOpcodeGroup::HLCast:
case HLOpcodeGroup::HLSubscript:
if (!F->hasFnAttribute(Attribute::ReadNone)) {
F->addFnAttr(Attribute::ReadNone);
}
addAttr(Attribute::ReadNone);
break;
case HLOpcodeGroup::HLInit:
if (!F->hasFnAttribute(Attribute::ReadNone))
if (!F->getReturnType()->isVoidTy()) {
F->addFnAttr(Attribute::ReadNone);
}
if (!funcTy->getReturnType()->isVoidTy()) {
addAttr(Attribute::ReadNone);
}
break;
case HLOpcodeGroup::HLMatLoadStore: {
HLMatLoadStoreOpcode matOp = static_cast<HLMatLoadStoreOpcode>(opcode);
if (matOp == HLMatLoadStoreOpcode::ColMatLoad ||
matOp == HLMatLoadStoreOpcode::RowMatLoad)
if (!F->hasFnAttribute(Attribute::ReadOnly)) {
F->addFnAttr(Attribute::ReadOnly);
}
addAttr(Attribute::ReadOnly);
} break;
case HLOpcodeGroup::HLCreateHandle: {
F->addFnAttr(Attribute::ReadNone);
addAttr(Attribute::ReadNone);
} break;
case HLOpcodeGroup::HLAnnotateHandle: {
F->addFnAttr(Attribute::ReadNone);
addAttr(Attribute::ReadNone);
} break;
case HLOpcodeGroup::HLIntrinsic: {
IntrinsicOp intrinsicOp = static_cast<IntrinsicOp>(opcode);
Expand All @@ -461,7 +489,7 @@ static void SetHLFunctionAttribute(Function *F, HLOpcodeGroup group,
case IntrinsicOp::IOP_GroupMemoryBarrier:
case IntrinsicOp::IOP_AllMemoryBarrierWithGroupSync:
case IntrinsicOp::IOP_AllMemoryBarrier:
F->addFnAttr(Attribute::NoDuplicate);
addAttr(Attribute::NoDuplicate);
break;
}
} break;
Expand All @@ -472,6 +500,75 @@ static void SetHLFunctionAttribute(Function *F, HLOpcodeGroup group,
// No default attributes for these opcodes.
break;
}
assert(!(attribs.hasAttribute(AttributeSet::FunctionIndex,
Attribute::ReadNone) &&
attribs.hasAttribute(AttributeSet::FunctionIndex,
Attribute::ReadOnly)) &&
"conflicting ReadNone and ReadOnly attributes");
return attribs;
}

static std::string GetHLFunctionAttributeMangling(const AttributeSet &attribs) {
std::string mangledName;
raw_string_ostream mangledNameStr(mangledName);

// Capture for adding in canonical order later.
bool ReadNone = false;
bool ReadOnly = false;
bool NoDuplicate = false;
bool WaveSensitive = false;

// Ensure every function attribute is recognized.
for (unsigned Slot = 0; Slot < attribs.getNumSlots(); Slot++) {
if (attribs.getSlotIndex(Slot) == AttributeSet::FunctionIndex) {
for (auto it = attribs.begin(Slot), e = attribs.end(Slot); it != e;
it++) {
if (it->isEnumAttribute()) {
switch (it->getKindAsEnum()) {
case Attribute::ReadNone:
ReadNone = true;
break;
case Attribute::ReadOnly:
ReadOnly = true;
break;
case Attribute::NoDuplicate:
NoDuplicate = true;
break;
case Attribute::NoUnwind:
// All intrinsics have this attribute, so mangling is unaffected.
break;
default:
assert(false && "unexpected attribute for HLOperation");
}
} else if (it->isStringAttribute()) {
StringRef Kind = it->getKindAsString();
if (Kind == HLWaveSensitive) {
assert(it->getValueAsString() == "y" &&
"otherwise, unexpected value for WaveSensitive attribute");
WaveSensitive = true;
} else {
assert(false &&
"unexpected string function attribute for HLOperation");
}
}
}
}
}

// Validate attribute combinations.
assert(!(ReadNone && ReadOnly) &&
"ReadNone and ReadOnly are mutually exclusive");

// Add mangling in canonical order
if (NoDuplicate)
mangledNameStr << "nd";
if (ReadNone)
mangledNameStr << "rn";
if (ReadOnly)
mangledNameStr << "ro";
if (WaveSensitive)
mangledNameStr << "wave";
return mangledName;
}


Expand All @@ -497,7 +594,11 @@ Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
HLOpcodeGroup group, StringRef *groupName,
StringRef *fnName, unsigned opcode,
const AttributeSet &attribs) {
const AttributeSet &origAttribs) {
// Set/transfer all common attributes
AttributeSet attribs = GetHLFunctionAttributes(
M.getContext(), funcTy, origAttribs, group, opcode);

std::string mangledName;
raw_string_ostream mangledNameStr(mangledName);
if (group == HLOpcodeGroup::HLExtIntrinsic) {
Expand All @@ -506,33 +607,31 @@ Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
mangledNameStr << *groupName;
mangledNameStr << '.';
mangledNameStr << *fnName;
attribs = attribs.addAttribute(M.getContext(), AttributeSet::FunctionIndex,
hlsl::HLPrefix, *groupName);
}
else {
mangledNameStr << GetHLFullName(group, opcode);
// Need to add wave sensitivity to name to prevent clashes with non-wave intrinsic
if(attribs.hasAttribute(AttributeSet::FunctionIndex, HLWaveSensitive))
mangledNameStr << "wave";
mangledNameStr << GetHLFullName(group, opcode, attribs);
mangledNameStr << '.';
funcTy->print(mangledNameStr);
}

mangledNameStr.flush();

Function *F = cast<Function>(M.getOrInsertFunction(mangledName, funcTy));
if (group == HLOpcodeGroup::HLExtIntrinsic) {
F->addFnAttr(hlsl::HLPrefix, *groupName);
// Avoid getOrInsertFunction to verify attributes and type without casting.
Function *F = cast_or_null<Function>(M.getNamedValue(mangledName));
if (F) {
assert(F->getFunctionType() == funcTy &&
"otherwise, function type mismatch not captured by mangling");
// Compare attribute mangling to ensure function attributes are as expected.
assert(
GetHLFunctionAttributeMangling(F->getAttributes().getFnAttributes()) ==
GetHLFunctionAttributeMangling(attribs) &&
"otherwise, function attribute mismatch not captured by mangling");
} else {
F = cast<Function>(M.getOrInsertFunction(mangledName, funcTy, attribs));
}

SetHLFunctionAttribute(F, group, opcode);

// Copy attributes
if (attribs.hasAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone))
F->addFnAttr(Attribute::ReadNone);
if (attribs.hasAttribute(AttributeSet::FunctionIndex, Attribute::ReadOnly))
F->addFnAttr(Attribute::ReadOnly);
if (attribs.hasAttribute(AttributeSet::FunctionIndex, HLWaveSensitive))
F->addFnAttr(HLWaveSensitive, "y");

return F;
}

Expand All @@ -541,15 +640,17 @@ Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
Function *GetOrCreateHLFunctionWithBody(Module &M, FunctionType *funcTy,
HLOpcodeGroup group, unsigned opcode,
StringRef name) {
std::string operatorName = GetHLFullName(group, opcode);
// Set/transfer all common attributes
AttributeSet attribs = GetHLFunctionAttributes(
M.getContext(), funcTy, AttributeSet(), group, opcode);

std::string operatorName = GetHLFullName(group, opcode, attribs);
std::string mangledName = operatorName + "." + name.str();
raw_string_ostream mangledNameStr(mangledName);
funcTy->print(mangledNameStr);
mangledNameStr.flush();

Function *F = cast<Function>(M.getOrInsertFunction(mangledName, funcTy));

SetHLFunctionAttribute(F, group, opcode);
Function *F = cast<Function>(M.getOrInsertFunction(mangledName, funcTy, attribs));

F->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,25 @@

// One HL call from each function
// 18 functions for HL lib due to entry cloning
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id:.*]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB-NOT: call i1 @"dx.hl.op..i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id:.*]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])
// CHECKHLLIB-NOT: call i1 @"dx.hl.op.ro.i1 (i32)"(i32 [[id]])


// CHECKGV: %[[cov:.*]] = call i32 @dx.op.coverage.i32(i32 91) ; Coverage()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ float4 main(uint i:I) : SV_Target {
// FCGL-SAME: %dx.types.ResourceProperties { i32 {{[0-9]+}}, i32 {{[0-9]+}} },
// FCGL-SAME: %"class.Buffer<vector<float, 4> >" undef)

// FCGL: {{%.+}} = call <4 x float>* @"dx.hl.subscript.[].<4 x float>* (i32, %dx.types.Handle, i32)"
// FCGL: {{%.+}} = call <4 x float>* @"dx.hl.subscript.[].rn.<4 x float>* (i32, %dx.types.Handle, i32)"
// FCGL-SAME: (i32 0, %dx.types.Handle [[AnnHandle]], i32 2)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: %dxc -T cs_6_5 -E CS -fcgl %s | FileCheck %s
// RUN: %dxc -T cs_6_5 -E CS %s | FileCheck %s -check-prefix=CHECKDXIL

// Proceed called before CommittedTriangleFrontFace.
// Don't be sensitive to HL Opcode because those can change.
// CHECK: call i1 [[HLProceed:@"[^"]+"]](i32
// CHECK: call i1 [[HLCommittedTriangleFrontFace:@"[^".]+\.[^.]+\.[^.]+\.ro[^"]+"]](i32
// ^ matches call i1 @"dx.hl.op.ro.i1 (i32, %\22class.RayQuery<5>\22*)"(i32
// CHECK-LABEL: ret void,

// Ensure HL declarations are not collapsed when attributes differ
// CHECK-DAG: declare i1 [[HLProceed]]({{.*}}) #[[AttrProceed:[0-9]+]]
// CHECK-DAG: declare i1 [[HLCommittedTriangleFrontFace]]({{.*}}) #[[AttrCommittedTriangleFrontFace:[0-9]+]]

// Ensure correct attributes for each HL intrinsic
// CHECK-DAG: attributes #[[AttrProceed]] = { nounwind }
// CHECK-DAG: attributes #[[AttrCommittedTriangleFrontFace]] = { nounwind readonly }

// Ensure Proceed not eliminated in final DXIL:
// CHECKDXIL: call i1 @dx.op.rayQuery_Proceed.i1(i32 180,
// CHECKDXIL: call i1 @dx.op.rayQuery_StateScalar.i1(i32 192,

RaytracingAccelerationStructure AccelerationStructure : register(t0);
RWByteAddressBuffer log : register(u0);

[numThreads(1,1,1)]
void CS()
{
RayQuery<RAY_FLAG_FORCE_OPAQUE|RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH> q;
RayDesc ray = { {0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 0.0f}, 0.0f, 9999.0f};
q.TraceRayInline(AccelerationStructure, RAY_FLAG_NONE, 0xFF, ray);

q.Proceed();

if(q.CommittedTriangleFrontFace())
{
log.Store(0,1);
}
}
Loading

0 comments on commit d9c07e9

Please sign in to comment.