Skip to content

Commit

Permalink
dxcopt: Support full container and restore extra data to module (micr…
Browse files Browse the repository at this point in the history
…osoft#4845)

This modifies IDxcOptimizer::RunOptimizier to accept full DxilContainer input. When full container input is used, this restores some data that is stripped from the module and placed in various other container parts.

Data restored:
  - Subobjects from RDAT
  - RootSignature from RTS0
  - ViewID and I/O dependency data from PSV0
  - Resource names and types/annotations from STAT

Serialization of these to metadata in module bitcode output still requires hlsl-dxilemit step.
  • Loading branch information
tex3d authored Dec 13, 2022
1 parent 21cf36a commit 2c3d965
Show file tree
Hide file tree
Showing 6 changed files with 1,597 additions and 68 deletions.
3 changes: 3 additions & 0 deletions include/dxc/DXIL/DxilModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ class DxilModule {
void StripDebugRelatedCode();
void RemoveUnusedTypeAnnotations();

// Copy resource reflection back to this module's resources.
void RestoreResourceReflection(const DxilModule &SourceDM);

// Helper to remove dx.* metadata with source and compile options.
// If the parameter `bReplaceWithDummyData` is true, the named metadata
// are replaced with valid empty data that satisfy tools.
Expand Down
11 changes: 11 additions & 0 deletions include/dxc/DxilContainer/DxilContainerAssembler.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "llvm/ADT/StringRef.h"

struct IStream;
class DxilPipelineStateValidation;

namespace llvm {
class Module;
Expand Down Expand Up @@ -51,6 +52,16 @@ DxilPartWriter *NewFeatureInfoWriter(const DxilModule &M);
DxilPartWriter *NewPSVWriter(const DxilModule &M, uint32_t PSVVersion = UINT_MAX);
DxilPartWriter *NewRDATWriter(const DxilModule &M);

// Store serialized ViewID data from DxilModule to PipelineStateValidation.
void StoreViewIDStateToPSV(const uint32_t *pInputData,
unsigned InputSizeInUInts,
DxilPipelineStateValidation &PSV);
// Load ViewID state from PSV back to DxilModule view state vector.
// Pass nullptr for pOutputData to compute and return needed OutputSizeInUInts.
unsigned LoadViewIDStateFromPSV(unsigned *pOutputData,
unsigned OutputSizeInUInts,
const DxilPipelineStateValidation &PSV);

// Unaligned is for matching container for validator version < 1.7.
DxilContainerWriter *NewDxilContainerWriter(bool bUnaligned = false);

Expand Down
56 changes: 56 additions & 0 deletions lib/DXIL/DxilModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1873,6 +1873,62 @@ void DxilModule::RemoveUnusedTypeAnnotations() {
}


template <typename _T>
static void CopyResourceInfo(_T &TargetRes, const _T &SourceRes,
DxilTypeSystem &TargetTypeSys,
const DxilTypeSystem &SourceTypeSys) {
if (TargetRes.GetKind() != SourceRes.GetKind() ||
TargetRes.GetLowerBound() != SourceRes.GetLowerBound() ||
TargetRes.GetRangeSize() != SourceRes.GetRangeSize() ||
TargetRes.GetSpaceID() != SourceRes.GetSpaceID()) {
DXASSERT(false, "otherwise, resource details don't match");
return;
}

if (TargetRes.GetGlobalName().empty() && !SourceRes.GetGlobalName().empty()) {
TargetRes.SetGlobalName(SourceRes.GetGlobalName());
}

if (TargetRes.GetGlobalSymbol() && SourceRes.GetGlobalSymbol() &&
SourceRes.GetGlobalSymbol()->hasName()) {
TargetRes.GetGlobalSymbol()->setName(
SourceRes.GetGlobalSymbol()->getName());
}

Type *Ty = SourceRes.GetHLSLType();
TargetRes.SetHLSLType(Ty);
TargetTypeSys.CopyTypeAnnotation(Ty, SourceTypeSys);
}

void DxilModule::RestoreResourceReflection(const DxilModule &SourceDM) {
DxilTypeSystem &TargetTypeSys = GetTypeSystem();
const DxilTypeSystem &SourceTypeSys = SourceDM.GetTypeSystem();
if (GetCBuffers().size() != SourceDM.GetCBuffers().size() ||
GetSRVs().size() != SourceDM.GetSRVs().size() ||
GetUAVs().size() != SourceDM.GetUAVs().size() ||
GetSamplers().size() != SourceDM.GetSamplers().size()) {
DXASSERT(false, "otherwise, resource lists don't match");
return;
}
for (unsigned i = 0; i < GetCBuffers().size(); ++i) {
CopyResourceInfo(GetCBuffer(i), SourceDM.GetCBuffer(i), TargetTypeSys,
SourceTypeSys);
}
for (unsigned i = 0; i < GetSRVs().size(); ++i) {
CopyResourceInfo(GetSRV(i), SourceDM.GetSRV(i), TargetTypeSys,
SourceTypeSys);
}
for (unsigned i = 0; i < GetUAVs().size(); ++i) {
CopyResourceInfo(GetUAV(i), SourceDM.GetUAV(i), TargetTypeSys,
SourceTypeSys);
}
for (unsigned i = 0; i < GetSamplers().size(); ++i) {
CopyResourceInfo(GetSampler(i), SourceDM.GetSampler(i), TargetTypeSys,
SourceTypeSys);
}
}


void DxilModule::LoadDxilResources(const llvm::MDOperand &MDO) {
if (MDO.get() == nullptr)
return;
Expand Down
208 changes: 175 additions & 33 deletions lib/DxilContainer/DxilContainerAssembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,180 @@ DxilPartWriter *hlsl::NewFeatureInfoWriter(const DxilModule &M) {
return new DxilFeatureInfoWriter(M);
}


//////////////////////////////////////////////////////////
// Utility code for serializing/deserializing ViewID state

// Code for ComputeSeriaizedViewIDStateSizeInUInts copied from
// ComputeViewIdState. It could be moved into some common location if this
// ViewID serialization/deserialization code were moved out of here.
static unsigned RoundUpToUINT(unsigned x) { return (x + 31) / 32; }
static unsigned ComputeSeriaizedViewIDStateSizeInUInts(
const PSVShaderKind SK, const bool bUsesViewID,
const unsigned InputScalars, const unsigned OutputScalars[4],
const unsigned PCScalars) {
// Compute serialized state size in UINTs.
unsigned NumStreams = SK == PSVShaderKind::Geometry ? 4 : 1;
unsigned Size = 0;
Size += 1; // #Inputs.
for (unsigned StreamId = 0; StreamId < NumStreams; StreamId++) {
Size += 1; // #Outputs for stream StreamId.
unsigned NumOutputs = OutputScalars[StreamId];
unsigned NumOutUINTs = RoundUpToUINT(NumOutputs);
if (bUsesViewID) {
Size += NumOutUINTs; // m_OutputsDependentOnViewId[StreamId]
}
Size += InputScalars * NumOutUINTs; // m_InputsContributingToOutputs[StreamId]
}
if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Domain || SK == PSVShaderKind::Mesh) {
Size += 1; // #PatchConstant.
unsigned NumPCUINTs = RoundUpToUINT(PCScalars);
if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Mesh) {
if (bUsesViewID) {
Size += NumPCUINTs; // m_PCOrPrimOutputsDependentOnViewId
}
Size += InputScalars * NumPCUINTs; // m_InputsContributingToPCOrPrimOutputs
} else {
unsigned NumOutputs = OutputScalars[0];
unsigned NumOutUINTs = RoundUpToUINT(NumOutputs);
Size += PCScalars * NumOutUINTs; // m_PCInputsContributingToOutputs
}
}
return Size;
}

static const uint32_t *CopyViewIDStateForOutputToPSV(
const uint32_t *pSrc, uint32_t InputScalars, uint32_t OutputScalars,
PSVComponentMask ViewIDMask, PSVDependencyTable IOTable) {
unsigned MaskDwords = PSVComputeMaskDwordsFromVectors(PSVALIGN4(OutputScalars) / 4);
if (ViewIDMask.IsValid()) {
DXASSERT_NOMSG(!IOTable.Table || ViewIDMask.NumVectors == IOTable.OutputVectors);
memcpy(ViewIDMask.Mask, pSrc, 4 * MaskDwords);
pSrc += MaskDwords;
}
if (IOTable.IsValid() && IOTable.InputVectors && IOTable.OutputVectors) {
DXASSERT_NOMSG((InputScalars <= IOTable.InputVectors * 4) && (IOTable.InputVectors * 4 - InputScalars < 4));
DXASSERT_NOMSG((OutputScalars <= IOTable.OutputVectors * 4) && (IOTable.OutputVectors * 4 - OutputScalars < 4));
memcpy(IOTable.Table, pSrc, 4 * MaskDwords * InputScalars);
pSrc += MaskDwords * InputScalars;
}
return pSrc;
}

static uint32_t *CopyViewIDStateForOutputFromPSV(uint32_t *pOutputData,
const unsigned InputScalars,
const unsigned OutputScalars,
PSVComponentMask ViewIDMask,
PSVDependencyTable IOTable) {
unsigned MaskDwords = PSVComputeMaskDwordsFromVectors(PSVALIGN4(OutputScalars) / 4);
if (ViewIDMask.IsValid()) {
DXASSERT_NOMSG(!IOTable.Table || ViewIDMask.NumVectors == IOTable.OutputVectors);
for (unsigned i = 0; i < MaskDwords; i++)
*(pOutputData++) = ViewIDMask.Mask[i];
}
if (IOTable.IsValid() && IOTable.InputVectors && IOTable.OutputVectors) {
DXASSERT_NOMSG((InputScalars <= IOTable.InputVectors * 4) && (IOTable.InputVectors * 4 - InputScalars < 4));
DXASSERT_NOMSG((OutputScalars <= IOTable.OutputVectors * 4) && (IOTable.OutputVectors * 4 - OutputScalars < 4));
for (unsigned i = 0; i < MaskDwords * InputScalars; i++)
*(pOutputData++) = IOTable.Table[i];
}
return pOutputData;
}

void hlsl::StoreViewIDStateToPSV(const uint32_t *pInputData,
unsigned InputSizeInUInts,
DxilPipelineStateValidation &PSV) {
PSVRuntimeInfo1 *pInfo1 = PSV.GetPSVRuntimeInfo1();
DXASSERT(pInfo1, "otherwise, PSV does not meet version requirement.");
PSVShaderKind SK = static_cast<PSVShaderKind>(pInfo1->ShaderStage);
const unsigned OutputStreams = SK == PSVShaderKind::Geometry ? 4 : 1;
const uint32_t *pSrc = pInputData;
const uint32_t InputScalars = *(pSrc++);
uint32_t OutputScalars[4];
for (unsigned streamIndex = 0; streamIndex < OutputStreams; streamIndex++) {
OutputScalars[streamIndex] = *(pSrc++);
pSrc = CopyViewIDStateForOutputToPSV(
pSrc, InputScalars, OutputScalars[streamIndex],
PSV.GetViewIDOutputMask(streamIndex),
PSV.GetInputToOutputTable(streamIndex));
}
if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Mesh) {
const uint32_t PCScalars = *(pSrc++);
pSrc = CopyViewIDStateForOutputToPSV(pSrc, InputScalars, PCScalars,
PSV.GetViewIDPCOutputMask(),
PSV.GetInputToPCOutputTable());
} else if (SK == PSVShaderKind::Domain) {
const uint32_t PCScalars = *(pSrc++);
pSrc = CopyViewIDStateForOutputToPSV(pSrc, PCScalars, OutputScalars[0],
PSVComponentMask(),
PSV.GetPCInputToOutputTable());
}
DXASSERT(pSrc - pInputData == InputSizeInUInts,
"otherwise, different amout of data written than expected.");
}

// This function is defined close to the serialization code in DxilPSVWriter to
// reduce the chance of a mismatch. It could be defined elsewhere, but it would
// make sense to move both the serialization and deserialization out of here and
// into a common location.
unsigned hlsl::LoadViewIDStateFromPSV(unsigned *pOutputData,
unsigned OutputSizeInUInts,
const DxilPipelineStateValidation &PSV) {
PSVRuntimeInfo1 *pInfo1 = PSV.GetPSVRuntimeInfo1();
if (!pInfo1) {
return 0;
}
PSVShaderKind SK = static_cast<PSVShaderKind>(pInfo1->ShaderStage);
const unsigned OutputStreams = SK == PSVShaderKind::Geometry ? 4 : 1;
const unsigned InputScalars = pInfo1->SigInputVectors * 4;
unsigned OutputScalars[4];
for (unsigned streamIndex = 0; streamIndex < OutputStreams; streamIndex++) {
OutputScalars[streamIndex] = pInfo1->SigOutputVectors[streamIndex] * 4;
}
unsigned PCScalars = 0;
if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Mesh ||
SK == PSVShaderKind::Domain) {
PCScalars = pInfo1->SigPatchConstOrPrimVectors * 4;
}
if (pOutputData == nullptr) {
return ComputeSeriaizedViewIDStateSizeInUInts(
SK, pInfo1->UsesViewID != 0, InputScalars, OutputScalars, PCScalars);
}

// Fill in serialized viewid buffer.
DXASSERT(ComputeSeriaizedViewIDStateSizeInUInts(
SK, pInfo1->UsesViewID != 0, InputScalars, OutputScalars,
PCScalars) == OutputSizeInUInts,
"otherwise, OutputSize doesn't match computed size.");
unsigned *pStartOutputData = pOutputData;
*(pOutputData++) = InputScalars;
for (unsigned streamIndex = 0; streamIndex < OutputStreams; streamIndex++) {
*(pOutputData++) = OutputScalars[streamIndex];
pOutputData = CopyViewIDStateForOutputFromPSV(
pOutputData, InputScalars, OutputScalars[streamIndex],
PSV.GetViewIDOutputMask(streamIndex),
PSV.GetInputToOutputTable(streamIndex));
}
if (SK == PSVShaderKind::Hull || SK == PSVShaderKind::Mesh) {
*(pOutputData++) = PCScalars;
pOutputData = CopyViewIDStateForOutputFromPSV(
pOutputData, InputScalars, PCScalars, PSV.GetViewIDPCOutputMask(),
PSV.GetInputToPCOutputTable());
} else if (SK == PSVShaderKind::Domain) {
*(pOutputData++) = PCScalars;
pOutputData = CopyViewIDStateForOutputFromPSV(
pOutputData, PCScalars, OutputScalars[0], PSVComponentMask(),
PSV.GetPCInputToOutputTable());
}
DXASSERT(pOutputData - pStartOutputData == OutputSizeInUInts,
"otherwise, OutputSizeInUInts didn't match size written.");
return pOutputData - pStartOutputData;
}


