diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h index 6b577c02f0545..b59798a3d64a4 100644 --- a/llvm/include/llvm/Analysis/DXILResource.h +++ b/llvm/include/llvm/Analysis/DXILResource.h @@ -11,6 +11,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/PassManager.h" #include "llvm/Pass.h" #include "llvm/Support/Alignment.h" @@ -18,33 +19,187 @@ namespace llvm { class CallInst; +class DataLayout; class LLVMContext; class MDTuple; +class TargetExtType; class Value; +class DXILResourceTypeMap; + namespace dxil { -class ResourceInfo { +/// The dx.RawBuffer target extension type +/// +/// `target("dx.RawBuffer", Type, IsWriteable, IsROV)` +class RawBufferExtType : public TargetExtType { public: - struct ResourceBinding { - uint32_t RecordID; - uint32_t Space; - uint32_t LowerBound; - uint32_t Size; + RawBufferExtType() = delete; + RawBufferExtType(const RawBufferExtType &) = delete; + RawBufferExtType &operator=(const RawBufferExtType &) = delete; + + bool isStructured() const { + // TODO: We need to be more prescriptive here, but since there's some debate + // over whether byte address buffer should have a void type or an i8 type, + // accept either for now. + Type *Ty = getTypeParameter(0); + return !Ty->isVoidTy() && !Ty->isIntegerTy(8); + } - bool operator==(const ResourceBinding &RHS) const { - return std::tie(RecordID, Space, LowerBound, Size) == - std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size); - } - bool operator!=(const ResourceBinding &RHS) const { - return !(*this == RHS); - } - bool operator<(const ResourceBinding &RHS) const { - return std::tie(RecordID, Space, LowerBound, Size) < - std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size); - } - }; + Type *getResourceType() const { + return isStructured() ? getTypeParameter(0) : nullptr; + } + bool isWriteable() const { return getIntParameter(0); } + bool isROV() const { return getIntParameter(1); } + static bool classof(const TargetExtType *T) { + return T->getName() == "dx.RawBuffer"; + } + static bool classof(const Type *T) { + return isa(T) && classof(cast(T)); + } +}; + +/// The dx.TypedBuffer target extension type +/// +/// `target("dx.TypedBuffer", Type, IsWriteable, IsROV, IsSigned)` +class TypedBufferExtType : public TargetExtType { +public: + TypedBufferExtType() = delete; + TypedBufferExtType(const TypedBufferExtType &) = delete; + TypedBufferExtType &operator=(const TypedBufferExtType &) = delete; + + Type *getResourceType() const { return getTypeParameter(0); } + bool isWriteable() const { return getIntParameter(0); } + bool isROV() const { return getIntParameter(1); } + bool isSigned() const { return getIntParameter(2); } + + static bool classof(const TargetExtType *T) { + return T->getName() == "dx.TypedBuffer"; + } + static bool classof(const Type *T) { + return isa(T) && classof(cast(T)); + } +}; + +/// The dx.Texture target extension type +/// +/// `target("dx.Texture", Type, IsWriteable, IsROV, IsSigned, Dimension)` +class TextureExtType : public TargetExtType { +public: + TextureExtType() = delete; + TextureExtType(const TextureExtType &) = delete; + TextureExtType &operator=(const TextureExtType &) = delete; + + Type *getResourceType() const { return getTypeParameter(0); } + bool isWriteable() const { return getIntParameter(0); } + bool isROV() const { return getIntParameter(1); } + bool isSigned() const { return getIntParameter(2); } + dxil::ResourceKind getDimension() const { + return static_cast(getIntParameter(3)); + } + + static bool classof(const TargetExtType *T) { + return T->getName() == "dx.Texture"; + } + static bool classof(const Type *T) { + return isa(T) && classof(cast(T)); + } +}; + +/// The dx.MSTexture target extension type +/// +/// `target("dx.MSTexture", Type, IsWriteable, Samples, IsSigned, Dimension)` +class MSTextureExtType : public TargetExtType { +public: + MSTextureExtType() = delete; + MSTextureExtType(const MSTextureExtType &) = delete; + MSTextureExtType &operator=(const MSTextureExtType &) = delete; + + Type *getResourceType() const { return getTypeParameter(0); } + bool isWriteable() const { return getIntParameter(0); } + uint32_t getSampleCount() const { return getIntParameter(1); } + bool isSigned() const { return getIntParameter(2); } + dxil::ResourceKind getDimension() const { + return static_cast(getIntParameter(3)); + } + + static bool classof(const TargetExtType *T) { + return T->getName() == "dx.MSTexture"; + } + static bool classof(const Type *T) { + return isa(T) && classof(cast(T)); + } +}; + +/// The dx.FeedbackTexture target extension type +/// +/// `target("dx.FeedbackTexture", FeedbackType, Dimension)` +class FeedbackTextureExtType : public TargetExtType { +public: + FeedbackTextureExtType() = delete; + FeedbackTextureExtType(const FeedbackTextureExtType &) = delete; + FeedbackTextureExtType &operator=(const FeedbackTextureExtType &) = delete; + + dxil::SamplerFeedbackType getFeedbackType() const { + return static_cast(getIntParameter(0)); + } + dxil::ResourceKind getDimension() const { + return static_cast(getIntParameter(1)); + } + + static bool classof(const TargetExtType *T) { + return T->getName() == "dx.FeedbackTexture"; + } + static bool classof(const Type *T) { + return isa(T) && classof(cast(T)); + } +}; + +/// The dx.CBuffer target extension type +/// +/// `target("dx.CBuffer", , ...)` +class CBufferExtType : public TargetExtType { +public: + CBufferExtType() = delete; + CBufferExtType(const CBufferExtType &) = delete; + CBufferExtType &operator=(const CBufferExtType &) = delete; + + Type *getResourceType() const { return getTypeParameter(0); } + + static bool classof(const TargetExtType *T) { + return T->getName() == "dx.CBuffer"; + } + static bool classof(const Type *T) { + return isa(T) && classof(cast(T)); + } +}; + +/// The dx.Sampler target extension type +/// +/// `target("dx.Sampler", SamplerType)` +class SamplerExtType : public TargetExtType { +public: + SamplerExtType() = delete; + SamplerExtType(const SamplerExtType &) = delete; + SamplerExtType &operator=(const SamplerExtType &) = delete; + + dxil::SamplerType getSamplerType() const { + return static_cast(getIntParameter(0)); + } + + static bool classof(const TargetExtType *T) { + return T->getName() == "dx.Sampler"; + } + static bool classof(const Type *T) { + return isa(T) && classof(cast(T)); + } +}; + +//===----------------------------------------------------------------------===// + +class ResourceTypeInfo { +public: struct UAVInfo { bool GloballyCoherent; bool HasCounter; @@ -93,55 +248,31 @@ class ResourceInfo { } }; - struct MSInfo { - uint32_t Count; - - bool operator==(const MSInfo &RHS) const { return Count == RHS.Count; } - bool operator!=(const MSInfo &RHS) const { return !(*this == RHS); } - bool operator<(const MSInfo &RHS) const { return Count < RHS.Count; } - }; - - struct FeedbackInfo { - dxil::SamplerFeedbackType Type; - - bool operator==(const FeedbackInfo &RHS) const { return Type == RHS.Type; } - bool operator!=(const FeedbackInfo &RHS) const { return !(*this == RHS); } - bool operator<(const FeedbackInfo &RHS) const { return Type < RHS.Type; } - }; - private: - // Universal properties. - Value *Symbol; - StringRef Name; + TargetExtType *HandleTy; + + // GloballyCoherent and HasCounter aren't really part of the type and need to + // be determined by analysis, so they're just provided directly by the + // DXILResourceTypeMap when we construct these. + bool GloballyCoherent; + bool HasCounter; dxil::ResourceClass RC; dxil::ResourceKind Kind; - ResourceBinding Binding = {}; - - // Resource class dependent properties. - // CBuffer, Sampler, and RawBuffer end here. - union { - UAVInfo UAVFlags; // UAV - uint32_t CBufferSize; // CBuffer - dxil::SamplerType SamplerTy; // Sampler - }; - - // Resource kind dependent properties. - union { - StructInfo Struct; // StructuredBuffer - TypedInfo Typed; // All SRV/UAV except Raw/StructuredBuffer - FeedbackInfo Feedback; // FeedbackTexture - }; - - MSInfo MultiSample; - public: - ResourceInfo(dxil::ResourceClass RC, dxil::ResourceKind Kind, Value *Symbol, - StringRef Name) - : Symbol(Symbol), Name(Name), RC(RC), Kind(Kind) {} - - // Conditions to check before accessing union members. + ResourceTypeInfo(TargetExtType *HandleTy, const dxil::ResourceClass RC, + const dxil::ResourceKind Kind, bool GloballyCoherent = false, + bool HasCounter = false); + ResourceTypeInfo(TargetExtType *HandleTy, bool GloballyCoherent = false, + bool HasCounter = false) + : ResourceTypeInfo(HandleTy, {}, dxil::ResourceKind::Invalid, + GloballyCoherent, HasCounter) {} + + TargetExtType *getHandleTy() const { return HandleTy; } + StructType *createElementStruct(); + + // Conditions to check before accessing specific views. bool isUAV() const; bool isCBuffer() const; bool isSampler() const; @@ -150,148 +281,185 @@ class ResourceInfo { bool isFeedback() const; bool isMultiSample() const; - void bind(uint32_t RecordID, uint32_t Space, uint32_t LowerBound, - uint32_t Size) { - Binding.RecordID = RecordID; - Binding.Space = Space; - Binding.LowerBound = LowerBound; - Binding.Size = Size; - } + // Views into the type. + UAVInfo getUAV() const; + uint32_t getCBufferSize(const DataLayout &DL) const; + dxil::SamplerType getSamplerType() const; + StructInfo getStruct(const DataLayout &DL) const; + TypedInfo getTyped() const; + dxil::SamplerFeedbackType getFeedbackType() const; + uint32_t getMultiSampleCount() const; + + dxil::ResourceClass getResourceClass() const { return RC; } + dxil::ResourceKind getResourceKind() const { return Kind; } + + bool operator==(const ResourceTypeInfo &RHS) const; + bool operator!=(const ResourceTypeInfo &RHS) const { return !(*this == RHS); } + bool operator<(const ResourceTypeInfo &RHS) const; + + void print(raw_ostream &OS, const DataLayout &DL) const; +}; + +//===----------------------------------------------------------------------===// + +class ResourceBindingInfo { +public: + struct ResourceBinding { + uint32_t RecordID; + uint32_t Space; + uint32_t LowerBound; + uint32_t Size; + + bool operator==(const ResourceBinding &RHS) const { + return std::tie(RecordID, Space, LowerBound, Size) == + std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size); + } + bool operator!=(const ResourceBinding &RHS) const { + return !(*this == RHS); + } + bool operator<(const ResourceBinding &RHS) const { + return std::tie(RecordID, Space, LowerBound, Size) < + std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size); + } + }; + +private: + ResourceBinding Binding; + TargetExtType *HandleTy; + GlobalVariable *Symbol = nullptr; + +public: + ResourceBindingInfo(uint32_t RecordID, uint32_t Space, uint32_t LowerBound, + uint32_t Size, TargetExtType *HandleTy, + GlobalVariable *Symbol = nullptr) + : Binding{RecordID, Space, LowerBound, Size}, HandleTy(HandleTy), + Symbol(Symbol) {} + + void setBindingID(unsigned ID) { Binding.RecordID = ID; } + const ResourceBinding &getBinding() const { return Binding; } - void setUAV(bool GloballyCoherent, bool HasCounter, bool IsROV) { - assert(isUAV() && "Not a UAV"); - UAVFlags.GloballyCoherent = GloballyCoherent; - UAVFlags.HasCounter = HasCounter; - UAVFlags.IsROV = IsROV; - } - const UAVInfo &getUAV() const { - assert(isUAV() && "Not a UAV"); - return UAVFlags; - } - void setCBuffer(uint32_t Size) { - assert(isCBuffer() && "Not a CBuffer"); - CBufferSize = Size; + TargetExtType *getHandleTy() const { return HandleTy; } + const StringRef getName() const { return Symbol ? Symbol->getName() : ""; } + + bool hasSymbol() const { return Symbol; } + GlobalVariable *createSymbol(Module &M, StructType *Ty, StringRef Name = ""); + MDTuple *getAsMetadata(Module &M, DXILResourceTypeMap &DRTM) const; + MDTuple *getAsMetadata(Module &M, dxil::ResourceTypeInfo RTI) const; + + std::pair + getAnnotateProps(Module &M, DXILResourceTypeMap &DRTM) const; + std::pair + getAnnotateProps(Module &M, dxil::ResourceTypeInfo RTI) const; + + bool operator==(const ResourceBindingInfo &RHS) const { + return std::tie(Binding, HandleTy, Symbol) == + std::tie(RHS.Binding, RHS.HandleTy, RHS.Symbol); } - void setSampler(dxil::SamplerType Ty) { SamplerTy = Ty; } - void setStruct(uint32_t Stride, MaybeAlign Alignment) { - assert(isStruct() && "Not a Struct"); - Struct.Stride = Stride; - Struct.AlignLog2 = Alignment ? Log2(*Alignment) : 0; + bool operator!=(const ResourceBindingInfo &RHS) const { + return !(*this == RHS); } - void setTyped(dxil::ElementType ElementTy, uint32_t ElementCount) { - assert(isTyped() && "Not Typed"); - Typed.ElementTy = ElementTy; - Typed.ElementCount = ElementCount; + bool operator<(const ResourceBindingInfo &RHS) const { + return Binding < RHS.Binding; } - const TypedInfo &getTyped() const { - assert(isTyped() && "Not typed"); - return Typed; + + void print(raw_ostream &OS, DXILResourceTypeMap &DRTM, + const DataLayout &DL) const; + void print(raw_ostream &OS, dxil::ResourceTypeInfo RTI, + const DataLayout &DL) const; +}; + +} // namespace dxil + +//===----------------------------------------------------------------------===// + +class DXILResourceTypeMap { + struct Info { + dxil::ResourceClass RC; + dxil::ResourceKind Kind; + bool GloballyCoherent; + bool HasCounter; + }; + DenseMap Infos; + +public: + bool invalidate(Module &M, const PreservedAnalyses &PA, + ModuleAnalysisManager::Invalidator &Inv); + + dxil::ResourceTypeInfo operator[](TargetExtType *Ty) { + Info I = Infos[Ty]; + return dxil::ResourceTypeInfo(Ty, I.RC, I.Kind, I.GloballyCoherent, + I.HasCounter); } - void setFeedback(dxil::SamplerFeedbackType Type) { - assert(isFeedback() && "Not Feedback"); - Feedback.Type = Type; + + void setGloballyCoherent(TargetExtType *Ty, bool GloballyCoherent) { + Infos[Ty].GloballyCoherent = GloballyCoherent; } - void setMultiSample(uint32_t Count) { - assert(isMultiSample() && "Not MultiSampled"); - MultiSample.Count = Count; + + void setHasCounter(TargetExtType *Ty, bool HasCounter) { + Infos[Ty].HasCounter = HasCounter; } - const MSInfo &getMultiSample() const { - assert(isMultiSample() && "Not MultiSampled"); - return MultiSample; +}; + +class DXILResourceTypeAnalysis + : public AnalysisInfoMixin { + friend AnalysisInfoMixin; + + static AnalysisKey Key; + +public: + using Result = DXILResourceTypeMap; + + DXILResourceTypeMap run(Module &M, ModuleAnalysisManager &AM) { + return Result(); } +}; - StringRef getName() const { return Name; } - dxil::ResourceClass getResourceClass() const { return RC; } - dxil::ResourceKind getResourceKind() const { return Kind; } +class DXILResourceTypeWrapperPass : public ImmutablePass { + DXILResourceTypeMap DRTM; - bool operator==(const ResourceInfo &RHS) const; - bool operator!=(const ResourceInfo &RHS) const { return !(*this == RHS); } - bool operator<(const ResourceInfo &RHS) const; - - static ResourceInfo SRV(Value *Symbol, StringRef Name, - dxil::ElementType ElementTy, uint32_t ElementCount, - dxil::ResourceKind Kind); - static ResourceInfo RawBuffer(Value *Symbol, StringRef Name); - static ResourceInfo StructuredBuffer(Value *Symbol, StringRef Name, - uint32_t Stride, MaybeAlign Alignment); - static ResourceInfo Texture2DMS(Value *Symbol, StringRef Name, - dxil::ElementType ElementTy, - uint32_t ElementCount, uint32_t SampleCount); - static ResourceInfo Texture2DMSArray(Value *Symbol, StringRef Name, - dxil::ElementType ElementTy, - uint32_t ElementCount, - uint32_t SampleCount); - - static ResourceInfo UAV(Value *Symbol, StringRef Name, - dxil::ElementType ElementTy, uint32_t ElementCount, - bool GloballyCoherent, bool IsROV, - dxil::ResourceKind Kind); - static ResourceInfo RWRawBuffer(Value *Symbol, StringRef Name, - bool GloballyCoherent, bool IsROV); - static ResourceInfo RWStructuredBuffer(Value *Symbol, StringRef Name, - uint32_t Stride, MaybeAlign Alignment, - bool GloballyCoherent, bool IsROV, - bool HasCounter); - static ResourceInfo RWTexture2DMS(Value *Symbol, StringRef Name, - dxil::ElementType ElementTy, - uint32_t ElementCount, uint32_t SampleCount, - bool GloballyCoherent); - static ResourceInfo RWTexture2DMSArray(Value *Symbol, StringRef Name, - dxil::ElementType ElementTy, - uint32_t ElementCount, - uint32_t SampleCount, - bool GloballyCoherent); - static ResourceInfo FeedbackTexture2D(Value *Symbol, StringRef Name, - dxil::SamplerFeedbackType FeedbackTy); - static ResourceInfo - FeedbackTexture2DArray(Value *Symbol, StringRef Name, - dxil::SamplerFeedbackType FeedbackTy); - - static ResourceInfo CBuffer(Value *Symbol, StringRef Name, uint32_t Size); - - static ResourceInfo Sampler(Value *Symbol, StringRef Name, - dxil::SamplerType SamplerTy); - - MDTuple *getAsMetadata(LLVMContext &Ctx) const; - - std::pair getAnnotateProps() const; - - void print(raw_ostream &OS) const; + virtual void anchor(); + +public: + static char ID; + DXILResourceTypeWrapperPass(); + + DXILResourceTypeMap &getResourceTypeMap() { return DRTM; } + const DXILResourceTypeMap &getResourceTypeMap() const { return DRTM; } }; -} // namespace dxil +ModulePass *createDXILResourceTypeWrapperPassPass(); -class DXILResourceMap { - SmallVector Resources; +//===----------------------------------------------------------------------===// + +class DXILBindingMap { + SmallVector Infos; DenseMap CallMap; unsigned FirstUAV = 0; unsigned FirstCBuffer = 0; unsigned FirstSampler = 0; -public: - using iterator = SmallVector::iterator; - using const_iterator = SmallVector::const_iterator; + /// Populate the map given the resource binding calls in the given module. + void populate(Module &M, DXILResourceTypeMap &DRTM); - DXILResourceMap( - SmallVectorImpl> &&CIToRI); +public: + using iterator = SmallVector::iterator; + using const_iterator = SmallVector::const_iterator; - iterator begin() { return Resources.begin(); } - const_iterator begin() const { return Resources.begin(); } - iterator end() { return Resources.end(); } - const_iterator end() const { return Resources.end(); } + iterator begin() { return Infos.begin(); } + const_iterator begin() const { return Infos.begin(); } + iterator end() { return Infos.end(); } + const_iterator end() const { return Infos.end(); } - bool empty() const { return Resources.empty(); } + bool empty() const { return Infos.empty(); } iterator find(const CallInst *Key) { auto Pos = CallMap.find(Key); - return Pos == CallMap.end() ? Resources.end() - : (Resources.begin() + Pos->second); + return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second); } const_iterator find(const CallInst *Key) const { auto Pos = CallMap.find(Key); - return Pos == CallMap.end() ? Resources.end() - : (Resources.begin() + Pos->second); + return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second); } iterator srv_begin() { return begin(); } @@ -334,44 +502,51 @@ class DXILResourceMap { return make_range(sampler_begin(), sampler_end()); } - void print(raw_ostream &OS) const; + void print(raw_ostream &OS, DXILResourceTypeMap &DRTM, + const DataLayout &DL) const; + + friend class DXILResourceBindingAnalysis; + friend class DXILResourceBindingWrapperPass; }; -class DXILResourceAnalysis : public AnalysisInfoMixin { - friend AnalysisInfoMixin; +class DXILResourceBindingAnalysis + : public AnalysisInfoMixin { + friend AnalysisInfoMixin; static AnalysisKey Key; public: - using Result = DXILResourceMap; + using Result = DXILBindingMap; /// Gather resource info for the module \c M. - DXILResourceMap run(Module &M, ModuleAnalysisManager &AM); + DXILBindingMap run(Module &M, ModuleAnalysisManager &AM); }; -/// Printer pass for the \c DXILResourceAnalysis results. -class DXILResourcePrinterPass : public PassInfoMixin { +/// Printer pass for the \c DXILResourceBindingAnalysis results. +class DXILResourceBindingPrinterPass + : public PassInfoMixin { raw_ostream &OS; public: - explicit DXILResourcePrinterPass(raw_ostream &OS) : OS(OS) {} + explicit DXILResourceBindingPrinterPass(raw_ostream &OS) : OS(OS) {} PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); static bool isRequired() { return true; } }; -class DXILResourceWrapperPass : public ModulePass { - std::unique_ptr ResourceMap; +class DXILResourceBindingWrapperPass : public ModulePass { + std::unique_ptr Map; + DXILResourceTypeMap *DRTM; public: static char ID; // Class identification, replacement for typeinfo - DXILResourceWrapperPass(); - ~DXILResourceWrapperPass() override; + DXILResourceBindingWrapperPass(); + ~DXILResourceBindingWrapperPass() override; - const DXILResourceMap &getResourceMap() const { return *ResourceMap; } - DXILResourceMap &getResourceMap() { return *ResourceMap; } + const DXILBindingMap &getBindingMap() const { return *Map; } + DXILBindingMap &getBindingMap() { return *Map; } void getAnalysisUsage(AnalysisUsage &AU) const override; bool runOnModule(Module &M) override; @@ -381,7 +556,7 @@ class DXILResourceWrapperPass : public ModulePass { void dump() const; }; -ModulePass *createDXILResourceWrapperPassPass(); +ModulePass *createDXILResourceBindingWrapperPassPass(); } // namespace llvm diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h index 7d829cf5b9b01..1cb9013bc48cc 100644 --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -84,7 +84,8 @@ void initializeDAHPass(PassRegistry &); void initializeDCELegacyPassPass(PassRegistry &); void initializeDXILMetadataAnalysisWrapperPassPass(PassRegistry &); void initializeDXILMetadataAnalysisWrapperPrinterPass(PassRegistry &); -void initializeDXILResourceWrapperPassPass(PassRegistry &); +void initializeDXILResourceBindingWrapperPassPass(PassRegistry &); +void initializeDXILResourceTypeWrapperPassPass(PassRegistry &); void initializeDeadMachineInstructionElimPass(PassRegistry &); void initializeDebugifyMachineModulePass(PassRegistry &); void initializeDependenceAnalysisWrapperPassPass(PassRegistry &); diff --git a/llvm/include/llvm/LinkAllPasses.h b/llvm/include/llvm/LinkAllPasses.h index 54245ca0b7022..ac1970334de0c 100644 --- a/llvm/include/llvm/LinkAllPasses.h +++ b/llvm/include/llvm/LinkAllPasses.h @@ -70,7 +70,8 @@ struct ForcePassLinking { (void)llvm::createCallGraphViewerPass(); (void)llvm::createCFGSimplificationPass(); (void)llvm::createStructurizeCFGPass(); - (void)llvm::createDXILResourceWrapperPassPass(); + (void)llvm::createDXILResourceBindingWrapperPassPass(); + (void)llvm::createDXILResourceTypeWrapperPassPass(); (void)llvm::createDeadArgEliminationPass(); (void)llvm::createDeadCodeEliminationPass(); (void)llvm::createDependenceAnalysisWrapperPass(); diff --git a/llvm/lib/Analysis/Analysis.cpp b/llvm/lib/Analysis/Analysis.cpp index 58723469f21ca..bc2b8a57f83a7 100644 --- a/llvm/lib/Analysis/Analysis.cpp +++ b/llvm/lib/Analysis/Analysis.cpp @@ -25,7 +25,8 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeCallGraphDOTPrinterPass(Registry); initializeCallGraphViewerPass(Registry); initializeCycleInfoWrapperPassPass(Registry); - initializeDXILResourceWrapperPassPass(Registry); + initializeDXILResourceBindingWrapperPassPass(Registry); + initializeDXILResourceTypeWrapperPassPass(Registry); initializeDependenceAnalysisWrapperPassPass(Registry); initializeDominanceFrontierWrapperPassPass(Registry); initializeDomViewerWrapperPassPass(Registry); diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp index 2802480481690..9f992ee1a8277 100644 --- a/llvm/lib/Analysis/DXILResource.cpp +++ b/llvm/lib/Analysis/DXILResource.cpp @@ -8,6 +8,7 @@ #include "llvm/Analysis/DXILResource.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallString.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DiagnosticInfo.h" @@ -17,6 +18,7 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" +#include "llvm/Support/FormatVariadic.h" #define DEBUG_TYPE "dxil-resource" @@ -148,17 +150,159 @@ static StringRef getSamplerFeedbackTypeName(SamplerFeedbackType SFT) { llvm_unreachable("Unhandled SamplerFeedbackType"); } -bool ResourceInfo::isUAV() const { return RC == ResourceClass::UAV; } +static dxil::ElementType toDXILElementType(Type *Ty, bool IsSigned) { + // TODO: Handle unorm, snorm, and packed. + Ty = Ty->getScalarType(); + + if (Ty->isIntegerTy()) { + switch (Ty->getIntegerBitWidth()) { + case 16: + return IsSigned ? ElementType::I16 : ElementType::U16; + case 32: + return IsSigned ? ElementType::I32 : ElementType::U32; + case 64: + return IsSigned ? ElementType::I64 : ElementType::U64; + case 1: + default: + return ElementType::Invalid; + } + } else if (Ty->isFloatTy()) { + return ElementType::F32; + } else if (Ty->isDoubleTy()) { + return ElementType::F64; + } else if (Ty->isHalfTy()) { + return ElementType::F16; + } + + return ElementType::Invalid; +} + +ResourceTypeInfo::ResourceTypeInfo(TargetExtType *HandleTy, + const dxil::ResourceClass RC_, + const dxil::ResourceKind Kind_, + bool GloballyCoherent, bool HasCounter) + : HandleTy(HandleTy), GloballyCoherent(GloballyCoherent), + HasCounter(HasCounter) { + // If we're provided a resource class and kind, trust them. + if (Kind_ != dxil::ResourceKind::Invalid) { + RC = RC_; + Kind = Kind_; + return; + } + + if (auto *Ty = dyn_cast(HandleTy)) { + RC = Ty->isWriteable() ? ResourceClass::UAV : ResourceClass::SRV; + Kind = Ty->isStructured() ? ResourceKind::StructuredBuffer + : ResourceKind::RawBuffer; + } else if (auto *Ty = dyn_cast(HandleTy)) { + RC = Ty->isWriteable() ? ResourceClass::UAV : ResourceClass::SRV; + Kind = ResourceKind::TypedBuffer; + } else if (auto *Ty = dyn_cast(HandleTy)) { + RC = Ty->isWriteable() ? ResourceClass::UAV : ResourceClass::SRV; + Kind = Ty->getDimension(); + } else if (auto *Ty = dyn_cast(HandleTy)) { + RC = Ty->isWriteable() ? ResourceClass::UAV : ResourceClass::SRV; + Kind = Ty->getDimension(); + } else if (auto *Ty = dyn_cast(HandleTy)) { + RC = ResourceClass::UAV; + Kind = Ty->getDimension(); + } else if (isa(HandleTy)) { + RC = ResourceClass::CBuffer; + Kind = ResourceKind::CBuffer; + } else if (isa(HandleTy)) { + RC = ResourceClass::Sampler; + Kind = ResourceKind::Sampler; + } else + llvm_unreachable("Unknown handle type"); +} + +static void formatTypeName(SmallString<64> &Dest, StringRef Name, + bool isWriteable, bool isROV) { + Dest = isWriteable ? (isROV ? "RasterizerOrdered" : "RW") : ""; + Dest += Name; +} + +StructType *ResourceTypeInfo::createElementStruct() { + SmallString<64> TypeName; + + switch (Kind) { + case ResourceKind::Texture1D: + case ResourceKind::Texture2D: + case ResourceKind::Texture3D: + case ResourceKind::TextureCube: + case ResourceKind::Texture1DArray: + case ResourceKind::Texture2DArray: + case ResourceKind::TextureCubeArray: { + auto *RTy = cast(HandleTy); + formatTypeName(TypeName, getResourceKindName(Kind), RTy->isWriteable(), + RTy->isROV()); + return StructType::create(RTy->getResourceType(), TypeName); + } + case ResourceKind::Texture2DMS: + case ResourceKind::Texture2DMSArray: { + auto *RTy = cast(HandleTy); + formatTypeName(TypeName, getResourceKindName(Kind), RTy->isWriteable(), + /*IsROV=*/false); + return StructType::create(RTy->getResourceType(), TypeName); + } + case ResourceKind::TypedBuffer: { + auto *RTy = cast(HandleTy); + formatTypeName(TypeName, getResourceKindName(Kind), RTy->isWriteable(), + RTy->isROV()); + return StructType::create(RTy->getResourceType(), TypeName); + } + case ResourceKind::RawBuffer: { + auto *RTy = cast(HandleTy); + formatTypeName(TypeName, "ByteAddressBuffer", RTy->isWriteable(), + RTy->isROV()); + return StructType::create(HandleTy->getContext(), TypeName); + } + case ResourceKind::StructuredBuffer: { + auto *RTy = cast(HandleTy); + formatTypeName(TypeName, "StructuredBuffer", RTy->isWriteable(), + RTy->isROV()); + return StructType::create(RTy->getResourceType(), TypeName); + } + case ResourceKind::FeedbackTexture2D: + case ResourceKind::FeedbackTexture2DArray: { + auto *RTy = cast(HandleTy); + TypeName = formatv("{0}<{1}>", getResourceKindName(Kind), + llvm::to_underlying(RTy->getFeedbackType())); + return StructType::create(HandleTy->getContext(), TypeName); + } + case ResourceKind::CBuffer: + return StructType::create(HandleTy->getContext(), "cbuffer"); + case ResourceKind::Sampler: { + auto *RTy = cast(HandleTy); + TypeName = formatv("SamplerState<{0}>", + llvm::to_underlying(RTy->getSamplerType())); + return StructType::create(HandleTy->getContext(), TypeName); + } + case ResourceKind::TBuffer: + case ResourceKind::RTAccelerationStructure: + llvm_unreachable("Unhandled resource kind"); + case ResourceKind::Invalid: + case ResourceKind::NumEntries: + llvm_unreachable("Invalid resource kind"); + } + llvm_unreachable("Unhandled ResourceKind enum"); +} -bool ResourceInfo::isCBuffer() const { return RC == ResourceClass::CBuffer; } +bool ResourceTypeInfo::isUAV() const { return RC == ResourceClass::UAV; } -bool ResourceInfo::isSampler() const { return RC == ResourceClass::Sampler; } +bool ResourceTypeInfo::isCBuffer() const { + return RC == ResourceClass::CBuffer; +} + +bool ResourceTypeInfo::isSampler() const { + return RC == ResourceClass::Sampler; +} -bool ResourceInfo::isStruct() const { +bool ResourceTypeInfo::isStruct() const { return Kind == ResourceKind::StructuredBuffer; } -bool ResourceInfo::isTyped() const { +bool ResourceTypeInfo::isTyped() const { switch (Kind) { case ResourceKind::Texture1D: case ResourceKind::Texture2D: @@ -187,194 +331,216 @@ bool ResourceInfo::isTyped() const { llvm_unreachable("Unhandled ResourceKind enum"); } -bool ResourceInfo::isFeedback() const { +bool ResourceTypeInfo::isFeedback() const { return Kind == ResourceKind::FeedbackTexture2D || Kind == ResourceKind::FeedbackTexture2DArray; } -bool ResourceInfo::isMultiSample() const { +bool ResourceTypeInfo::isMultiSample() const { return Kind == ResourceKind::Texture2DMS || Kind == ResourceKind::Texture2DMSArray; } -ResourceInfo ResourceInfo::SRV(Value *Symbol, StringRef Name, - ElementType ElementTy, uint32_t ElementCount, - ResourceKind Kind) { - ResourceInfo RI(ResourceClass::SRV, Kind, Symbol, Name); - assert(RI.isTyped() && !(RI.isStruct() || RI.isMultiSample()) && - "Invalid ResourceKind for SRV constructor."); - RI.setTyped(ElementTy, ElementCount); - return RI; -} - -ResourceInfo ResourceInfo::RawBuffer(Value *Symbol, StringRef Name) { - ResourceInfo RI(ResourceClass::SRV, ResourceKind::RawBuffer, Symbol, Name); - return RI; -} - -ResourceInfo ResourceInfo::StructuredBuffer(Value *Symbol, StringRef Name, - uint32_t Stride, - MaybeAlign Alignment) { - ResourceInfo RI(ResourceClass::SRV, ResourceKind::StructuredBuffer, Symbol, - Name); - RI.setStruct(Stride, Alignment); - return RI; -} - -ResourceInfo ResourceInfo::Texture2DMS(Value *Symbol, StringRef Name, - ElementType ElementTy, - uint32_t ElementCount, - uint32_t SampleCount) { - ResourceInfo RI(ResourceClass::SRV, ResourceKind::Texture2DMS, Symbol, Name); - RI.setTyped(ElementTy, ElementCount); - RI.setMultiSample(SampleCount); - return RI; -} - -ResourceInfo ResourceInfo::Texture2DMSArray(Value *Symbol, StringRef Name, - ElementType ElementTy, - uint32_t ElementCount, - uint32_t SampleCount) { - ResourceInfo RI(ResourceClass::SRV, ResourceKind::Texture2DMSArray, Symbol, - Name); - RI.setTyped(ElementTy, ElementCount); - RI.setMultiSample(SampleCount); - return RI; -} - -ResourceInfo ResourceInfo::UAV(Value *Symbol, StringRef Name, - ElementType ElementTy, uint32_t ElementCount, - bool GloballyCoherent, bool IsROV, - ResourceKind Kind) { - ResourceInfo RI(ResourceClass::UAV, Kind, Symbol, Name); - assert(RI.isTyped() && !(RI.isStruct() || RI.isMultiSample()) && - "Invalid ResourceKind for UAV constructor."); - RI.setTyped(ElementTy, ElementCount); - RI.setUAV(GloballyCoherent, /*HasCounter=*/false, IsROV); - return RI; -} - -ResourceInfo ResourceInfo::RWRawBuffer(Value *Symbol, StringRef Name, - bool GloballyCoherent, bool IsROV) { - ResourceInfo RI(ResourceClass::UAV, ResourceKind::RawBuffer, Symbol, Name); - RI.setUAV(GloballyCoherent, /*HasCounter=*/false, IsROV); - return RI; -} - -ResourceInfo ResourceInfo::RWStructuredBuffer(Value *Symbol, StringRef Name, - uint32_t Stride, - MaybeAlign Alignment, - bool GloballyCoherent, bool IsROV, - bool HasCounter) { - ResourceInfo RI(ResourceClass::UAV, ResourceKind::StructuredBuffer, Symbol, - Name); - RI.setStruct(Stride, Alignment); - RI.setUAV(GloballyCoherent, HasCounter, IsROV); - return RI; -} - -ResourceInfo ResourceInfo::RWTexture2DMS(Value *Symbol, StringRef Name, - ElementType ElementTy, - uint32_t ElementCount, - uint32_t SampleCount, - bool GloballyCoherent) { - ResourceInfo RI(ResourceClass::UAV, ResourceKind::Texture2DMS, Symbol, Name); - RI.setTyped(ElementTy, ElementCount); - RI.setUAV(GloballyCoherent, /*HasCounter=*/false, /*IsROV=*/false); - RI.setMultiSample(SampleCount); - return RI; -} - -ResourceInfo ResourceInfo::RWTexture2DMSArray(Value *Symbol, StringRef Name, - ElementType ElementTy, - uint32_t ElementCount, - uint32_t SampleCount, - bool GloballyCoherent) { - ResourceInfo RI(ResourceClass::UAV, ResourceKind::Texture2DMSArray, Symbol, - Name); - RI.setTyped(ElementTy, ElementCount); - RI.setUAV(GloballyCoherent, /*HasCounter=*/false, /*IsROV=*/false); - RI.setMultiSample(SampleCount); - return RI; -} - -ResourceInfo ResourceInfo::FeedbackTexture2D(Value *Symbol, StringRef Name, - SamplerFeedbackType FeedbackTy) { - ResourceInfo RI(ResourceClass::UAV, ResourceKind::FeedbackTexture2D, Symbol, - Name); - RI.setUAV(/*GloballyCoherent=*/false, /*HasCounter=*/false, /*IsROV=*/false); - RI.setFeedback(FeedbackTy); - return RI; -} - -ResourceInfo -ResourceInfo::FeedbackTexture2DArray(Value *Symbol, StringRef Name, - SamplerFeedbackType FeedbackTy) { - ResourceInfo RI(ResourceClass::UAV, ResourceKind::FeedbackTexture2DArray, - Symbol, Name); - RI.setUAV(/*GloballyCoherent=*/false, /*HasCounter=*/false, /*IsROV=*/false); - RI.setFeedback(FeedbackTy); - return RI; -} - -ResourceInfo ResourceInfo::CBuffer(Value *Symbol, StringRef Name, - uint32_t Size) { - ResourceInfo RI(ResourceClass::CBuffer, ResourceKind::CBuffer, Symbol, Name); - RI.setCBuffer(Size); - return RI; -} - -ResourceInfo ResourceInfo::Sampler(Value *Symbol, StringRef Name, - SamplerType SamplerTy) { - ResourceInfo RI(ResourceClass::Sampler, ResourceKind::Sampler, Symbol, Name); - RI.setSampler(SamplerTy); - return RI; -} - -bool ResourceInfo::operator==(const ResourceInfo &RHS) const { - if (std::tie(Symbol, Name, Binding, RC, Kind) != - std::tie(RHS.Symbol, RHS.Name, RHS.Binding, RHS.RC, RHS.Kind)) - return false; - if (isCBuffer() && RHS.isCBuffer() && CBufferSize != RHS.CBufferSize) - return false; - if (isSampler() && RHS.isSampler() && SamplerTy != RHS.SamplerTy) - return false; - if (isUAV() && RHS.isUAV() && UAVFlags != RHS.UAVFlags) - return false; - if (isStruct() && RHS.isStruct() && Struct != RHS.Struct) - return false; - if (isFeedback() && RHS.isFeedback() && Feedback != RHS.Feedback) - return false; - if (isTyped() && RHS.isTyped() && Typed != RHS.Typed) - return false; - if (isMultiSample() && RHS.isMultiSample() && MultiSample != RHS.MultiSample) +static bool isROV(dxil::ResourceKind Kind, TargetExtType *Ty) { + switch (Kind) { + case ResourceKind::Texture1D: + case ResourceKind::Texture2D: + case ResourceKind::Texture3D: + case ResourceKind::TextureCube: + case ResourceKind::Texture1DArray: + case ResourceKind::Texture2DArray: + case ResourceKind::TextureCubeArray: + return cast(Ty)->isROV(); + case ResourceKind::TypedBuffer: + return cast(Ty)->isROV(); + case ResourceKind::RawBuffer: + case ResourceKind::StructuredBuffer: + return cast(Ty)->isROV(); + case ResourceKind::Texture2DMS: + case ResourceKind::Texture2DMSArray: + case ResourceKind::FeedbackTexture2D: + case ResourceKind::FeedbackTexture2DArray: return false; - return true; + case ResourceKind::CBuffer: + case ResourceKind::Sampler: + case ResourceKind::TBuffer: + case ResourceKind::RTAccelerationStructure: + case ResourceKind::Invalid: + case ResourceKind::NumEntries: + llvm_unreachable("Resource cannot be ROV"); + } + llvm_unreachable("Unhandled ResourceKind enum"); +} + +ResourceTypeInfo::UAVInfo ResourceTypeInfo::getUAV() const { + assert(isUAV() && "Not a UAV"); + return {GloballyCoherent, HasCounter, isROV(Kind, HandleTy)}; +} + +uint32_t ResourceTypeInfo::getCBufferSize(const DataLayout &DL) const { + assert(isCBuffer() && "Not a CBuffer"); + Type *Ty = cast(HandleTy)->getResourceType(); + return DL.getTypeSizeInBits(Ty) / 8; +} + +dxil::SamplerType ResourceTypeInfo::getSamplerType() const { + assert(isSampler() && "Not a Sampler"); + return cast(HandleTy)->getSamplerType(); +} + +ResourceTypeInfo::StructInfo +ResourceTypeInfo::getStruct(const DataLayout &DL) const { + assert(isStruct() && "Not a Struct"); + + Type *ElTy = cast(HandleTy)->getResourceType(); + + uint32_t Stride = DL.getTypeAllocSize(ElTy); + MaybeAlign Alignment; + if (auto *STy = dyn_cast(ElTy)) + Alignment = DL.getStructLayout(STy)->getAlignment(); + uint32_t AlignLog2 = Alignment ? Log2(*Alignment) : 0; + return {Stride, AlignLog2}; +} + +static std::pair getTypedElementType(dxil::ResourceKind Kind, + TargetExtType *Ty) { + switch (Kind) { + case ResourceKind::Texture1D: + case ResourceKind::Texture2D: + case ResourceKind::Texture3D: + case ResourceKind::TextureCube: + case ResourceKind::Texture1DArray: + case ResourceKind::Texture2DArray: + case ResourceKind::TextureCubeArray: { + auto *RTy = cast(Ty); + return {RTy->getResourceType(), RTy->isSigned()}; + } + case ResourceKind::Texture2DMS: + case ResourceKind::Texture2DMSArray: { + auto *RTy = cast(Ty); + return {RTy->getResourceType(), RTy->isSigned()}; + } + case ResourceKind::TypedBuffer: { + auto *RTy = cast(Ty); + return {RTy->getResourceType(), RTy->isSigned()}; + } + case ResourceKind::RawBuffer: + case ResourceKind::StructuredBuffer: + case ResourceKind::FeedbackTexture2D: + case ResourceKind::FeedbackTexture2DArray: + case ResourceKind::CBuffer: + case ResourceKind::Sampler: + case ResourceKind::TBuffer: + case ResourceKind::RTAccelerationStructure: + case ResourceKind::Invalid: + case ResourceKind::NumEntries: + llvm_unreachable("Resource is not typed"); + } + llvm_unreachable("Unhandled ResourceKind enum"); +} + +ResourceTypeInfo::TypedInfo ResourceTypeInfo::getTyped() const { + assert(isTyped() && "Not typed"); + + auto [ElTy, IsSigned] = getTypedElementType(Kind, HandleTy); + dxil::ElementType ET = toDXILElementType(ElTy, IsSigned); + uint32_t Count = 1; + if (auto *VTy = dyn_cast(ElTy)) + Count = VTy->getNumElements(); + return {ET, Count}; } -bool ResourceInfo::operator<(const ResourceInfo &RHS) const { - // Skip the symbol to avoid non-determinism, and the name to keep a consistent - // ordering even when we strip reflection data. - if (std::tie(Binding, RC, Kind) < std::tie(RHS.Binding, RHS.RC, RHS.Kind)) +dxil::SamplerFeedbackType ResourceTypeInfo::getFeedbackType() const { + assert(isFeedback() && "Not Feedback"); + return cast(HandleTy)->getFeedbackType(); +} +uint32_t ResourceTypeInfo::getMultiSampleCount() const { + assert(isMultiSample() && "Not MultiSampled"); + return cast(HandleTy)->getSampleCount(); +} + +bool ResourceTypeInfo::operator==(const ResourceTypeInfo &RHS) const { + return std::tie(HandleTy, GloballyCoherent, HasCounter) == + std::tie(RHS.HandleTy, RHS.GloballyCoherent, RHS.HasCounter); +} + +bool ResourceTypeInfo::operator<(const ResourceTypeInfo &RHS) const { + // An empty datalayout is sufficient for sorting purposes. + DataLayout DummyDL; + if (std::tie(RC, Kind) < std::tie(RHS.RC, RHS.Kind)) return true; - if (isCBuffer() && RHS.isCBuffer() && CBufferSize < RHS.CBufferSize) + if (isCBuffer() && RHS.isCBuffer() && + getCBufferSize(DummyDL) < RHS.getCBufferSize(DummyDL)) return true; - if (isSampler() && RHS.isSampler() && SamplerTy < RHS.SamplerTy) + if (isSampler() && RHS.isSampler() && getSamplerType() < RHS.getSamplerType()) return true; - if (isUAV() && RHS.isUAV() && UAVFlags < RHS.UAVFlags) + if (isUAV() && RHS.isUAV() && getUAV() < RHS.getUAV()) return true; - if (isStruct() && RHS.isStruct() && Struct < RHS.Struct) + if (isStruct() && RHS.isStruct() && + getStruct(DummyDL) < RHS.getStruct(DummyDL)) return true; - if (isFeedback() && RHS.isFeedback() && Feedback < RHS.Feedback) + if (isFeedback() && RHS.isFeedback() && + getFeedbackType() < RHS.getFeedbackType()) return true; - if (isTyped() && RHS.isTyped() && Typed < RHS.Typed) + if (isTyped() && RHS.isTyped() && getTyped() < RHS.getTyped()) return true; - if (isMultiSample() && RHS.isMultiSample() && MultiSample < RHS.MultiSample) + if (isMultiSample() && RHS.isMultiSample() && + getMultiSampleCount() < RHS.getMultiSampleCount()) return true; return false; } -MDTuple *ResourceInfo::getAsMetadata(LLVMContext &Ctx) const { +void ResourceTypeInfo::print(raw_ostream &OS, const DataLayout &DL) const { + OS << " Class: " << getResourceClassName(RC) << "\n" + << " Kind: " << getResourceKindName(Kind) << "\n"; + + if (isCBuffer()) { + OS << " CBuffer size: " << getCBufferSize(DL) << "\n"; + } else if (isSampler()) { + OS << " Sampler Type: " << getSamplerTypeName(getSamplerType()) << "\n"; + } else { + if (isUAV()) { + UAVInfo UAVFlags = getUAV(); + OS << " Globally Coherent: " << UAVFlags.GloballyCoherent << "\n" + << " HasCounter: " << UAVFlags.HasCounter << "\n" + << " IsROV: " << UAVFlags.IsROV << "\n"; + } + if (isMultiSample()) + OS << " Sample Count: " << getMultiSampleCount() << "\n"; + + if (isStruct()) { + StructInfo Struct = getStruct(DL); + OS << " Buffer Stride: " << Struct.Stride << "\n"; + OS << " Alignment: " << Struct.AlignLog2 << "\n"; + } else if (isTyped()) { + TypedInfo Typed = getTyped(); + OS << " Element Type: " << getElementTypeName(Typed.ElementTy) << "\n" + << " Element Count: " << Typed.ElementCount << "\n"; + } else if (isFeedback()) + OS << " Feedback Type: " << getSamplerFeedbackTypeName(getFeedbackType()) + << "\n"; + } +} + +GlobalVariable *ResourceBindingInfo::createSymbol(Module &M, StructType *Ty, + StringRef Name) { + assert(!Symbol && "Symbol has already been created"); + Symbol = new GlobalVariable(M, Ty, /*isConstant=*/true, + GlobalValue::ExternalLinkage, + /*Initializer=*/nullptr, Name); + return Symbol; +} + +MDTuple *ResourceBindingInfo::getAsMetadata(Module &M, + DXILResourceTypeMap &DRTM) const { + return getAsMetadata(M, DRTM[getHandleTy()]); +} + +MDTuple *ResourceBindingInfo::getAsMetadata(Module &M, + dxil::ResourceTypeInfo RTI) const { + LLVMContext &Ctx = M.getContext(); + const DataLayout &DL = M.getDataLayout(); + SmallVector MDVals; Type *I32Ty = Type::getInt32Ty(Ctx); @@ -389,22 +555,24 @@ MDTuple *ResourceInfo::getAsMetadata(LLVMContext &Ctx) const { }; MDVals.push_back(getIntMD(Binding.RecordID)); + assert(Symbol && "Cannot yet create useful resource metadata without symbol"); MDVals.push_back(ValueAsMetadata::get(Symbol)); - MDVals.push_back(MDString::get(Ctx, Name)); + MDVals.push_back(MDString::get(Ctx, Symbol->getName())); MDVals.push_back(getIntMD(Binding.Space)); MDVals.push_back(getIntMD(Binding.LowerBound)); MDVals.push_back(getIntMD(Binding.Size)); - if (isCBuffer()) { - MDVals.push_back(getIntMD(CBufferSize)); + if (RTI.isCBuffer()) { + MDVals.push_back(getIntMD(RTI.getCBufferSize(DL))); MDVals.push_back(nullptr); - } else if (isSampler()) { - MDVals.push_back(getIntMD(llvm::to_underlying(SamplerTy))); + } else if (RTI.isSampler()) { + MDVals.push_back(getIntMD(llvm::to_underlying(RTI.getSamplerType()))); MDVals.push_back(nullptr); } else { - MDVals.push_back(getIntMD(llvm::to_underlying(Kind))); + MDVals.push_back(getIntMD(llvm::to_underlying(RTI.getResourceKind()))); - if (isUAV()) { + if (RTI.isUAV()) { + ResourceTypeInfo::UAVInfo UAVFlags = RTI.getUAV(); MDVals.push_back(getBoolMD(UAVFlags.GloballyCoherent)); MDVals.push_back(getBoolMD(UAVFlags.HasCounter)); MDVals.push_back(getBoolMD(UAVFlags.IsROV)); @@ -412,23 +580,24 @@ MDTuple *ResourceInfo::getAsMetadata(LLVMContext &Ctx) const { // All SRVs include sample count in the metadata, but it's only meaningful // for multi-sampled textured. Also, UAVs can be multisampled in SM6.7+, // but this just isn't reflected in the metadata at all. - uint32_t SampleCount = isMultiSample() ? MultiSample.Count : 0; + uint32_t SampleCount = + RTI.isMultiSample() ? RTI.getMultiSampleCount() : 0; MDVals.push_back(getIntMD(SampleCount)); } // Further properties are attached to a metadata list of tag-value pairs. SmallVector Tags; - if (isStruct()) { + if (RTI.isStruct()) { Tags.push_back( getIntMD(llvm::to_underlying(ExtPropTags::StructuredBufferStride))); - Tags.push_back(getIntMD(Struct.Stride)); - } else if (isTyped()) { + Tags.push_back(getIntMD(RTI.getStruct(DL).Stride)); + } else if (RTI.isTyped()) { Tags.push_back(getIntMD(llvm::to_underlying(ExtPropTags::ElementType))); - Tags.push_back(getIntMD(llvm::to_underlying(Typed.ElementTy))); - } else if (isFeedback()) { + Tags.push_back(getIntMD(llvm::to_underlying(RTI.getTyped().ElementTy))); + } else if (RTI.isFeedback()) { Tags.push_back( getIntMD(llvm::to_underlying(ExtPropTags::SamplerFeedbackKind))); - Tags.push_back(getIntMD(llvm::to_underlying(Feedback.Type))); + Tags.push_back(getIntMD(llvm::to_underlying(RTI.getFeedbackType()))); } MDVals.push_back(Tags.empty() ? nullptr : MDNode::get(Ctx, Tags)); } @@ -436,17 +605,29 @@ MDTuple *ResourceInfo::getAsMetadata(LLVMContext &Ctx) const { return MDNode::get(Ctx, MDVals); } -std::pair ResourceInfo::getAnnotateProps() const { - uint32_t ResourceKind = llvm::to_underlying(Kind); - uint32_t AlignLog2 = isStruct() ? Struct.AlignLog2 : 0; - bool IsUAV = isUAV(); +std::pair +ResourceBindingInfo::getAnnotateProps(Module &M, + DXILResourceTypeMap &DRTM) const { + return getAnnotateProps(M, DRTM[getHandleTy()]); +} + +std::pair +ResourceBindingInfo::getAnnotateProps(Module &M, + dxil::ResourceTypeInfo RTI) const { + const DataLayout &DL = M.getDataLayout(); + + uint32_t ResourceKind = llvm::to_underlying(RTI.getResourceKind()); + uint32_t AlignLog2 = RTI.isStruct() ? RTI.getStruct(DL).AlignLog2 : 0; + bool IsUAV = RTI.isUAV(); + ResourceTypeInfo::UAVInfo UAVFlags = + IsUAV ? RTI.getUAV() : ResourceTypeInfo::UAVInfo{}; bool IsROV = IsUAV && UAVFlags.IsROV; bool IsGloballyCoherent = IsUAV && UAVFlags.GloballyCoherent; uint8_t SamplerCmpOrHasCounter = 0; if (IsUAV) SamplerCmpOrHasCounter = UAVFlags.HasCounter; - else if (isSampler()) - SamplerCmpOrHasCounter = SamplerTy == SamplerType::Comparison; + else if (RTI.isSampler()) + SamplerCmpOrHasCounter = RTI.getSamplerType() == SamplerType::Comparison; // TODO: Document this format. Currently the only reference is the // implementation of dxc's DxilResourceProperties struct. @@ -459,16 +640,17 @@ std::pair ResourceInfo::getAnnotateProps() const { Word0 |= (SamplerCmpOrHasCounter & 1) << 15; uint32_t Word1 = 0; - if (isStruct()) - Word1 = Struct.Stride; - else if (isCBuffer()) - Word1 = CBufferSize; - else if (isFeedback()) - Word1 = llvm::to_underlying(Feedback.Type); - else if (isTyped()) { + if (RTI.isStruct()) + Word1 = RTI.getStruct(DL).Stride; + else if (RTI.isCBuffer()) + Word1 = RTI.getCBufferSize(DL); + else if (RTI.isFeedback()) + Word1 = llvm::to_underlying(RTI.getFeedbackType()); + else if (RTI.isTyped()) { + ResourceTypeInfo::TypedInfo Typed = RTI.getTyped(); uint32_t CompType = llvm::to_underlying(Typed.ElementTy); uint32_t CompCount = Typed.ElementCount; - uint32_t SampleCount = isMultiSample() ? MultiSample.Count : 0; + uint32_t SampleCount = RTI.isMultiSample() ? RTI.getMultiSampleCount() : 0; Word1 |= (CompType & 0xFF) << 0; Word1 |= (CompCount & 0xFF) << 8; @@ -478,276 +660,119 @@ std::pair ResourceInfo::getAnnotateProps() const { return {Word0, Word1}; } -void ResourceInfo::print(raw_ostream &OS) const { - OS << " Symbol: "; - Symbol->printAsOperand(OS); - OS << "\n"; +void ResourceBindingInfo::print(raw_ostream &OS, DXILResourceTypeMap &DRTM, + const DataLayout &DL) const { + print(OS, DRTM[getHandleTy()], DL); +} + +void ResourceBindingInfo::print(raw_ostream &OS, dxil::ResourceTypeInfo RTI, + const DataLayout &DL) const { + if (Symbol) { + OS << " Symbol: "; + Symbol->printAsOperand(OS); + OS << "\n"; + } - OS << " Name: \"" << Name << "\"\n" - << " Binding:\n" + OS << " Binding:\n" << " Record ID: " << Binding.RecordID << "\n" << " Space: " << Binding.Space << "\n" << " Lower Bound: " << Binding.LowerBound << "\n" - << " Size: " << Binding.Size << "\n" - << " Class: " << getResourceClassName(RC) << "\n" - << " Kind: " << getResourceKindName(Kind) << "\n"; + << " Size: " << Binding.Size << "\n"; - if (isCBuffer()) { - OS << " CBuffer size: " << CBufferSize << "\n"; - } else if (isSampler()) { - OS << " Sampler Type: " << getSamplerTypeName(SamplerTy) << "\n"; - } else { - if (isUAV()) { - OS << " Globally Coherent: " << UAVFlags.GloballyCoherent << "\n" - << " HasCounter: " << UAVFlags.HasCounter << "\n" - << " IsROV: " << UAVFlags.IsROV << "\n"; - } - if (isMultiSample()) - OS << " Sample Count: " << MultiSample.Count << "\n"; - - if (isStruct()) { - OS << " Buffer Stride: " << Struct.Stride << "\n"; - OS << " Alignment: " << Struct.AlignLog2 << "\n"; - } else if (isTyped()) { - OS << " Element Type: " << getElementTypeName(Typed.ElementTy) << "\n" - << " Element Count: " << Typed.ElementCount << "\n"; - } else if (isFeedback()) - OS << " Feedback Type: " << getSamplerFeedbackTypeName(Feedback.Type) - << "\n"; - } + RTI.print(OS, DL); } //===----------------------------------------------------------------------===// -// ResourceMapper - -static dxil::ElementType toDXILElementType(Type *Ty, bool IsSigned) { - // TODO: Handle unorm, snorm, and packed. - Ty = Ty->getScalarType(); - - if (Ty->isIntegerTy()) { - switch (Ty->getIntegerBitWidth()) { - case 16: - return IsSigned ? ElementType::I16 : ElementType::U16; - case 32: - return IsSigned ? ElementType::I32 : ElementType::U32; - case 64: - return IsSigned ? ElementType::I64 : ElementType::U64; - case 1: - default: - return ElementType::Invalid; - } - } else if (Ty->isFloatTy()) { - return ElementType::F32; - } else if (Ty->isDoubleTy()) { - return ElementType::F64; - } else if (Ty->isHalfTy()) { - return ElementType::F16; - } - return ElementType::Invalid; +bool DXILResourceTypeMap::invalidate(Module &M, const PreservedAnalyses &PA, + ModuleAnalysisManager::Invalidator &Inv) { + // Passes that introduce resource types must explicitly invalidate this pass. + auto PAC = PA.getChecker(); + return !PAC.preservedWhenStateless(); } -namespace { - -class ResourceMapper { - Module &M; - LLVMContext &Context; - SmallVector> Resources; - -public: - ResourceMapper(Module &M) : M(M), Context(M.getContext()) {} - - void diagnoseHandle(CallInst *CI, const Twine &Msg, - DiagnosticSeverity Severity = DS_Error) { - std::string S; - raw_string_ostream SS(S); - CI->printAsOperand(SS); - DiagnosticInfoUnsupported Diag(*CI->getFunction(), Msg + ": " + SS.str(), - CI->getDebugLoc(), Severity); - Context.diagnose(Diag); - } - - ResourceInfo *mapBufferType(CallInst *CI, TargetExtType *HandleTy, - bool IsTyped) { - if (HandleTy->getNumTypeParameters() != 1 || - HandleTy->getNumIntParameters() != (IsTyped ? 3 : 2)) { - diagnoseHandle(CI, Twine("Invalid buffer target type")); - return nullptr; - } - - Type *ElTy = HandleTy->getTypeParameter(0); - unsigned IsWriteable = HandleTy->getIntParameter(0); - unsigned IsROV = HandleTy->getIntParameter(1); - bool IsSigned = IsTyped && HandleTy->getIntParameter(2); - - ResourceClass RC = IsWriteable ? ResourceClass::UAV : ResourceClass::SRV; - ResourceKind Kind; - if (IsTyped) - Kind = ResourceKind::TypedBuffer; - else if (ElTy->isIntegerTy(8)) - Kind = ResourceKind::RawBuffer; - else - Kind = ResourceKind::StructuredBuffer; - - // TODO: We need to lower to a typed pointer, can we smuggle the type - // through? - Value *Symbol = UndefValue::get(PointerType::getUnqual(Context)); - // TODO: We don't actually keep track of the name right now... - StringRef Name = ""; - - // Note that we return a pointer into the vector's storage. This is okay as - // long as we don't add more elements until we're done with the pointer. - auto &Pair = - Resources.emplace_back(CI, ResourceInfo{RC, Kind, Symbol, Name}); - ResourceInfo *RI = &Pair.second; - - if (RI->isUAV()) - // TODO: We need analysis for GloballyCoherent and HasCounter - RI->setUAV(false, false, IsROV); - - if (RI->isTyped()) { - dxil::ElementType ET = toDXILElementType(ElTy, IsSigned); - uint32_t Count = 1; - if (auto *VTy = dyn_cast(ElTy)) - Count = VTy->getNumElements(); - RI->setTyped(ET, Count); - } else if (RI->isStruct()) { - const DataLayout &DL = M.getDataLayout(); - - // This mimics what DXC does. Notably, we only ever set the alignment if - // the type is actually a struct type. - uint32_t Stride = DL.getTypeAllocSize(ElTy); - MaybeAlign Alignment; - if (auto *STy = dyn_cast(ElTy)) - Alignment = DL.getStructLayout(STy)->getAlignment(); - RI->setStruct(Stride, Alignment); - } - - return RI; - } - - ResourceInfo *mapHandleIntrin(CallInst *CI) { - FunctionType *FTy = CI->getFunctionType(); - Type *RetTy = FTy->getReturnType(); - auto *HandleTy = dyn_cast(RetTy); - if (!HandleTy) { - diagnoseHandle(CI, "dx.handle.fromBinding requires target type"); - return nullptr; - } - - StringRef TypeName = HandleTy->getName(); - if (TypeName == "dx.TypedBuffer") { - return mapBufferType(CI, HandleTy, /*IsTyped=*/true); - } else if (TypeName == "dx.RawBuffer") { - return mapBufferType(CI, HandleTy, /*IsTyped=*/false); - } else if (TypeName == "dx.CBuffer") { - // TODO: implement - diagnoseHandle(CI, "dx.CBuffer handles are not implemented yet"); - return nullptr; - } else if (TypeName == "dx.Sampler") { - // TODO: implement - diagnoseHandle(CI, "dx.Sampler handles are not implemented yet"); - return nullptr; - } else if (TypeName == "dx.Texture") { - // TODO: implement - diagnoseHandle(CI, "dx.Texture handles are not implemented yet"); - return nullptr; - } - - diagnoseHandle(CI, "Invalid target(dx) type"); - return nullptr; - } - - ResourceInfo *mapHandleFromBinding(CallInst *CI) { - assert(CI->getIntrinsicID() == Intrinsic::dx_handle_fromBinding && - "Must be dx.handle.fromBinding intrinsic"); - - ResourceInfo *RI = mapHandleIntrin(CI); - if (!RI) - return nullptr; - - uint32_t Space = cast(CI->getArgOperand(0))->getZExtValue(); - uint32_t LowerBound = - cast(CI->getArgOperand(1))->getZExtValue(); - uint32_t Size = cast(CI->getArgOperand(2))->getZExtValue(); +//===----------------------------------------------------------------------===// - // We use a binding ID of zero for now - these will be filled in later. - RI->bind(0U, Space, LowerBound, Size); +void DXILBindingMap::populate(Module &M, DXILResourceTypeMap &DRTM) { + SmallVector> + CIToInfos; - return RI; - } + for (Function &F : M.functions()) { + if (!F.isDeclaration()) + continue; + LLVM_DEBUG(dbgs() << "Function: " << F.getName() << "\n"); + Intrinsic::ID ID = F.getIntrinsicID(); + switch (ID) { + default: + continue; + case Intrinsic::dx_handle_fromBinding: { + auto *HandleTy = cast(F.getReturnType()); + ResourceTypeInfo RTI = DRTM[HandleTy]; - DXILResourceMap mapResources() { - for (Function &F : M.functions()) { - if (!F.isDeclaration()) - continue; - LLVM_DEBUG(dbgs() << "Function: " << F.getName() << "\n"); - Intrinsic::ID ID = F.getIntrinsicID(); - switch (ID) { - default: - // TODO: handle `dx.op` functions. - continue; - case Intrinsic::dx_handle_fromBinding: - for (User *U : F.users()) { + for (User *U : F.users()) + if (CallInst *CI = dyn_cast(U)) { LLVM_DEBUG(dbgs() << " Visiting: " << *U << "\n"); - if (CallInst *CI = dyn_cast(U)) - mapHandleFromBinding(CI); + uint32_t Space = + cast(CI->getArgOperand(0))->getZExtValue(); + uint32_t LowerBound = + cast(CI->getArgOperand(1))->getZExtValue(); + uint32_t Size = + cast(CI->getArgOperand(2))->getZExtValue(); + ResourceBindingInfo RBI = ResourceBindingInfo{ + /*RecordID=*/0, Space, LowerBound, Size, HandleTy}; + + CIToInfos.emplace_back(CI, RBI, RTI); } - break; - } - } - return DXILResourceMap(std::move(Resources)); + break; + } + } } -}; -} // namespace + llvm::stable_sort(CIToInfos, [](auto &LHS, auto &RHS) { + const auto &[LCI, LRBI, LRTI] = LHS; + const auto &[RCI, RRBI, RRTI] = RHS; + // Sort by resource class first for grouping purposes, and then by the + // binding and type so we can remove duplicates. + ResourceClass LRC = LRTI.getResourceClass(); + ResourceClass RRC = RRTI.getResourceClass(); -DXILResourceMap::DXILResourceMap( - SmallVectorImpl> &&CIToRI) { - if (CIToRI.empty()) - return; - - llvm::stable_sort(CIToRI, [](auto &LHS, auto &RHS) { - // Sort by resource class first for grouping purposes, and then by the rest - // of the fields so that we can remove duplicates. - ResourceClass LRC = LHS.second.getResourceClass(); - ResourceClass RRC = RHS.second.getResourceClass(); - return std::tie(LRC, LHS.second) < std::tie(RRC, RHS.second); + return std::tie(LRC, LRBI, LRTI) < std::tie(RRC, RRBI, RRTI); }); - for (auto [CI, RI] : CIToRI) { - if (Resources.empty() || RI != Resources.back()) - Resources.push_back(RI); - CallMap[CI] = Resources.size() - 1; + for (auto [CI, RBI, RTI] : CIToInfos) { + if (Infos.empty() || RBI != Infos.back()) + Infos.push_back(RBI); + CallMap[CI] = Infos.size() - 1; } - unsigned Size = Resources.size(); + unsigned Size = Infos.size(); // In DXC, Record ID is unique per resource type. Match that. FirstUAV = FirstCBuffer = FirstSampler = Size; uint32_t NextID = 0; for (unsigned I = 0, E = Size; I != E; ++I) { - ResourceInfo &RI = Resources[I]; - if (RI.isUAV() && FirstUAV == Size) { + ResourceBindingInfo &RBI = Infos[I]; + ResourceTypeInfo RTI = DRTM[RBI.getHandleTy()]; + if (RTI.isUAV() && FirstUAV == Size) { FirstUAV = I; NextID = 0; - } else if (RI.isCBuffer() && FirstCBuffer == Size) { + } else if (RTI.isCBuffer() && FirstCBuffer == Size) { FirstCBuffer = I; NextID = 0; - } else if (RI.isSampler() && FirstSampler == Size) { + } else if (RTI.isSampler() && FirstSampler == Size) { FirstSampler = I; NextID = 0; } // Adjust the resource binding to use the next ID. - const ResourceInfo::ResourceBinding &Binding = RI.getBinding(); - RI.bind(NextID++, Binding.Space, Binding.LowerBound, Binding.Size); + RBI.setBindingID(NextID++); } } -void DXILResourceMap::print(raw_ostream &OS) const { - for (unsigned I = 0, E = Resources.size(); I != E; ++I) { +void DXILBindingMap::print(raw_ostream &OS, DXILResourceTypeMap &DRTM, + const DataLayout &DL) const { + for (unsigned I = 0, E = Infos.size(); I != E; ++I) { OS << "Binding " << I << ":\n"; - Resources[I].print(OS); + Infos[I].print(OS, DRTM, DL); OS << "\n"; } @@ -759,61 +784,83 @@ void DXILResourceMap::print(raw_ostream &OS) const { } //===----------------------------------------------------------------------===// -// DXILResourceAnalysis and DXILResourcePrinterPass -// Provide an explicit template instantiation for the static ID. -AnalysisKey DXILResourceAnalysis::Key; +AnalysisKey DXILResourceTypeAnalysis::Key; +AnalysisKey DXILResourceBindingAnalysis::Key; -DXILResourceMap DXILResourceAnalysis::run(Module &M, - ModuleAnalysisManager &AM) { - DXILResourceMap Data = ResourceMapper(M).mapResources(); +DXILBindingMap DXILResourceBindingAnalysis::run(Module &M, + ModuleAnalysisManager &AM) { + DXILBindingMap Data; + DXILResourceTypeMap &DRTM = AM.getResult(M); + Data.populate(M, DRTM); return Data; } -PreservedAnalyses DXILResourcePrinterPass::run(Module &M, - ModuleAnalysisManager &AM) { - DXILResourceMap &DRM = AM.getResult(M); - DRM.print(OS); +PreservedAnalyses +DXILResourceBindingPrinterPass::run(Module &M, ModuleAnalysisManager &AM) { + DXILBindingMap &DBM = AM.getResult(M); + DXILResourceTypeMap &DRTM = AM.getResult(M); + + DBM.print(OS, DRTM, M.getDataLayout()); return PreservedAnalyses::all(); } -//===----------------------------------------------------------------------===// -// DXILResourceWrapperPass +void DXILResourceTypeWrapperPass::anchor() {} -DXILResourceWrapperPass::DXILResourceWrapperPass() : ModulePass(ID) { - initializeDXILResourceWrapperPassPass(*PassRegistry::getPassRegistry()); +DXILResourceTypeWrapperPass::DXILResourceTypeWrapperPass() : ImmutablePass(ID) { + initializeDXILResourceTypeWrapperPassPass(*PassRegistry::getPassRegistry()); } -DXILResourceWrapperPass::~DXILResourceWrapperPass() = default; +INITIALIZE_PASS(DXILResourceTypeWrapperPass, "dxil-resource-type", + "DXIL Resource Type Analysis", false, true) +char DXILResourceTypeWrapperPass::ID = 0; -void DXILResourceWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +ModulePass *llvm::createDXILResourceTypeWrapperPassPass() { + return new DXILResourceTypeWrapperPass(); +} + +DXILResourceBindingWrapperPass::DXILResourceBindingWrapperPass() + : ModulePass(ID) { + initializeDXILResourceBindingWrapperPassPass( + *PassRegistry::getPassRegistry()); +} + +DXILResourceBindingWrapperPass::~DXILResourceBindingWrapperPass() = default; + +void DXILResourceBindingWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredTransitive(); AU.setPreservesAll(); } -bool DXILResourceWrapperPass::runOnModule(Module &M) { - ResourceMap.reset(new DXILResourceMap(ResourceMapper(M).mapResources())); +bool DXILResourceBindingWrapperPass::runOnModule(Module &M) { + Map.reset(new DXILBindingMap()); + + DRTM = &getAnalysis().getResourceTypeMap(); + Map->populate(M, *DRTM); + return false; } -void DXILResourceWrapperPass::releaseMemory() { ResourceMap.reset(); } +void DXILResourceBindingWrapperPass::releaseMemory() { Map.reset(); } -void DXILResourceWrapperPass::print(raw_ostream &OS, const Module *) const { - if (!ResourceMap) { +void DXILResourceBindingWrapperPass::print(raw_ostream &OS, + const Module *M) const { + if (!Map) { OS << "No resource map has been built!\n"; return; } - ResourceMap->print(OS); + Map->print(OS, *DRTM, M->getDataLayout()); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD -void DXILResourceWrapperPass::dump() const { print(dbgs(), nullptr); } +void DXILResourceBindingWrapperPass::dump() const { print(dbgs(), nullptr); } #endif -INITIALIZE_PASS(DXILResourceWrapperPass, DEBUG_TYPE, "DXIL Resource analysis", - false, true) -char DXILResourceWrapperPass::ID = 0; +INITIALIZE_PASS(DXILResourceBindingWrapperPass, "dxil-resource-binding", + "DXIL Resource Binding Analysis", false, true) +char DXILResourceBindingWrapperPass::ID = 0; -ModulePass *llvm::createDXILResourceWrapperPassPass() { - return new DXILResourceWrapperPass(); +ModulePass *llvm::createDXILResourceBindingWrapperPassPass() { + return new DXILResourceBindingWrapperPass(); } diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index 825f2f7f9a494..ad7e6429a1741 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -22,7 +22,8 @@ MODULE_ANALYSIS("callgraph", CallGraphAnalysis()) MODULE_ANALYSIS("collector-metadata", CollectorMetadataAnalysis()) MODULE_ANALYSIS("ctx-prof-analysis", CtxProfAnalysis()) MODULE_ANALYSIS("dxil-metadata", DXILMetadataAnalysis()) -MODULE_ANALYSIS("dxil-resource", DXILResourceAnalysis()) +MODULE_ANALYSIS("dxil-resource-binding", DXILResourceBindingAnalysis()) +MODULE_ANALYSIS("dxil-resource-type", DXILResourceTypeAnalysis()) MODULE_ANALYSIS("inline-advisor", InlineAdvisorAnalysis()) MODULE_ANALYSIS("ir-similarity", IRSimilarityAnalysis()) MODULE_ANALYSIS("last-run-tracking", LastRunTrackingAnalysis()) @@ -127,7 +128,8 @@ MODULE_PASS("print-must-be-executed-contexts", MODULE_PASS("print-profile-summary", ProfileSummaryPrinterPass(errs())) MODULE_PASS("print-stack-safety", StackSafetyGlobalPrinterPass(errs())) MODULE_PASS("print", DXILMetadataAnalysisPrinterPass(errs())) -MODULE_PASS("print", DXILResourcePrinterPass(errs())) +MODULE_PASS("print", + DXILResourceBindingPrinterPass(errs())) MODULE_PASS("print", InlineAdvisorAnalysisPrinterPass(errs())) MODULE_PASS("print", ModuleDebugInfoPrinterPass(errs())) MODULE_PASS("print", PhysicalRegisterUsageInfoPrinterPass(errs())) diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp index aaf994b23cf3c..4c55a13b17f29 100644 --- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp +++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp @@ -61,7 +61,8 @@ class DXContainerGlobals : public llvm::ModulePass { AU.setPreservesAll(); AU.addRequired(); AU.addRequired(); - AU.addRequired(); + AU.addRequired(); + AU.addRequired(); } }; @@ -144,19 +145,23 @@ void DXContainerGlobals::addSignature(Module &M, } void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) { - const DXILResourceMap &ResMap = - getAnalysis().getResourceMap(); - - for (const dxil::ResourceInfo &ResInfo : ResMap) { - const dxil::ResourceInfo::ResourceBinding &Binding = ResInfo.getBinding(); + const DXILBindingMap &DBM = + getAnalysis().getBindingMap(); + DXILResourceTypeMap &DRTM = + getAnalysis().getResourceTypeMap(); + + for (const dxil::ResourceBindingInfo &RBI : DBM) { + const dxil::ResourceBindingInfo::ResourceBinding &Binding = + RBI.getBinding(); dxbc::PSV::v2::ResourceBindInfo BindInfo; BindInfo.LowerBound = Binding.LowerBound; BindInfo.UpperBound = Binding.LowerBound + Binding.Size - 1; BindInfo.Space = Binding.Space; + dxil::ResourceTypeInfo TypeInfo = DRTM[RBI.getHandleTy()]; dxbc::PSV::ResourceType ResType = dxbc::PSV::ResourceType::Invalid; - bool IsUAV = ResInfo.getResourceClass() == dxil::ResourceClass::UAV; - switch (ResInfo.getResourceKind()) { + bool IsUAV = TypeInfo.getResourceClass() == dxil::ResourceClass::UAV; + switch (TypeInfo.getResourceKind()) { case dxil::ResourceKind::Sampler: ResType = dxbc::PSV::ResourceType::Sampler; break; @@ -166,7 +171,7 @@ void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) { case dxil::ResourceKind::StructuredBuffer: ResType = IsUAV ? dxbc::PSV::ResourceType::UAVStructured : dxbc::PSV::ResourceType::SRVStructured; - if (IsUAV && ResInfo.getUAV().HasCounter) + if (IsUAV && TypeInfo.getUAV().HasCounter) ResType = dxbc::PSV::ResourceType::UAVStructuredWithCounter; break; case dxil::ResourceKind::RTAccelerationStructure: @@ -184,7 +189,7 @@ void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) { BindInfo.Type = ResType; BindInfo.Kind = - static_cast(ResInfo.getResourceKind()); + static_cast(TypeInfo.getResourceKind()); // TODO: Add support for dxbc::PSV::ResourceFlag::UsedByAtomic64, tracking // with https://github.com/llvm/llvm-project/issues/104392 BindInfo.Flags.Flags = 0u; @@ -240,7 +245,8 @@ INITIALIZE_PASS_BEGIN(DXContainerGlobals, "dxil-globals", "DXContainer Global Emitter", false, true) INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper) INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass) INITIALIZE_PASS_END(DXContainerGlobals, "dxil-globals", "DXContainer Global Emitter", false, true) diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp index 0e6cf59e25750..1783e4a546313 100644 --- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp @@ -10,7 +10,6 @@ #include "DirectX.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Analysis/DXILResource.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" @@ -33,7 +32,6 @@ class DXILDataScalarizationLegacy : public ModulePass { bool runOnModule(Module &M) override; DXILDataScalarizationLegacy() : ModulePass(ID) {} - void getAnalysisUsage(AnalysisUsage &AU) const override; static char ID; // Pass identification. }; @@ -276,7 +274,6 @@ PreservedAnalyses DXILDataScalarization::run(Module &M, if (!MadeChanges) return PreservedAnalyses::all(); PreservedAnalyses PA; - PA.preserve(); return PA; } @@ -284,10 +281,6 @@ bool DXILDataScalarizationLegacy::runOnModule(Module &M) { return findAndReplaceVectors(M); } -void DXILDataScalarizationLegacy::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addPreserved(); -} - char DXILDataScalarizationLegacy::ID = 0; INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE, diff --git a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp index 79ebbe0925e5c..91ac758150fb4 100644 --- a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp +++ b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp @@ -8,7 +8,6 @@ #include "DXILFinalizeLinkage.h" #include "DirectX.h" -#include "llvm/Analysis/DXILResource.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/Metadata.h" @@ -51,10 +50,6 @@ bool DXILFinalizeLinkageLegacy::runOnModule(Module &M) { return finalizeLinkage(M); } -void DXILFinalizeLinkageLegacy::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addPreserved(); -} - char DXILFinalizeLinkageLegacy::ID = 0; INITIALIZE_PASS_BEGIN(DXILFinalizeLinkageLegacy, DEBUG_TYPE, diff --git a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.h b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.h index 62d3a8a27cfce..aab1bc3f7a28e 100644 --- a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.h +++ b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.h @@ -32,7 +32,6 @@ class DXILFinalizeLinkageLegacy : public ModulePass { DXILFinalizeLinkageLegacy() : ModulePass(ID) {} bool runOnModule(Module &M) override; - void getAnalysisUsage(AnalysisUsage &AU) const override; static char ID; // Pass identification. }; } // namespace llvm diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp index e4a3bc76eeacd..6077af997212e 100644 --- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp +++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp @@ -14,7 +14,6 @@ #include "DirectX.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Analysis/DXILResource.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" @@ -38,7 +37,6 @@ class DXILFlattenArraysLegacy : public ModulePass { bool runOnModule(Module &M) override; DXILFlattenArraysLegacy() : ModulePass(ID) {} - void getAnalysisUsage(AnalysisUsage &AU) const override; static char ID; // Pass identification. }; @@ -419,7 +417,6 @@ PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) { if (!MadeChanges) return PreservedAnalyses::all(); PreservedAnalyses PA; - PA.preserve(); return PA; } @@ -427,10 +424,6 @@ bool DXILFlattenArraysLegacy::runOnModule(Module &M) { return flattenArrays(M); } -void DXILFlattenArraysLegacy::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addPreserved(); -} - char DXILFlattenArraysLegacy::ID = 0; INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE, diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp index d2bfca1fada55..3c6ea4470fbdc 100644 --- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp +++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp @@ -14,7 +14,6 @@ #include "DirectX.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Analysis/DXILResource.h" #include "llvm/CodeGen/Passes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" @@ -39,7 +38,6 @@ class DXILIntrinsicExpansionLegacy : public ModulePass { bool runOnModule(Module &M) override; DXILIntrinsicExpansionLegacy() : ModulePass(ID) {} - void getAnalysisUsage(AnalysisUsage &AU) const override; static char ID; // Pass identification. }; @@ -617,10 +615,6 @@ bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) { return expansionIntrinsics(M); } -void DXILIntrinsicExpansionLegacy::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addPreserved(); -} - char DXILIntrinsicExpansionLegacy::ID = 0; INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE, diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index d9e70da6ed653..c66b24442d4bd 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -77,11 +77,13 @@ namespace { class OpLowerer { Module &M; DXILOpBuilder OpBuilder; - DXILResourceMap &DRM; + DXILBindingMap &DBM; + DXILResourceTypeMap &DRTM; SmallVector CleanupCasts; public: - OpLowerer(Module &M, DXILResourceMap &DRM) : M(M), OpBuilder(M), DRM(DRM) {} + OpLowerer(Module &M, DXILBindingMap &DBM, DXILResourceTypeMap &DRTM) + : M(M), OpBuilder(M), DBM(DBM), DRTM(DRTM) {} /// Replace every call to \c F using \c ReplaceCall, and then erase \c F. If /// there is an error replacing a call, we emit a diagnostic and return true. @@ -257,10 +259,12 @@ class OpLowerer { return replaceFunction(F, [&](CallInst *CI) -> Error { IRB.SetInsertPoint(CI); - auto *It = DRM.find(CI); - assert(It != DRM.end() && "Resource not in map?"); - dxil::ResourceInfo &RI = *It; + auto *It = DBM.find(CI); + assert(It != DBM.end() && "Resource not in map?"); + dxil::ResourceBindingInfo &RI = *It; + const auto &Binding = RI.getBinding(); + dxil::ResourceClass RC = DRTM[RI.getHandleTy()].getResourceClass(); Value *IndexOp = CI->getArgOperand(3); if (Binding.LowerBound != 0) @@ -268,7 +272,7 @@ class OpLowerer { ConstantInt::get(Int32Ty, Binding.LowerBound)); std::array Args{ - ConstantInt::get(Int8Ty, llvm::to_underlying(RI.getResourceClass())), + ConstantInt::get(Int8Ty, llvm::to_underlying(RC)), ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp, CI->getArgOperand(4)}; Expected OpCall = @@ -293,18 +297,20 @@ class OpLowerer { return replaceFunction(F, [&](CallInst *CI) -> Error { IRB.SetInsertPoint(CI); - auto *It = DRM.find(CI); - assert(It != DRM.end() && "Resource not in map?"); - dxil::ResourceInfo &RI = *It; + auto *It = DBM.find(CI); + assert(It != DBM.end() && "Resource not in map?"); + dxil::ResourceBindingInfo &RI = *It; const auto &Binding = RI.getBinding(); + dxil::ResourceClass RC = DRTM[RI.getHandleTy()].getResourceClass(); Value *IndexOp = CI->getArgOperand(3); if (Binding.LowerBound != 0) IndexOp = IRB.CreateAdd(IndexOp, ConstantInt::get(Int32Ty, Binding.LowerBound)); - std::pair Props = RI.getAnnotateProps(); + std::pair Props = + RI.getAnnotateProps(*F.getParent(), DRTM); // For `CreateHandleFromBinding` we need the upper bound rather than the // size, so we need to be careful about the difference for "unbounded". @@ -312,8 +318,8 @@ class OpLowerer { uint32_t UpperBound = Binding.Size == Unbounded ? Unbounded : Binding.LowerBound + Binding.Size - 1; - Constant *ResBind = OpBuilder.getResBind( - Binding.LowerBound, UpperBound, Binding.Space, RI.getResourceClass()); + Constant *ResBind = OpBuilder.getResBind(Binding.LowerBound, UpperBound, + Binding.Space, RC); std::array BindArgs{ResBind, IndexOp, CI->getArgOperand(4)}; Expected OpBind = OpBuilder.tryCreateOp( OpCode::CreateHandleFromBinding, BindArgs, CI->getName()); @@ -340,7 +346,7 @@ class OpLowerer { } /// Lower `dx.handle.fromBinding` intrinsics depending on the shader model and - /// taking into account binding information from DXILResourceAnalysis. + /// taking into account binding information from DXILResourceBindingAnalysis. bool lowerHandleFromBinding(Function &F) { Triple TT(Triple(M.getTargetTriple())); if (TT.getDXILVersion() < VersionTuple(1, 6)) @@ -737,13 +743,14 @@ class OpLowerer { } // namespace PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &MAM) { - DXILResourceMap &DRM = MAM.getResult(M); + DXILBindingMap &DBM = MAM.getResult(M); + DXILResourceTypeMap &DRTM = MAM.getResult(M); - bool MadeChanges = OpLowerer(M, DRM).lowerIntrinsics(); + bool MadeChanges = OpLowerer(M, DBM, DRTM).lowerIntrinsics(); if (!MadeChanges) return PreservedAnalyses::all(); PreservedAnalyses PA; - PA.preserve(); + PA.preserve(); return PA; } @@ -751,18 +758,21 @@ namespace { class DXILOpLoweringLegacy : public ModulePass { public: bool runOnModule(Module &M) override { - DXILResourceMap &DRM = - getAnalysis().getResourceMap(); + DXILBindingMap &DBM = + getAnalysis().getBindingMap(); + DXILResourceTypeMap &DRTM = + getAnalysis().getResourceTypeMap(); - return OpLowerer(M, DRM).lowerIntrinsics(); + return OpLowerer(M, DBM, DRTM).lowerIntrinsics(); } StringRef getPassName() const override { return "DXIL Op Lowering"; } DXILOpLoweringLegacy() : ModulePass(ID) {} static char ID; // Pass identification. void getAnalysisUsage(llvm::AnalysisUsage &AU) const override { - AU.addRequired(); - AU.addPreserved(); + AU.addRequired(); + AU.addRequired(); + AU.addPreserved(); } }; char DXILOpLoweringLegacy::ID = 0; @@ -770,7 +780,8 @@ char DXILOpLoweringLegacy::ID = 0; INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, false) -INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass) INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, false) diff --git a/llvm/lib/Target/DirectX/DXILPrepare.cpp b/llvm/lib/Target/DirectX/DXILPrepare.cpp index 6092cfb3948f0..375e6ce712924 100644 --- a/llvm/lib/Target/DirectX/DXILPrepare.cpp +++ b/llvm/lib/Target/DirectX/DXILPrepare.cpp @@ -251,7 +251,7 @@ class DXILPrepareModule : public ModulePass { AU.addPreserved(); AU.addPreserved(); AU.addPreserved(); - AU.addPreserved(); + AU.addPreserved(); } static char ID; // Pass identification. }; diff --git a/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp b/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp index 0478dc2df988d..ff690f2abe490 100644 --- a/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp +++ b/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp @@ -48,7 +48,7 @@ static StringRef getRCPrefix(dxil::ResourceClass RC) { llvm_unreachable("covered switch"); } -static StringRef getFormatName(const dxil::ResourceInfo &RI) { +static StringRef getFormatName(const dxil::ResourceTypeInfo &RI) { if (RI.isTyped()) { switch (RI.getTyped().ElementTy) { case dxil::ElementType::I1: @@ -139,9 +139,9 @@ static StringRef getTextureDimName(dxil::ResourceKind RK) { namespace { struct FormatResourceDimension - : public llvm::FormatAdapter { - explicit FormatResourceDimension(const dxil::ResourceInfo &RI) - : llvm::FormatAdapter(RI) {} + : public llvm::FormatAdapter { + explicit FormatResourceDimension(const dxil::ResourceTypeInfo &RI) + : llvm::FormatAdapter(RI) {} void format(llvm::raw_ostream &OS, StringRef Style) override { dxil::ResourceKind RK = Item.getResourceKind(); @@ -149,7 +149,7 @@ struct FormatResourceDimension default: { OS << getTextureDimName(RK); if (Item.isMultiSample()) - OS << Item.getMultiSample().Count; + OS << Item.getMultiSampleCount(); break; } case dxil::ResourceKind::RawBuffer: @@ -172,33 +172,40 @@ struct FormatResourceDimension }; struct FormatBindingID - : public llvm::FormatAdapter { - explicit FormatBindingID(const dxil::ResourceInfo &RI) - : llvm::FormatAdapter(RI) {} + : public llvm::FormatAdapter { + dxil::ResourceClass RC; + + explicit FormatBindingID(const dxil::ResourceBindingInfo &RBI, + const dxil::ResourceTypeInfo &RTI) + : llvm::FormatAdapter(RBI), + RC(RTI.getResourceClass()) {} void format(llvm::raw_ostream &OS, StringRef Style) override { - OS << getRCPrefix(Item.getResourceClass()).upper() - << Item.getBinding().RecordID; + OS << getRCPrefix(RC).upper() << Item.getBinding().RecordID; } }; struct FormatBindingLocation - : public llvm::FormatAdapter { - explicit FormatBindingLocation(const dxil::ResourceInfo &RI) - : llvm::FormatAdapter(RI) {} + : public llvm::FormatAdapter { + dxil::ResourceClass RC; + + explicit FormatBindingLocation(const dxil::ResourceBindingInfo &RBI, + const dxil::ResourceTypeInfo &RTI) + : llvm::FormatAdapter(RBI), + RC(RTI.getResourceClass()) {} void format(llvm::raw_ostream &OS, StringRef Style) override { const auto &Binding = Item.getBinding(); - OS << getRCPrefix(Item.getResourceClass()) << Binding.LowerBound; + OS << getRCPrefix(RC) << Binding.LowerBound; if (Binding.Space) OS << ",space" << Binding.Space; } }; struct FormatBindingSize - : public llvm::FormatAdapter { - explicit FormatBindingSize(const dxil::ResourceInfo &RI) - : llvm::FormatAdapter(RI) {} + : public llvm::FormatAdapter { + explicit FormatBindingSize(const dxil::ResourceBindingInfo &RI) + : llvm::FormatAdapter(RI) {} void format(llvm::raw_ostream &OS, StringRef Style) override { uint32_t Size = Item.getBinding().Size; @@ -211,7 +218,8 @@ struct FormatBindingSize } // namespace -static void prettyPrintResources(raw_ostream &OS, const DXILResourceMap &DRM, +static void prettyPrintResources(raw_ostream &OS, const DXILBindingMap &DBM, + DXILResourceTypeMap &DRTM, const dxil::Resources &MDResources) { // Column widths are arbitrary but match the widths DXC uses. OS << ";\n; Resource Bindings:\n;\n"; @@ -222,20 +230,22 @@ static void prettyPrintResources(raw_ostream &OS, const DXILResourceMap &DRM, "", "", "", "", ""); // TODO: Do we want to sort these by binding or something like that? - for (const dxil::ResourceInfo &RI : DRM) { - dxil::ResourceClass RC = RI.getResourceClass(); + for (const dxil::ResourceBindingInfo &RBI : DBM) { + const dxil::ResourceTypeInfo &RTI = DRTM[RBI.getHandleTy()]; + + dxil::ResourceClass RC = RTI.getResourceClass(); assert((RC != dxil::ResourceClass::CBuffer || !MDResources.hasCBuffers()) && "Old and new cbuffer representations can't coexist"); assert((RC != dxil::ResourceClass::UAV || !MDResources.hasUAVs()) && "Old and new UAV representations can't coexist"); - StringRef Name(RI.getName()); + StringRef Name(RBI.getName()); StringRef Type(getRCName(RC)); - StringRef Format(getFormatName(RI)); - FormatResourceDimension Dim(RI); - FormatBindingID ID(RI); - FormatBindingLocation Bind(RI); - FormatBindingSize Count(RI); + StringRef Format(getFormatName(RTI)); + FormatResourceDimension Dim(RTI); + FormatBindingID ID(RBI, RTI); + FormatBindingLocation Bind(RBI, RTI); + FormatBindingSize Count(RBI); OS << formatv("; {0,-30} {1,10} {2,7} {3,11} {4,7} {5,14} {6,9}\n", Name, Type, Format, Dim, ID, Bind, Count); } @@ -250,9 +260,10 @@ static void prettyPrintResources(raw_ostream &OS, const DXILResourceMap &DRM, PreservedAnalyses DXILPrettyPrinterPass::run(Module &M, ModuleAnalysisManager &MAM) { - const DXILResourceMap &DRM = MAM.getResult(M); + const DXILBindingMap &DBM = MAM.getResult(M); + DXILResourceTypeMap &DRTM = MAM.getResult(M); const dxil::Resources &MDResources = MAM.getResult(M); - prettyPrintResources(OS, DRM, MDResources); + prettyPrintResources(OS, DBM, DRTM, MDResources); return PreservedAnalyses::all(); } @@ -277,7 +288,8 @@ class DXILPrettyPrinterLegacy : public llvm::ModulePass { bool runOnModule(Module &M) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesAll(); - AU.addRequired(); + AU.addRequired(); + AU.addRequired(); AU.addRequired(); } }; @@ -286,16 +298,19 @@ class DXILPrettyPrinterLegacy : public llvm::ModulePass { char DXILPrettyPrinterLegacy::ID = 0; INITIALIZE_PASS_BEGIN(DXILPrettyPrinterLegacy, "dxil-pretty-printer", "DXIL Metadata Pretty Printer", true, true) -INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass) INITIALIZE_PASS_DEPENDENCY(DXILResourceMDWrapper) INITIALIZE_PASS_END(DXILPrettyPrinterLegacy, "dxil-pretty-printer", "DXIL Metadata Pretty Printer", true, true) bool DXILPrettyPrinterLegacy::runOnModule(Module &M) { - const DXILResourceMap &DRM = - getAnalysis().getResourceMap(); + const DXILBindingMap &DBM = + getAnalysis().getBindingMap(); + DXILResourceTypeMap &DRTM = + getAnalysis().getResourceTypeMap(); dxil::Resources &Res = getAnalysis().getDXILResource(); - prettyPrintResources(OS, DRM, Res); + prettyPrintResources(OS, DBM, DRTM, Res); return false; } diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp index 4ba10d123e8d2..9443ccd9c82a5 100644 --- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp @@ -72,25 +72,30 @@ enum class EntryPropsTag { } // namespace -static NamedMDNode *emitResourceMetadata(Module &M, const DXILResourceMap &DRM, +static NamedMDNode *emitResourceMetadata(Module &M, DXILBindingMap &DBM, + DXILResourceTypeMap &DRTM, const dxil::Resources &MDResources) { LLVMContext &Context = M.getContext(); + for (ResourceBindingInfo &RI : DBM) + if (!RI.hasSymbol()) + RI.createSymbol(M, DRTM[RI.getHandleTy()].createElementStruct()); + SmallVector SRVs, UAVs, CBufs, Smps; - for (const ResourceInfo &RI : DRM.srvs()) - SRVs.push_back(RI.getAsMetadata(Context)); - for (const ResourceInfo &RI : DRM.uavs()) - UAVs.push_back(RI.getAsMetadata(Context)); - for (const ResourceInfo &RI : DRM.cbuffers()) - CBufs.push_back(RI.getAsMetadata(Context)); - for (const ResourceInfo &RI : DRM.samplers()) - Smps.push_back(RI.getAsMetadata(Context)); + for (const ResourceBindingInfo &RI : DBM.srvs()) + SRVs.push_back(RI.getAsMetadata(M, DRTM)); + for (const ResourceBindingInfo &RI : DBM.uavs()) + UAVs.push_back(RI.getAsMetadata(M, DRTM)); + for (const ResourceBindingInfo &RI : DBM.cbuffers()) + CBufs.push_back(RI.getAsMetadata(M, DRTM)); + for (const ResourceBindingInfo &RI : DBM.samplers()) + Smps.push_back(RI.getAsMetadata(M, DRTM)); Metadata *SRVMD = SRVs.empty() ? nullptr : MDNode::get(Context, SRVs); Metadata *UAVMD = UAVs.empty() ? nullptr : MDNode::get(Context, UAVs); Metadata *CBufMD = CBufs.empty() ? nullptr : MDNode::get(Context, CBufs); Metadata *SmpMD = Smps.empty() ? nullptr : MDNode::get(Context, Smps); - bool HasResources = !DRM.empty(); + bool HasResources = !DBM.empty(); if (MDResources.hasUAVs()) { assert(!UAVMD && "Old and new UAV representations can't coexist"); @@ -295,7 +300,8 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD, return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx); } -static void translateMetadata(Module &M, const DXILResourceMap &DRM, +static void translateMetadata(Module &M, DXILBindingMap &DBM, + DXILResourceTypeMap &DRTM, const Resources &MDResources, const ModuleShaderFlags &ShaderFlags, const ModuleMetadataInfo &MMDI) { @@ -306,7 +312,8 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM, emitValidatorVersionMD(M, MMDI); emitShaderModelVersionMD(M, MMDI); emitDXILVersionTupleMD(M, MMDI); - NamedMDNode *NamedResourceMD = emitResourceMetadata(M, DRM, MDResources); + NamedMDNode *NamedResourceMD = + emitResourceMetadata(M, DBM, DRTM, MDResources); auto *ResourceMD = (NamedResourceMD != nullptr) ? NamedResourceMD->getOperand(0) : nullptr; // FIXME: Add support to construct Signatures @@ -358,12 +365,13 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM, PreservedAnalyses DXILTranslateMetadata::run(Module &M, ModuleAnalysisManager &MAM) { - const DXILResourceMap &DRM = MAM.getResult(M); + DXILBindingMap &DBM = MAM.getResult(M); + DXILResourceTypeMap &DRTM = MAM.getResult(M); const dxil::Resources &MDResources = MAM.getResult(M); const ModuleShaderFlags &ShaderFlags = MAM.getResult(M); const dxil::ModuleMetadataInfo MMDI = MAM.getResult(M); - translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI); + translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI); return PreservedAnalyses::all(); } @@ -377,18 +385,21 @@ class DXILTranslateMetadataLegacy : public ModulePass { StringRef getPassName() const override { return "DXIL Translate Metadata"; } void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired(); + AU.addRequired(); + AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addRequired(); - AU.addPreserved(); + AU.addPreserved(); AU.addPreserved(); AU.addPreserved(); } bool runOnModule(Module &M) override { - const DXILResourceMap &DRM = - getAnalysis().getResourceMap(); + DXILBindingMap &DBM = + getAnalysis().getBindingMap(); + DXILResourceTypeMap &DRTM = + getAnalysis().getResourceTypeMap(); const dxil::Resources &MDResources = getAnalysis().getDXILResource(); const ModuleShaderFlags &ShaderFlags = @@ -396,7 +407,7 @@ class DXILTranslateMetadataLegacy : public ModulePass { dxil::ModuleMetadataInfo MMDI = getAnalysis().getModuleMetadata(); - translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI); + translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI); return true; } }; @@ -411,7 +422,7 @@ ModulePass *llvm::createDXILTranslateMetadataLegacyPass() { INITIALIZE_PASS_BEGIN(DXILTranslateMetadataLegacy, "dxil-translate-metadata", "DXIL Translate Metadata", false, false) -INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass) INITIALIZE_PASS_DEPENDENCY(DXILResourceMDWrapper) INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper) INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) diff --git a/llvm/test/Analysis/DXILResource/buffer-frombinding.ll b/llvm/test/Analysis/DXILResource/buffer-frombinding.ll index b26a185423597..313c8376483b9 100644 --- a/llvm/test/Analysis/DXILResource/buffer-frombinding.ll +++ b/llvm/test/Analysis/DXILResource/buffer-frombinding.ll @@ -1,15 +1,13 @@ -; RUN: opt -S -disable-output -passes="print" < %s 2>&1 | FileCheck %s +; RUN: opt -S -disable-output -passes="print" < %s 2>&1 | FileCheck %s @G = external constant <4 x float>, align 4 define void @test_typedbuffer() { ; ByteAddressBuffer Buf : register(t8, space1) - %srv0 = call target("dx.RawBuffer", i8, 0, 0) + %srv0 = call target("dx.RawBuffer", void, 0, 0) @llvm.dx.handle.fromBinding.tdx.RawBuffer_i8_0_0t( i32 1, i32 8, i32 1, i32 0, i1 false) ; CHECK: Binding [[SRV0:[0-9]+]]: - ; CHECK: Symbol: ptr undef - ; CHECK: Name: "" ; CHECK: Binding: ; CHECK: Record ID: 0 ; CHECK: Space: 1 @@ -24,8 +22,6 @@ define void @test_typedbuffer() { @llvm.dx.handle.fromBinding.tdx.RawBuffer_sl_v4f32v4i32s_0_0t( i32 4, i32 2, i32 1, i32 0, i1 false) ; CHECK: Binding [[SRV1:[0-9]+]]: - ; CHECK: Symbol: ptr undef - ; CHECK: Name: "" ; CHECK: Binding: ; CHECK: Record ID: 1 ; CHECK: Space: 4 @@ -41,8 +37,6 @@ define void @test_typedbuffer() { @llvm.dx.handle.fromBinding.tdx.TypedBuffer_i32_0_0t( i32 5, i32 3, i32 24, i32 0, i1 false) ; CHECK: Binding [[SRV2:[0-9]+]]: - ; CHECK: Symbol: ptr undef - ; CHECK: Name: "" ; CHECK: Binding: ; CHECK: Record ID: 2 ; CHECK: Space: 5 @@ -58,8 +52,6 @@ define void @test_typedbuffer() { @llvm.dx.handle.fromBinding.tdx.TypedBuffer_i32_1_0t( i32 2, i32 7, i32 1, i32 0, i1 false) ; CHECK: Binding [[UAV0:[0-9]+]]: - ; CHECK: Symbol: ptr undef - ; CHECK: Name: "" ; CHECK: Binding: ; CHECK: Record ID: 0 ; CHECK: Space: 2 @@ -78,8 +70,6 @@ define void @test_typedbuffer() { @llvm.dx.handle.fromBinding.tdx.TypedBuffer_f32_1_0( i32 3, i32 5, i32 1, i32 0, i1 false) ; CHECK: Binding [[UAV1:[0-9]+]]: - ; CHECK: Symbol: ptr undef - ; CHECK: Name: "" ; CHECK: Binding: ; CHECK: Record ID: 1 ; CHECK: Space: 3 @@ -103,8 +93,6 @@ define void @test_typedbuffer() { @llvm.dx.handle.fromBinding.tdx.TypedBuffer_f32_1_0( i32 4, i32 0, i32 10, i32 5, i1 false) ; CHECK: Binding [[UAV2:[0-9]+]]: - ; CHECK: Symbol: ptr undef - ; CHECK: Name: "" ; CHECK: Binding: ; CHECK: Record ID: 2 ; CHECK: Space: 4 diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll index 40fa30778a153..147898efc716f 100644 --- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll +++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll @@ -6,6 +6,7 @@ ; CHECK-LABEL: Pass Arguments: ; CHECK-NEXT: Target Library Information ; CHECK-NEXT: Target Transform Information +; CHECK-NEXT: DXIL Resource Type Analysis ; CHECK-NEXT: ModulePass Manager ; CHECK-NEXT: DXIL Finalize Linkage ; CHECK-NEXT: DXIL Intrinsic Expansion @@ -14,7 +15,7 @@ ; CHECK-NEXT: FunctionPass Manager ; CHECK-NEXT: Dominator Tree Construction ; CHECK-NEXT: Scalarize vector operations -; CHECK-NEXT: DXIL Resource analysis +; CHECK-NEXT: DXIL Resource Binding Analysis ; CHECK-NEXT: DXIL Op Lowering ; CHECK-NEXT: DXIL resource Information ; CHECK-NEXT: DXIL Shader Flag Analysis @@ -23,4 +24,3 @@ ; CHECK-NEXT: DXIL Prepare Module ; CHECK-NEXT: DXIL Metadata Pretty Printer ; CHECK-NEXT: Print Module IR - diff --git a/llvm/unittests/Analysis/DXILResourceTest.cpp b/llvm/unittests/Analysis/DXILResourceTest.cpp index e24018457dabe..4c005f817af59 100644 --- a/llvm/unittests/Analysis/DXILResourceTest.cpp +++ b/llvm/unittests/Analysis/DXILResourceTest.cpp @@ -8,6 +8,9 @@ #include "llvm/Analysis/DXILResource.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Module.h" #include "gtest/gtest.h" using namespace llvm; @@ -99,8 +102,16 @@ testing::AssertionResult MDTupleEq(const char *LHSExpr, const char *RHSExpr, } // namespace TEST(DXILResource, AnnotationsAndMetadata) { + // TODO: How am I supposed to get this? + DataLayout DL("e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-" + "f64:64-n8:16:32:64-v96:32"); + LLVMContext Context; + Module M("AnnotationsAndMetadata", Context); + M.setDataLayout(DL); + Type *Int1Ty = Type::getInt1Ty(Context); + Type *Int8Ty = Type::getInt8Ty(Context); Type *Int32Ty = Type::getInt32Ty(Context); Type *FloatTy = Type::getFloatTy(Context); Type *DoubleTy = Type::getDoubleTy(Context); @@ -110,205 +121,327 @@ TEST(DXILResource, AnnotationsAndMetadata) { MDBuilder TestMD(Context, Int32Ty, Int1Ty); - // ByteAddressBuffer Buffer0; - Value *Symbol = UndefValue::get( - StructType::create(Context, {Int32Ty}, "struct.ByteAddressBuffer")); - ResourceInfo Resource = ResourceInfo::RawBuffer(Symbol, "Buffer0"); - Resource.bind(0, 0, 0, 1); - std::pair Props = Resource.getAnnotateProps(); + // ByteAddressBuffer Buffer; + ResourceTypeInfo RTI(llvm::TargetExtType::get( + Context, "dx.RawBuffer", Int8Ty, {/*IsWriteable=*/0, /*IsROV=*/0})); + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::SRV); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::RawBuffer); + + ResourceBindingInfo RBI( + /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, + RTI.getHandleTy()); + GlobalVariable *GV = RBI.createSymbol(M, RTI.createElementStruct(), "Buffer"); + std::pair Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x0000000bU); EXPECT_EQ(Props.second, 0U); - MDTuple *MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "Buffer0", 0, 0, 1, 11, 0, nullptr)); + MDTuple *MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ(MD, TestMD.get(0, GV, "Buffer", 0, 0, 1, 11, 0, nullptr)); // RWByteAddressBuffer BufferOut : register(u3, space2); - Symbol = UndefValue::get( - StructType::create(Context, {Int32Ty}, "struct.RWByteAddressBuffer")); - Resource = - ResourceInfo::RWRawBuffer(Symbol, "BufferOut", - /*GloballyCoherent=*/false, /*IsROV=*/false); - Resource.bind(1, 2, 3, 1); - Props = Resource.getAnnotateProps(); + RTI = ResourceTypeInfo(llvm::TargetExtType::get( + Context, "dx.RawBuffer", Int8Ty, {/*IsWriteable=*/1, /*IsROV=*/0})); + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::UAV); + EXPECT_EQ(RTI.getUAV().GloballyCoherent, false); + EXPECT_EQ(RTI.getUAV().HasCounter, false); + EXPECT_EQ(RTI.getUAV().IsROV, false); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::RawBuffer); + + RBI = ResourceBindingInfo( + /*RecordID=*/1, /*Space=*/2, /*LowerBound=*/3, /*Size=*/1, + RTI.getHandleTy()); + GV = RBI.createSymbol(M, RTI.createElementStruct(), "BufferOut"); + Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x0000100bU); EXPECT_EQ(Props.second, 0U); - MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ(MD, TestMD.get(1, Symbol, "BufferOut", 2, 3, 1, 11, false, false, + MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ(MD, TestMD.get(1, GV, "BufferOut", 2, 3, 1, 11, false, false, false, nullptr)); // struct BufType0 { int i; float f; double d; }; // StructuredBuffer Buffer0 : register(t0); StructType *BufType0 = StructType::create(Context, {Int32Ty, FloatTy, DoubleTy}, "BufType0"); - Symbol = UndefValue::get(StructType::create( - Context, {BufType0}, "class.StructuredBuffer")); - Resource = ResourceInfo::StructuredBuffer(Symbol, "Buffer0", - /*Stride=*/16, Align(8)); - Resource.bind(0, 0, 0, 1); - Props = Resource.getAnnotateProps(); + RTI = ResourceTypeInfo(llvm::TargetExtType::get( + Context, "dx.RawBuffer", BufType0, {/*IsWriteable=*/0, /*IsROV=*/0})); + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::SRV); + ASSERT_EQ(RTI.isStruct(), true); + EXPECT_EQ(RTI.getStruct(DL).Stride, 16u); + EXPECT_EQ(RTI.getStruct(DL).AlignLog2, Log2(Align(8))); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::StructuredBuffer); + + RBI = ResourceBindingInfo( + /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, + RTI.getHandleTy()); + GV = RBI.createSymbol(M, RTI.createElementStruct(), "Buffer0"); + Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x0000030cU); EXPECT_EQ(Props.second, 0x00000010U); - MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ( - MD, TestMD.get(0, Symbol, "Buffer0", 0, 0, 1, 12, 0, TestMD.get(1, 16))); + MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ(MD, + TestMD.get(0, GV, "Buffer0", 0, 0, 1, 12, 0, TestMD.get(1, 16))); // StructuredBuffer Buffer1 : register(t1); - Symbol = UndefValue::get(StructType::create( - Context, {Floatx3Ty}, "class.StructuredBuffer >")); - Resource = ResourceInfo::StructuredBuffer(Symbol, "Buffer1", - /*Stride=*/12, {}); - Resource.bind(1, 0, 1, 1); - Props = Resource.getAnnotateProps(); + RTI = ResourceTypeInfo(llvm::TargetExtType::get( + Context, "dx.RawBuffer", Floatx3Ty, {/*IsWriteable=*/0, /*IsROV=*/0})); + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::SRV); + ASSERT_EQ(RTI.isStruct(), true); + EXPECT_EQ(RTI.getStruct(DL).Stride, 12u); + EXPECT_EQ(RTI.getStruct(DL).AlignLog2, 0u); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::StructuredBuffer); + + RBI = ResourceBindingInfo( + /*RecordID=*/1, /*Space=*/0, /*LowerBound=*/1, /*Size=*/1, + RTI.getHandleTy()); + GV = RBI.createSymbol(M, RTI.createElementStruct(), "Buffer1"); + Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x0000000cU); EXPECT_EQ(Props.second, 0x0000000cU); - MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ( - MD, TestMD.get(1, Symbol, "Buffer1", 0, 1, 1, 12, 0, TestMD.get(1, 12))); + MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ(MD, + TestMD.get(1, GV, "Buffer1", 0, 1, 1, 12, 0, TestMD.get(1, 12))); // Texture2D ColorMapTexture : register(t2); - Symbol = UndefValue::get(StructType::create( - Context, {Floatx4Ty}, "class.Texture2D >")); - Resource = - ResourceInfo::SRV(Symbol, "ColorMapTexture", dxil::ElementType::F32, - /*ElementCount=*/4, dxil::ResourceKind::Texture2D); - Resource.bind(2, 0, 2, 1); - Props = Resource.getAnnotateProps(); + RTI = ResourceTypeInfo( + llvm::TargetExtType::get(Context, "dx.Texture", Floatx4Ty, + {/*IsWriteable=*/0, /*IsROV=*/0, /*IsSigned=*/0, + llvm::to_underlying(ResourceKind::Texture2D)})); + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::SRV); + ASSERT_EQ(RTI.isTyped(), true); + EXPECT_EQ(RTI.getTyped().ElementTy, ElementType::F32); + EXPECT_EQ(RTI.getTyped().ElementCount, 4u); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::Texture2D); + + RBI = ResourceBindingInfo( + /*RecordID=*/2, /*Space=*/0, /*LowerBound=*/2, /*Size=*/1, + RTI.getHandleTy()); + GV = RBI.createSymbol(M, RTI.createElementStruct(), "ColorMapTexture"); + Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x00000002U); EXPECT_EQ(Props.second, 0x00000409U); - MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ(MD, TestMD.get(2, Symbol, "ColorMapTexture", 0, 2, 1, 2, 0, + MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ(MD, TestMD.get(2, GV, "ColorMapTexture", 0, 2, 1, 2, 0, TestMD.get(0, 9))); // Texture2DMS DepthBuffer : register(t0); - Symbol = UndefValue::get( - StructType::create(Context, {FloatTy}, "class.Texture2DMS")); - Resource = - ResourceInfo::Texture2DMS(Symbol, "DepthBuffer", dxil::ElementType::F32, - /*ElementCount=*/1, /*SampleCount=*/8); - Resource.bind(0, 0, 0, 1); - Props = Resource.getAnnotateProps(); + RTI = ResourceTypeInfo(llvm::TargetExtType::get( + Context, "dx.MSTexture", FloatTy, + {/*IsWriteable=*/0, /*SampleCount=*/8, + /*IsSigned=*/0, llvm::to_underlying(ResourceKind::Texture2DMS)})); + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::SRV); + ASSERT_EQ(RTI.isTyped(), true); + EXPECT_EQ(RTI.getTyped().ElementTy, ElementType::F32); + EXPECT_EQ(RTI.getTyped().ElementCount, 1u); + ASSERT_EQ(RTI.isMultiSample(), true); + EXPECT_EQ(RTI.getMultiSampleCount(), 8u); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::Texture2DMS); + + RBI = ResourceBindingInfo( + /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, + RTI.getHandleTy()); + GV = RBI.createSymbol(M, RTI.createElementStruct(), "DepthBuffer"); + Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x00000003U); EXPECT_EQ(Props.second, 0x00080109U); - MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "DepthBuffer", 0, 0, 1, 3, 8, - TestMD.get(0, 9))); + MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ( + MD, TestMD.get(0, GV, "DepthBuffer", 0, 0, 1, 3, 8, TestMD.get(0, 9))); // FeedbackTexture2D feedbackMinMip; - Symbol = UndefValue::get( - StructType::create(Context, {Int32Ty}, "class.FeedbackTexture2D<0>")); - Resource = ResourceInfo::FeedbackTexture2D(Symbol, "feedbackMinMip", - SamplerFeedbackType::MinMip); - Resource.bind(0, 0, 0, 1); - Props = Resource.getAnnotateProps(); + RTI = ResourceTypeInfo(llvm::TargetExtType::get( + Context, "dx.FeedbackTexture", {}, + {llvm::to_underlying(SamplerFeedbackType::MinMip), + llvm::to_underlying(ResourceKind::FeedbackTexture2D)})); + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::UAV); + ASSERT_EQ(RTI.isFeedback(), true); + EXPECT_EQ(RTI.getFeedbackType(), SamplerFeedbackType::MinMip); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::FeedbackTexture2D); + + RBI = ResourceBindingInfo( + /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, + RTI.getHandleTy()); + GV = RBI.createSymbol(M, RTI.createElementStruct(), "feedbackMinMip"); + Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x00001011U); EXPECT_EQ(Props.second, 0U); - MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "feedbackMinMip", 0, 0, 1, 17, false, - false, false, TestMD.get(2, 0))); + MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ(MD, TestMD.get(0, GV, "feedbackMinMip", 0, 0, 1, 17, false, false, + false, TestMD.get(2, 0))); // FeedbackTexture2DArray feedbackMipRegion; - Symbol = UndefValue::get(StructType::create( - Context, {Int32Ty}, "class.FeedbackTexture2DArray<1>")); - Resource = ResourceInfo::FeedbackTexture2DArray( - Symbol, "feedbackMipRegion", SamplerFeedbackType::MipRegionUsed); - Resource.bind(0, 0, 0, 1); - Props = Resource.getAnnotateProps(); + RTI = ResourceTypeInfo(llvm::TargetExtType::get( + Context, "dx.FeedbackTexture", {}, + {llvm::to_underlying(SamplerFeedbackType::MipRegionUsed), + llvm::to_underlying(ResourceKind::FeedbackTexture2DArray)})); + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::UAV); + ASSERT_EQ(RTI.isFeedback(), true); + EXPECT_EQ(RTI.getFeedbackType(), SamplerFeedbackType::MipRegionUsed); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::FeedbackTexture2DArray); + + RBI = ResourceBindingInfo( + /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, + RTI.getHandleTy()); + GV = RBI.createSymbol(M, RTI.createElementStruct(), "feedbackMipRegion"); + Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x00001012U); EXPECT_EQ(Props.second, 0x00000001U); - MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "feedbackMipRegion", 0, 0, 1, 18, false, + MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ(MD, TestMD.get(0, GV, "feedbackMipRegion", 0, 0, 1, 18, false, false, false, TestMD.get(2, 1))); // globallycoherent RWTexture2D OutputTexture : register(u0, space2); - Symbol = UndefValue::get(StructType::create( - Context, {Int32x2Ty}, "class.RWTexture2D >")); - Resource = ResourceInfo::UAV(Symbol, "OutputTexture", dxil::ElementType::I32, - /*ElementCount=*/2, /*GloballyCoherent=*/1, - /*IsROV=*/0, dxil::ResourceKind::Texture2D); - Resource.bind(0, 2, 0, 1); - Props = Resource.getAnnotateProps(); + RTI = ResourceTypeInfo( + llvm::TargetExtType::get(Context, "dx.Texture", Int32x2Ty, + {/*IsWriteable=*/1, + /*IsROV=*/0, /*IsSigned=*/1, + llvm::to_underlying(ResourceKind::Texture2D)}), + /*GloballyCoherent=*/true, /*HasCounter=*/false); + + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::UAV); + EXPECT_EQ(RTI.getUAV().GloballyCoherent, true); + EXPECT_EQ(RTI.getUAV().HasCounter, false); + EXPECT_EQ(RTI.getUAV().IsROV, false); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::Texture2D); + + RBI = ResourceBindingInfo( + /*RecordID=*/0, /*Space=*/2, /*LowerBound=*/0, /*Size=*/1, + RTI.getHandleTy()); + GV = RBI.createSymbol(M, RTI.createElementStruct(), "OutputTexture"); + Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x00005002U); EXPECT_EQ(Props.second, 0x00000204U); - MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "OutputTexture", 2, 0, 1, 2, true, - false, false, TestMD.get(0, 4))); + MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ(MD, TestMD.get(0, GV, "OutputTexture", 2, 0, 1, 2, true, false, + false, TestMD.get(0, 4))); // RasterizerOrderedBuffer ROB; - Symbol = UndefValue::get( - StructType::create(Context, {Floatx4Ty}, - "class.RasterizerOrderedBuffer >")); - Resource = ResourceInfo::UAV(Symbol, "ROB", dxil::ElementType::F32, - /*ElementCount=*/4, /*GloballyCoherent=*/0, - /*IsROV=*/1, dxil::ResourceKind::TypedBuffer); - Resource.bind(0, 0, 0, 1); - Props = Resource.getAnnotateProps(); + RTI = ResourceTypeInfo(llvm::TargetExtType::get( + Context, "dx.TypedBuffer", Floatx4Ty, + {/*IsWriteable=*/1, /*IsROV=*/1, /*IsSigned=*/0})); + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::UAV); + EXPECT_EQ(RTI.getUAV().GloballyCoherent, false); + EXPECT_EQ(RTI.getUAV().HasCounter, false); + EXPECT_EQ(RTI.getUAV().IsROV, true); + ASSERT_EQ(RTI.isTyped(), true); + EXPECT_EQ(RTI.getTyped().ElementTy, ElementType::F32); + EXPECT_EQ(RTI.getTyped().ElementCount, 4u); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::TypedBuffer); + + RBI = ResourceBindingInfo( + /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, + RTI.getHandleTy()); + GV = RBI.createSymbol(M, RTI.createElementStruct(), "ROB"); + Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x0000300aU); EXPECT_EQ(Props.second, 0x00000409U); - MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "ROB", 0, 0, 1, 10, false, false, true, + MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ(MD, TestMD.get(0, GV, "ROB", 0, 0, 1, 10, false, false, true, TestMD.get(0, 9))); // RWStructuredBuffer g_OutputBuffer : register(u2); StructType *BufType1 = StructType::create( Context, {Floatx3Ty, FloatTy, Int32Ty}, "ParticleMotion"); - Symbol = UndefValue::get(StructType::create( - Context, {BufType1}, "class.StructuredBuffer")); - Resource = - ResourceInfo::RWStructuredBuffer(Symbol, "g_OutputBuffer", /*Stride=*/20, - Align(4), /*GloballyCoherent=*/false, - /*IsROV=*/false, /*HasCounter=*/true); - Resource.bind(0, 0, 2, 1); - Props = Resource.getAnnotateProps(); + RTI = ResourceTypeInfo( + llvm::TargetExtType::get(Context, "dx.RawBuffer", BufType1, + {/*IsWriteable=*/1, /*IsROV=*/0}), + /*GloballyCoherent=*/false, /*HasCounter=*/true); + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::UAV); + EXPECT_EQ(RTI.getUAV().GloballyCoherent, false); + EXPECT_EQ(RTI.getUAV().HasCounter, true); + EXPECT_EQ(RTI.getUAV().IsROV, false); + ASSERT_EQ(RTI.isStruct(), true); + EXPECT_EQ(RTI.getStruct(DL).Stride, 20u); + EXPECT_EQ(RTI.getStruct(DL).AlignLog2, Log2(Align(4))); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::StructuredBuffer); + + RBI = ResourceBindingInfo( + /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/2, /*Size=*/1, + RTI.getHandleTy()); + GV = RBI.createSymbol(M, RTI.createElementStruct(), "g_OutputBuffer"); + Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x0000920cU); EXPECT_EQ(Props.second, 0x00000014U); - MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "g_OutputBuffer", 0, 2, 1, 12, false, - true, false, TestMD.get(1, 20))); - - // RWTexture2DMSArray g_rw_t2dmsa; - Symbol = UndefValue::get(StructType::create( - Context, {Int32Ty}, "class.RWTexture2DMSArray")); - Resource = ResourceInfo::RWTexture2DMSArray( - Symbol, "g_rw_t2dmsa", dxil::ElementType::U32, /*ElementCount=*/1, - /*SampleCount=*/8, /*GloballyCoherent=*/false); - Resource.bind(0, 0, 0, 1); - Props = Resource.getAnnotateProps(); + MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ(MD, TestMD.get(0, GV, "g_OutputBuffer", 0, 2, 1, 12, false, true, + false, TestMD.get(1, 20))); + + // RWTexture2DMSArray g_rw_t2dmsa; + RTI = ResourceTypeInfo(llvm::TargetExtType::get( + Context, "dx.MSTexture", Int32Ty, + {/*IsWriteable=*/1, /*SampleCount=*/8, /*IsSigned=*/0, + llvm::to_underlying(ResourceKind::Texture2DMSArray)})); + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::UAV); + EXPECT_EQ(RTI.getUAV().GloballyCoherent, false); + EXPECT_EQ(RTI.getUAV().HasCounter, false); + EXPECT_EQ(RTI.getUAV().IsROV, false); + ASSERT_EQ(RTI.isTyped(), true); + EXPECT_EQ(RTI.getTyped().ElementTy, ElementType::U32); + EXPECT_EQ(RTI.getTyped().ElementCount, 1u); + ASSERT_EQ(RTI.isMultiSample(), true); + EXPECT_EQ(RTI.getMultiSampleCount(), 8u); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::Texture2DMSArray); + + RBI = ResourceBindingInfo( + /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, + RTI.getHandleTy()); + GV = RBI.createSymbol(M, RTI.createElementStruct(), "g_rw_t2dmsa"); + Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x00001008U); EXPECT_EQ(Props.second, 0x00080105U); - MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "g_rw_t2dmsa", 0, 0, 1, 8, false, false, + MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ(MD, TestMD.get(0, GV, "g_rw_t2dmsa", 0, 0, 1, 8, false, false, false, TestMD.get(0, 5))); // cbuffer cb0 { float4 g_X; float4 g_Y; } - Symbol = UndefValue::get( - StructType::create(Context, {Floatx4Ty, Floatx4Ty}, "cb0")); - Resource = ResourceInfo::CBuffer(Symbol, "cb0", /*Size=*/32); - Resource.bind(0, 0, 0, 1); - Props = Resource.getAnnotateProps(); + StructType *CBufType0 = + StructType::create(Context, {Floatx4Ty, Floatx4Ty}, "cb0"); + RTI = ResourceTypeInfo( + llvm::TargetExtType::get(Context, "dx.CBuffer", CBufType0, {})); + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::CBuffer); + EXPECT_EQ(RTI.getCBufferSize(DL), 32u); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::CBuffer); + + RBI = ResourceBindingInfo( + /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, + RTI.getHandleTy()); + GV = RBI.createSymbol(M, RTI.createElementStruct(), ""); + Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x0000000dU); EXPECT_EQ(Props.second, 0x00000020U); - MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "cb0", 0, 0, 1, 32, nullptr)); + MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ(MD, TestMD.get(0, GV, "", 0, 0, 1, 32, nullptr)); // SamplerState ColorMapSampler : register(s0); - Symbol = UndefValue::get( - StructType::create(Context, {Int32Ty}, "struct.SamplerState")); - Resource = ResourceInfo::Sampler(Symbol, "ColorMapSampler", - dxil::SamplerType::Default); - Resource.bind(0, 0, 0, 1); - Props = Resource.getAnnotateProps(); + RTI = ResourceTypeInfo(llvm::TargetExtType::get( + Context, "dx.Sampler", {}, + {llvm::to_underlying(dxil::SamplerType::Default)})); + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::Sampler); + EXPECT_EQ(RTI.getSamplerType(), dxil::SamplerType::Default); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::Sampler); + + RBI = ResourceBindingInfo( + /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, + RTI.getHandleTy()); + GV = RBI.createSymbol(M, RTI.createElementStruct(), "ColorMapSampler"); + Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x0000000eU); EXPECT_EQ(Props.second, 0U); - MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ(MD, - TestMD.get(0, Symbol, "ColorMapSampler", 0, 0, 1, 0, nullptr)); + MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ(MD, TestMD.get(0, GV, "ColorMapSampler", 0, 0, 1, 0, nullptr)); + + RTI = ResourceTypeInfo(llvm::TargetExtType::get( + Context, "dx.Sampler", {}, + {llvm::to_underlying(dxil::SamplerType::Comparison)})); + EXPECT_EQ(RTI.getResourceClass(), ResourceClass::Sampler); + EXPECT_EQ(RTI.getSamplerType(), dxil::SamplerType::Comparison); + EXPECT_EQ(RTI.getResourceKind(), ResourceKind::Sampler); - // SamplerComparisonState ShadowSampler {...}; - Resource = ResourceInfo::Sampler(Symbol, "CmpSampler", - dxil::SamplerType::Comparison); - Resource.bind(0, 0, 0, 1); - Props = Resource.getAnnotateProps(); + RBI = ResourceBindingInfo( + /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, + RTI.getHandleTy()); + GV = RBI.createSymbol(M, RTI.createElementStruct(), "CmpSampler"); + Props = RBI.getAnnotateProps(M, RTI); EXPECT_EQ(Props.first, 0x0000800eU); EXPECT_EQ(Props.second, 0U); - MD = Resource.getAsMetadata(Context); - EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "CmpSampler", 0, 0, 1, 1, nullptr)); + MD = RBI.getAsMetadata(M, RTI); + EXPECT_MDEQ(MD, TestMD.get(0, GV, "CmpSampler", 0, 0, 1, 1, nullptr)); }