Skip to content

Commit

Permalink
Add alignment to NodeRecordType including DXIL metadata update (micro…
Browse files Browse the repository at this point in the history
…soft#6279)

This change adds NodeRecordType alignment field to RDAT to make it
possible to validate pointer and stride alignment in the runtime.

This includes a change to DXIL metadata to preserve the record alignment
without requiring recovery by looking for GetNodeRecordPtr.

Fixes microsoft#6270
  • Loading branch information
tex3d authored Feb 29, 2024
1 parent bbcbb2d commit 66ba5a1
Show file tree
Hide file tree
Showing 18 changed files with 333 additions and 118 deletions.
2 changes: 2 additions & 0 deletions include/dxc/DXIL/DxilMetadataHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ class DxilMDHelper {
// Node Record Type
static const unsigned kDxilNodeRecordSizeTag = 0;
static const unsigned kDxilNodeSVDispatchGridTag = 1;
static const unsigned kDxilNodeRecordAlignmentTag = 2;

// GSState.
static const unsigned kDxilGSStateNumFields = 5;
Expand Down Expand Up @@ -624,6 +625,7 @@ class DxilMDHelper {
unsigned &payloadSizeInBytes);

llvm::MDTuple *EmitDxilNodeIOState(const NodeIOProperties &Node);
llvm::MDTuple *EmitDxilNodeRecordType(const NodeRecordType &RecordType);
hlsl::NodeIOProperties LoadDxilNodeIOState(const llvm::MDOperand &MDO);
hlsl::NodeRecordType LoadDxilNodeRecordType(const llvm::MDOperand &MDO);

Expand Down
1 change: 1 addition & 0 deletions include/dxc/DXIL/DxilNodeProps.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ struct SVDispatchGrid {
//
struct NodeRecordType {
unsigned size;
unsigned alignment;
SVDispatchGrid SV_DispatchGrid;
};

Expand Down
5 changes: 5 additions & 0 deletions include/dxc/DxilContainer/RDAT_LibraryTypes.inl
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ RDAT_ENUM_START(NodeAttribKind, uint32_t)
RDAT_ENUM_VALUE(RecordDispatchGrid, 5)
RDAT_ENUM_VALUE(OutputArraySize, 6)
RDAT_ENUM_VALUE(AllowSparseNodes, 7)
RDAT_ENUM_VALUE(RecordAlignmentInBytes, 8)
RDAT_ENUM_VALUE_NODEF(LastValue)
RDAT_ENUM_END()

Expand Down Expand Up @@ -407,6 +408,10 @@ RDAT_STRUCT_TABLE(NodeShaderIOAttrib, NodeShaderIOAttribTable)
getAttribKind() ==
hlsl::RDAT::NodeAttribKind::AllowSparseNodes)
RDAT_VALUE(uint32_t, AllowSparseNodes)
RDAT_UNION_ELIF(RecordAlignmentInBytes,
getAttribKind() ==
hlsl::RDAT::NodeAttribKind::RecordAlignmentInBytes)
RDAT_VALUE(uint32_t, RecordAlignmentInBytes)
RDAT_UNION_ENDIF()
RDAT_UNION_END()
RDAT_STRUCT_END()
Expand Down
58 changes: 40 additions & 18 deletions lib/DXIL/DxilMetadataHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1967,6 +1967,7 @@ void DxilMDHelper::SerializeNodeProps(SmallVectorImpl<llvm::Metadata *> &MDVals,
nodeinput.RecordType.SV_DispatchGrid.ComponentType)));
MDVals.push_back(
Uint32ToConstMD(nodeinput.RecordType.SV_DispatchGrid.NumComponents));
MDVals.push_back(Uint32ToConstMD(nodeinput.RecordType.alignment));
}
for (auto &nodeoutput : props->OutputNodes) {
MDVals.push_back(Uint32ToConstMD(nodeoutput.Flags));
Expand All @@ -1983,6 +1984,7 @@ void DxilMDHelper::SerializeNodeProps(SmallVectorImpl<llvm::Metadata *> &MDVals,
MDVals.push_back(Int32ToConstMD(nodeoutput.MaxRecordsSharedWith));
MDVals.push_back(Uint32ToConstMD(nodeoutput.OutputArraySize));
MDVals.push_back(BoolToConstMD(nodeoutput.AllowSparseNodes));
MDVals.push_back(Uint32ToConstMD(nodeoutput.RecordType.alignment));
}
}

