Skip to content

[DirectX] Split resource info into type and binding info. NFC #119773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 149 additions & 58 deletions llvm/include/llvm/Analysis/DXILResource.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class LLVMContext;
class MDTuple;
class Value;

class DXILResourceTypeMap;

namespace dxil {

/// The dx.RawBuffer target extension type
Expand Down Expand Up @@ -197,27 +199,8 @@ class SamplerExtType : public TargetExtType {

//===----------------------------------------------------------------------===//

class ResourceInfo {
class ResourceTypeInfo {
public:
struct ResourceBinding {
uint32_t RecordID;
uint32_t Space;
uint32_t LowerBound;
uint32_t Size;

bool operator==(const ResourceBinding &RHS) const {
return std::tie(RecordID, Space, LowerBound, Size) ==
std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
}
bool operator!=(const ResourceBinding &RHS) const {
return !(*this == RHS);
}
bool operator<(const ResourceBinding &RHS) const {
return std::tie(RecordID, Space, LowerBound, Size) <
std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
}
};

struct UAVInfo {
bool GloballyCoherent;
bool HasCounter;
Expand Down Expand Up @@ -267,22 +250,25 @@ class ResourceInfo {
};

private:
ResourceBinding Binding;
TargetExtType *HandleTy;

// GloballyCoherent and HasCounter aren't really part of the type and need to
// be determined by analysis, so they're just provided directly when we
// construct these.
// be determined by analysis, so they're just provided directly by the
// DXILResourceTypeMap when we construct these.
bool GloballyCoherent;
bool HasCounter;

dxil::ResourceClass RC;
dxil::ResourceKind Kind;

public:
ResourceInfo(uint32_t RecordID, uint32_t Space, uint32_t LowerBound,
uint32_t Size, TargetExtType *HandleTy,
bool GloballyCoherent = false, bool HasCounter = false);
ResourceTypeInfo(TargetExtType *HandleTy, const dxil::ResourceClass RC,
const dxil::ResourceKind Kind, bool GloballyCoherent = false,
bool HasCounter = false);
ResourceTypeInfo(TargetExtType *HandleTy, bool GloballyCoherent = false,
bool HasCounter = false)
: ResourceTypeInfo(HandleTy, {}, dxil::ResourceKind::Invalid,
GloballyCoherent, HasCounter) {}

TargetExtType *getHandleTy() const { return HandleTy; }

Expand All @@ -304,44 +290,145 @@ class ResourceInfo {
dxil::SamplerFeedbackType getFeedbackType() const;
uint32_t getMultiSampleCount() const;

StringRef getName() const {
// TODO: Get the name from the symbol once we include one here.
return "";
}
dxil::ResourceClass getResourceClass() const { return RC; }
dxil::ResourceKind getResourceKind() const { return Kind; }

void setGloballyCoherent(bool V) { GloballyCoherent = V; }
void setHasCounter(bool V) { HasCounter = V; }

bool operator==(const ResourceTypeInfo &RHS) const;
bool operator!=(const ResourceTypeInfo &RHS) const { return !(*this == RHS); }
bool operator<(const ResourceTypeInfo &RHS) const;

void print(raw_ostream &OS, const DataLayout &DL) const;
};

//===----------------------------------------------------------------------===//

class ResourceBindingInfo {
public:
struct ResourceBinding {
uint32_t RecordID;
uint32_t Space;
uint32_t LowerBound;
uint32_t Size;

bool operator==(const ResourceBinding &RHS) const {
return std::tie(RecordID, Space, LowerBound, Size) ==
std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
}
bool operator!=(const ResourceBinding &RHS) const {
return !(*this == RHS);
}
bool operator<(const ResourceBinding &RHS) const {
return std::tie(RecordID, Space, LowerBound, Size) <
std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
}
};

private:
ResourceBinding Binding;
TargetExtType *HandleTy;

public:
ResourceBindingInfo(uint32_t RecordID, uint32_t Space, uint32_t LowerBound,
uint32_t Size, TargetExtType *HandleTy)
: Binding{RecordID, Space, LowerBound, Size}, HandleTy(HandleTy) {}

void setBindingID(unsigned ID) { Binding.RecordID = ID; }

const ResourceBinding &getBinding() const { return Binding; }
TargetExtType *getHandleTy() const { return HandleTy; }
const StringRef getName() const {
// TODO: Get the name from the symbol once we include one here.
return "";
}

MDTuple *getAsMetadata(Module &M) const;
std::pair<uint32_t, uint32_t> getAnnotateProps(Module &M) const;
MDTuple *getAsMetadata(Module &M, dxil::ResourceTypeInfo &RTI) const;

bool operator==(const ResourceInfo &RHS) const;
bool operator!=(const ResourceInfo &RHS) const { return !(*this == RHS); }
bool operator<(const ResourceInfo &RHS) const;
std::pair<uint32_t, uint32_t>
getAnnotateProps(Module &M, dxil::ResourceTypeInfo &RTI) const;

void print(raw_ostream &OS, const DataLayout &DL) const;
bool operator==(const ResourceBindingInfo &RHS) const {
return std::tie(Binding, HandleTy) == std::tie(RHS.Binding, RHS.HandleTy);
}
bool operator!=(const ResourceBindingInfo &RHS) const {
return !(*this == RHS);
}
bool operator<(const ResourceBindingInfo &RHS) const {
return Binding < RHS.Binding;
}

void print(raw_ostream &OS, dxil::ResourceTypeInfo &RTI,
const DataLayout &DL) const;
};

} // namespace dxil

//===----------------------------------------------------------------------===//

class DXILResourceMap {
SmallVector<dxil::ResourceInfo> Infos;
class DXILResourceTypeMap {
DenseMap<TargetExtType *, dxil::ResourceTypeInfo> Infos;

public:
bool invalidate(Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &Inv);

dxil::ResourceTypeInfo &operator[](TargetExtType *Ty) {
auto It = Infos.find(Ty);
if (It != Infos.end())
return It->second;
auto [NewIt, Inserted] = Infos.try_emplace(Ty, Ty);
return NewIt->second;
}
};

class DXILResourceTypeAnalysis
: public AnalysisInfoMixin<DXILResourceTypeAnalysis> {
friend AnalysisInfoMixin<DXILResourceTypeAnalysis>;

static AnalysisKey Key;

public:
using Result = DXILResourceTypeMap;

DXILResourceTypeMap run(Module &M, ModuleAnalysisManager &AM) {
// Running the pass just generates an empty map, which will be filled when
// users of the pass query the results.
return Result();
}
};

class DXILResourceTypeWrapperPass : public ImmutablePass {
DXILResourceTypeMap DRTM;

virtual void anchor();

public:
static char ID;
DXILResourceTypeWrapperPass();

DXILResourceTypeMap &getResourceTypeMap() { return DRTM; }
const DXILResourceTypeMap &getResourceTypeMap() const { return DRTM; }
};

ModulePass *createDXILResourceTypeWrapperPassPass();

//===----------------------------------------------------------------------===//

class DXILBindingMap {
SmallVector<dxil::ResourceBindingInfo> Infos;
DenseMap<CallInst *, unsigned> CallMap;
unsigned FirstUAV = 0;
unsigned FirstCBuffer = 0;
unsigned FirstSampler = 0;

/// Populate the map given the resource binding calls in the given module.
void populate(Module &M);
void populate(Module &M, DXILResourceTypeMap &DRTM);

public:
using iterator = SmallVector<dxil::ResourceInfo>::iterator;
using const_iterator = SmallVector<dxil::ResourceInfo>::const_iterator;
using iterator = SmallVector<dxil::ResourceBindingInfo>::iterator;
using const_iterator = SmallVector<dxil::ResourceBindingInfo>::const_iterator;

iterator begin() { return Infos.begin(); }
const_iterator begin() const { return Infos.begin(); }
Expand Down Expand Up @@ -400,47 +487,51 @@ class DXILResourceMap {
return make_range(sampler_begin(), sampler_end());
}

void print(raw_ostream &OS, const DataLayout &DL) const;
void print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
const DataLayout &DL) const;

friend class DXILResourceAnalysis;
friend class DXILResourceWrapperPass;
friend class DXILResourceBindingAnalysis;
friend class DXILResourceBindingWrapperPass;
};

class DXILResourceAnalysis : public AnalysisInfoMixin<DXILResourceAnalysis> {
friend AnalysisInfoMixin<DXILResourceAnalysis>;
class DXILResourceBindingAnalysis
: public AnalysisInfoMixin<DXILResourceBindingAnalysis> {
friend AnalysisInfoMixin<DXILResourceBindingAnalysis>;

static AnalysisKey Key;

public:
using Result = DXILResourceMap;
using Result = DXILBindingMap;

/// Gather resource info for the module \c M.
DXILResourceMap run(Module &M, ModuleAnalysisManager &AM);
DXILBindingMap run(Module &M, ModuleAnalysisManager &AM);
};

/// Printer pass for the \c DXILResourceAnalysis results.
class DXILResourcePrinterPass : public PassInfoMixin<DXILResourcePrinterPass> {
/// Printer pass for the \c DXILResourceBindingAnalysis results.
class DXILResourceBindingPrinterPass
: public PassInfoMixin<DXILResourceBindingPrinterPass> {
raw_ostream &OS;

public:
explicit DXILResourcePrinterPass(raw_ostream &OS) : OS(OS) {}
explicit DXILResourceBindingPrinterPass(raw_ostream &OS) : OS(OS) {}

PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);

static bool isRequired() { return true; }
};

class DXILResourceWrapperPass : public ModulePass {
std::unique_ptr<DXILResourceMap> Map;
class DXILResourceBindingWrapperPass : public ModulePass {
std::unique_ptr<DXILBindingMap> Map;
DXILResourceTypeMap *DRTM;

public:
static char ID; // Class identification, replacement for typeinfo

DXILResourceWrapperPass();
~DXILResourceWrapperPass() override;
DXILResourceBindingWrapperPass();
~DXILResourceBindingWrapperPass() override;

const DXILResourceMap &getResourceMap() const { return *Map; }
DXILResourceMap &getResourceMap() { return *Map; }
const DXILBindingMap &getBindingMap() const { return *Map; }
DXILBindingMap &getBindingMap() { return *Map; }

void getAnalysisUsage(AnalysisUsage &AU) const override;
bool runOnModule(Module &M) override;
Expand All @@ -450,7 +541,7 @@ class DXILResourceWrapperPass : public ModulePass {
void dump() const;
};

ModulePass *createDXILResourceWrapperPassPass();
ModulePass *createDXILResourceBindingWrapperPassPass();

} // namespace llvm

Expand Down
3 changes: 2 additions & 1 deletion llvm/include/llvm/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ void initializeDAHPass(PassRegistry &);
void initializeDCELegacyPassPass(PassRegistry &);
void initializeDXILMetadataAnalysisWrapperPassPass(PassRegistry &);
void initializeDXILMetadataAnalysisWrapperPrinterPass(PassRegistry &);
void initializeDXILResourceWrapperPassPass(PassRegistry &);
void initializeDXILResourceBindingWrapperPassPass(PassRegistry &);
void initializeDXILResourceTypeWrapperPassPass(PassRegistry &);
void initializeDeadMachineInstructionElimPass(PassRegistry &);
void initializeDebugifyMachineModulePass(PassRegistry &);
void initializeDependenceAnalysisWrapperPassPass(PassRegistry &);
Expand Down
3 changes: 2 additions & 1 deletion llvm/include/llvm/LinkAllPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ struct ForcePassLinking {
(void)llvm::createCallGraphViewerPass();
(void)llvm::createCFGSimplificationPass();
(void)llvm::createStructurizeCFGPass();
(void)llvm::createDXILResourceWrapperPassPass();
(void)llvm::createDXILResourceBindingWrapperPassPass();
(void)llvm::createDXILResourceTypeWrapperPassPass();
(void)llvm::createDeadArgEliminationPass();
(void)llvm::createDeadCodeEliminationPass();
(void)llvm::createDependenceAnalysisWrapperPass();
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Analysis/Analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ void llvm::initializeAnalysis(PassRegistry &Registry) {
initializeCallGraphDOTPrinterPass(Registry);
initializeCallGraphViewerPass(Registry);
initializeCycleInfoWrapperPassPass(Registry);
initializeDXILResourceWrapperPassPass(Registry);
initializeDXILResourceBindingWrapperPassPass(Registry);
initializeDXILResourceTypeWrapperPassPass(Registry);
initializeDependenceAnalysisWrapperPassPass(Registry);
initializeDominanceFrontierWrapperPassPass(Registry);
initializeDomViewerWrapperPassPass(Registry);
Expand Down
Loading
Loading