//////////////////////////////////////////////////////////
// DxilPSVWriter - Writes PSV0 part

class DxilPSVWriter : public DxilPartWriter {
private:
const DxilModule &m_Module;
Expand Down Expand Up @@ -509,22 +683,6 @@ class DxilPSVWriter : public DxilPartWriter {
E.DynamicMaskAndStream |= (SE.GetDynIdxCompMask()) & 0xF;
}

const uint32_t *CopyViewIDState(const uint32_t *pSrc, uint32_t InputScalars, uint32_t OutputScalars, PSVComponentMask ViewIDMask, PSVDependencyTable IOTable) {
unsigned MaskDwords = PSVComputeMaskDwordsFromVectors(PSVALIGN4(OutputScalars) / 4);
if (ViewIDMask.IsValid()) {
DXASSERT_NOMSG(!IOTable.Table || ViewIDMask.NumVectors == IOTable.OutputVectors);
memcpy(ViewIDMask.Mask, pSrc, 4 * MaskDwords);
pSrc += MaskDwords;
}
if (IOTable.IsValid() && IOTable.InputVectors && IOTable.OutputVectors) {
DXASSERT_NOMSG((InputScalars <= IOTable.InputVectors * 4) && (IOTable.InputVectors * 4 - InputScalars < 4));
DXASSERT_NOMSG((OutputScalars <= IOTable.OutputVectors * 4) && (IOTable.OutputVectors * 4 - OutputScalars < 4));
memcpy(IOTable.Table, pSrc, 4 * MaskDwords * InputScalars);
pSrc += MaskDwords * InputScalars;
}
return pSrc;
}

public:
DxilPSVWriter(const DxilModule &mod, uint32_t PSVVersion = UINT_MAX)
: m_Module(mod),
Expand Down Expand Up @@ -840,23 +998,7 @@ class DxilPSVWriter : public DxilPartWriter {
// Gather ViewID dependency information
auto &viewState = m_Module.GetSerializedViewIdState();
if (!viewState.empty()) {
const uint32_t *pSrc = viewState.data();
const uint32_t InputScalars = *(pSrc++);
uint32_t OutputScalars[4];
for (unsigned streamIndex = 0; streamIndex < 4; streamIndex++) {
OutputScalars[streamIndex] = *(pSrc++);
pSrc = CopyViewIDState(pSrc, InputScalars, OutputScalars[streamIndex], m_PSV.GetViewIDOutputMask(streamIndex), m_PSV.GetInputToOutputTable(streamIndex));
if (!SM->IsGS())
break;
}
if (SM->IsHS() || SM->IsMS()) {
const uint32_t PCScalars = *(pSrc++);
pSrc = CopyViewIDState(pSrc, InputScalars, PCScalars, m_PSV.GetViewIDPCOutputMask(), m_PSV.GetInputToPCOutputTable());
} else if (SM->IsDS()) {
const uint32_t PCScalars = *(pSrc++);
pSrc = CopyViewIDState(pSrc, PCScalars, OutputScalars[0], PSVComponentMask(), m_PSV.GetPCInputToOutputTable());
}
DXASSERT_NOMSG(viewState.data() + viewState.size() == pSrc);
StoreViewIDStateToPSV(viewState.data(), (unsigned)viewState.size(), m_PSV);
}
}

Expand Down
Loading

0 comments on commit 2c3d965

Please sign in to comment.