Expand Down Expand Up @@ -2019,6 +2021,10 @@ void DxilMDHelper::DeserializeNodeProps(const MDTuple *pProps, unsigned &idx,
ConstMDToUint32(pProps->getOperand(idx++)));
nodeinput.RecordType.SV_DispatchGrid.NumComponents =
ConstMDToUint32(pProps->getOperand(idx++));
if (pProps->getNumOperands() > idx) {
nodeinput.RecordType.alignment =
ConstMDToUint32(pProps->getOperand(idx++));
}
}

for (auto &nodeoutput : props->OutputNodes) {
Expand All @@ -2037,6 +2043,10 @@ void DxilMDHelper::DeserializeNodeProps(const MDTuple *pProps, unsigned &idx,
nodeoutput.MaxRecordsSharedWith = ConstMDToInt32(pProps->getOperand(idx++));
nodeoutput.OutputArraySize = ConstMDToUint32(pProps->getOperand(idx++));
nodeoutput.AllowSparseNodes = ConstMDToBool(pProps->getOperand(idx++));
if (pProps->getNumOperands() > idx) {
nodeoutput.RecordType.alignment =
ConstMDToUint32(pProps->getOperand(idx++));
}
}
}

Expand Down Expand Up @@ -2755,6 +2765,32 @@ void DxilMDHelper::EmitDxilNodeState(std::vector<llvm::Metadata *> &MDVals,
}
}

llvm::MDTuple *
DxilMDHelper::EmitDxilNodeRecordType(const NodeRecordType &RecordType) {
vector<Metadata *> MDVals;
MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilNodeRecordSizeTag));
MDVals.emplace_back(Uint32ToConstMD(RecordType.size));

if (RecordType.SV_DispatchGrid.NumComponents) {
MDVals.emplace_back(
Uint32ToConstMD(DxilMDHelper::kDxilNodeSVDispatchGridTag));
vector<Metadata *> SVDispatchGridVals;
SVDispatchGridVals.emplace_back(
Uint32ToConstMD(RecordType.SV_DispatchGrid.ByteOffset));
SVDispatchGridVals.emplace_back(Uint32ToConstMD(
static_cast<unsigned>(RecordType.SV_DispatchGrid.ComponentType)));
SVDispatchGridVals.emplace_back(
Uint32ToConstMD(RecordType.SV_DispatchGrid.NumComponents));
MDVals.emplace_back(MDNode::get(m_Ctx, SVDispatchGridVals));
}
if (RecordType.alignment) {
MDVals.emplace_back(
Uint32ToConstMD(DxilMDHelper::kDxilNodeRecordAlignmentTag));
MDVals.emplace_back(Uint32ToConstMD(RecordType.alignment));
}
return MDNode::get(m_Ctx, MDVals);
}

llvm::MDTuple *
DxilMDHelper::EmitDxilNodeIOState(const hlsl::NodeIOProperties &Node) {
vector<Metadata *> MDVals;
Expand All @@ -2763,24 +2799,7 @@ DxilMDHelper::EmitDxilNodeIOState(const hlsl::NodeIOProperties &Node) {

if (Node.RecordType.size) {
MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilNodeRecordTypeTag));
vector<Metadata *> NodeRecordTypeVals;
NodeRecordTypeVals.emplace_back(
Uint32ToConstMD(DxilMDHelper::kDxilNodeRecordSizeTag));
NodeRecordTypeVals.emplace_back(Uint32ToConstMD(Node.RecordType.size));
// If the record has a SV_DispatchGrid field
if (Node.RecordType.SV_DispatchGrid.NumComponents) {
NodeRecordTypeVals.emplace_back(
Uint32ToConstMD(DxilMDHelper::kDxilNodeSVDispatchGridTag));
vector<Metadata *> SVDispatchGridVals;
SVDispatchGridVals.emplace_back(
Uint32ToConstMD(Node.RecordType.SV_DispatchGrid.ByteOffset));
SVDispatchGridVals.emplace_back(Uint32ToConstMD(static_cast<unsigned>(
Node.RecordType.SV_DispatchGrid.ComponentType)));
SVDispatchGridVals.emplace_back(
Uint32ToConstMD(Node.RecordType.SV_DispatchGrid.NumComponents));
NodeRecordTypeVals.emplace_back(MDNode::get(m_Ctx, SVDispatchGridVals));
}
MDVals.emplace_back(MDNode::get(m_Ctx, NodeRecordTypeVals));
MDVals.emplace_back(EmitDxilNodeRecordType(Node.RecordType));
}

if (Node.Flags.IsOutputNode()) {
Expand Down Expand Up @@ -2856,6 +2875,9 @@ DxilMDHelper::LoadDxilNodeRecordType(const llvm::MDOperand &MDO) {
Record.SV_DispatchGrid.NumComponents =
ConstMDToUint32(pSVDTupleMD->getOperand(2));
} break;
case DxilMDHelper::kDxilNodeRecordAlignmentTag: {
Record.alignment = ConstMDToUint32(MDO);
} break;
default:
m_bExtraMetadata = true;
break;
Expand Down
7 changes: 7 additions & 0 deletions lib/DxilContainer/DxilContainerAssembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,13 @@ class DxilRDATWriter : public DxilPartWriter {
N.RecordType.SV_DispatchGrid.NumComponents);
nodeAttribs.push_back(Builder.InsertRecord(nAttrib));
}

if (N.RecordType.alignment) {
nAttrib = {};
nAttrib.AttribKind = (uint32_t)NodeAttribKind::RecordAlignmentInBytes;
nAttrib.RecordAlignmentInBytes = N.RecordType.alignment;
nodeAttribs.push_back(Builder.InsertRecord(nAttrib));
}
}

ioNode.Attribs =
Expand Down
4 changes: 3 additions & 1 deletion tools/clang/lib/CodeGen/CGHLSLMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2620,8 +2620,10 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
}

// Ex: For DispatchNodeInputRecord<MY_RECORD>, set size =
// size(MY_RECORD)
// size(MY_RECORD), alignment = alignof(MY_RECORD)
node.RecordType.size = CGM.getDataLayout().getTypeAllocSize(Type);
node.RecordType.alignment =
CGM.getDataLayout().getABITypeAlignment(Type);
// Iterate over fields of the MY_RECORD(example) struct
for (auto fieldDecl : RD->fields()) {
// Check if any of the fields have a semantic annotation =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// CHECK: ![[NodeDispatchGrid]] = !{i32 1, i32 1, i32 1}
// CHECK: ![[NodeInputs]] = !{![[Input0:[0-9]+]]}
// CHECK: ![[Input0]] = !{i32 1, i32 97, i32 2, ![[NodeRecordType:[0-9]+]]}
// CHECK: ![[NodeRecordType]] = !{i32 0, i32 68}
// CHECK: ![[NodeRecordType]] = !{i32 0, i32 68, i32 2, i32 4}
// CHECK: ![[NumThreads]] = !{i32 32, i32 1, i32 1}
// CHECK: ![[AutoBindingSpace]] = !{i32 0}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// CHECK: ![[NodeDispatchGrid]] = !{i32 1, i32 1, i32 1}
// CHECK: ![[NodeInputs]] = !{![[Input0:[0-9]+]]}
// CHECK: ![[Input0]] = !{i32 1, i32 97, i32 2, ![[NodeRecordType:[0-9]+]]}
// CHECK: ![[NodeRecordType]] = !{i32 0, i32 68}
// CHECK: ![[NodeRecordType]] = !{i32 0, i32 68, i32 2, i32 4}
// CHECK: ![[NumThreads]] = !{i32 32, i32 1, i32 1}
// CHECK: ![[AutoBindingSpace]] = !{i32 0}

Expand Down
Loading

0 comments on commit 66ba5a1

Please sign in to comment.