@@ -612,6 +612,9 @@ struct RegisterBindingFlags {
612612
613613 bool ContainsNumeric = false ;
614614 bool DefaultGlobals = false ;
615+
616+ // used only when Resource == true
617+ std::optional<llvm::dxil::ResourceClass> ResourceClass;
615618};
616619
617620static bool isDeclaredWithinCOrTBuffer (const Decl *TheDecl) {
@@ -677,65 +680,38 @@ static const T *getSpecifiedHLSLAttrFromVarDecl(VarDecl *VD) {
677680 return getSpecifiedHLSLAttrFromRecordDecl<T>(TheRecordDecl);
678681}
679682
680- static void updateFlagsFromType (QualType TheQualTy ,
681- RegisterBindingFlags &Flags);
682-
683- static void updateResourceClassFlagsFromRecordDecl (RegisterBindingFlags &Flags,
684- const RecordDecl *RD) {
685- if (!RD)
686- return ;
687-
688- if (RD-> isCompleteDefinition ()) {
689- for ( auto Field : RD-> fields ()) {
690- QualType T = Field-> getType () ;
691- updateFlagsFromType (T, Flags) ;
683+ static void updateResourceClassFlagsFromRecordType (RegisterBindingFlags &Flags ,
684+ const RecordType *RT) {
685+ llvm::SmallVector< const Type *> TypesToScan;
686+ TypesToScan. emplace_back (RT);
687+
688+ while (!TypesToScan. empty ()) {
689+ const Type *T = TypesToScan. pop_back_val () ;
690+ while (T-> isArrayType ())
691+ T = T-> getArrayElementTypeNoTypeQual ();
692+ if (T-> isIntegralOrEnumerationType () || T-> isFloatingType ()) {
693+ Flags. ContainsNumeric = true ;
694+ continue ;
692695 }
693- }
694- }
695-
696- static void updateFlagsFromType (QualType TheQualTy,
697- RegisterBindingFlags &Flags) {
698- // if the member's type is a numeric type, set the ContainsNumeric flag
699- if (TheQualTy->isIntegralOrEnumerationType () || TheQualTy->isFloatingType ()) {
700- Flags.ContainsNumeric = true ;
701- return ;
702- }
703-
704- const clang::Type *TheBaseType = TheQualTy.getTypePtr ();
705- while (TheBaseType->isArrayType ())
706- TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual ();
707- // otherwise, if the member's base type is not a record type, return
708- const RecordType *TheRecordTy = TheBaseType->getAs <RecordType>();
709- if (!TheRecordTy)
710- return ;
711-
712- RecordDecl *SubRecordDecl = TheRecordTy->getDecl ();
713- const HLSLResourceClassAttr *Attr =
714- getSpecifiedHLSLAttrFromRecordDecl<HLSLResourceClassAttr>(SubRecordDecl);
715- // find the attr if it's on the member, or on any of the member's fields
716- if (Attr) {
717- llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass ();
718- updateResourceClassFlagsFromDeclResourceClass (Flags, DeclResourceClass);
719- }
696+ const RecordType *RT = T->getAs <RecordType>();
697+ if (!RT)
698+ continue ;
720699
721- // otherwise, dig deeper and recurse into the member
722- else {
723- updateResourceClassFlagsFromRecordDecl (Flags, SubRecordDecl);
700+ const RecordDecl *RD = RT->getDecl ();
701+ for (FieldDecl *FD : RD->fields ()) {
702+ if (HLSLResourceClassAttr *RCAttr =
703+ FD->getAttr <HLSLResourceClassAttr>()) {
704+ updateResourceClassFlagsFromDeclResourceClass (
705+ Flags, RCAttr->getResourceClass ());
706+ continue ;
707+ }
708+ TypesToScan.emplace_back (FD->getType ().getTypePtr ());
709+ }
724710 }
725711}
726712
727713static RegisterBindingFlags HLSLFillRegisterBindingFlags (Sema &S,
728714 Decl *TheDecl) {
729-
730- // Cbuffers and Tbuffers are HLSLBufferDecl types
731- HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
732- // Samplers, UAVs, and SRVs are VarDecl types
733- VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
734-
735- assert (((TheVarDecl && !CBufferOrTBuffer) ||
736- (!TheVarDecl && CBufferOrTBuffer)) &&
737- " either TheVarDecl or CBufferOrTBuffer should be set" );
738-
739715 RegisterBindingFlags Flags;
740716
741717 // check if the decl type is groupshared
@@ -744,58 +720,60 @@ static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
744720 return Flags;
745721 }
746722
747- if (!isDeclaredWithinCOrTBuffer (TheDecl)) {
748- // make sure the type is a basic / numeric type
749- if (TheVarDecl) {
750- QualType TheQualTy = TheVarDecl->getType ();
751- // a numeric variable or an array of numeric variables
752- // will inevitably end up in $Globals buffer
753- const clang::Type *TheBaseType = TheQualTy.getTypePtr ();
754- while (TheBaseType->isArrayType ())
755- TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual ();
756- if (TheBaseType->isIntegralType (S.getASTContext ()) ||
757- TheBaseType->isFloatingType ())
758- Flags.DefaultGlobals = true ;
759- }
760- }
761-
762- if (CBufferOrTBuffer) {
723+ // Cbuffers and Tbuffers are HLSLBufferDecl types
724+ if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
763725 Flags.Resource = true ;
764- if (CBufferOrTBuffer->isCBuffer ())
765- Flags.CBV = true ;
766- else
767- Flags.SRV = true ;
768- } else if (TheVarDecl) {
726+ Flags.ResourceClass = CBufferOrTBuffer->isCBuffer ()
727+ ? llvm::dxil::ResourceClass::CBuffer
728+ : llvm::dxil::ResourceClass::SRV;
729+ }
730+ // Samplers, UAVs, and SRVs are VarDecl types
731+ else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
769732 const HLSLResourceClassAttr *resClassAttr =
770733 getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl);
771-
772734 if (resClassAttr) {
773- llvm::hlsl::ResourceClass DeclResourceClass =
774- resClassAttr->getResourceClass ();
775735 Flags.Resource = true ;
776- updateResourceClassFlagsFromDeclResourceClass ( Flags, DeclResourceClass );
736+ Flags. ResourceClass = resClassAttr-> getResourceClass ( );
777737 } else {
778738 const clang::Type *TheBaseType = TheVarDecl->getType ().getTypePtr ();
779739 while (TheBaseType->isArrayType ())
780740 TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual ();
781- if (TheBaseType->isArithmeticType ())
741+
742+ if (TheBaseType->isArithmeticType ()) {
782743 Flags.Basic = true ;
783- else if (TheBaseType->isRecordType ()) {
744+ if (!isDeclaredWithinCOrTBuffer (TheDecl) &&
745+ (TheBaseType->isIntegralType (S.getASTContext ()) ||
746+ TheBaseType->isFloatingType ()))
747+ Flags.DefaultGlobals = true ;
748+ } else if (TheBaseType->isRecordType ()) {
784749 Flags.UDT = true ;
785750 const RecordType *TheRecordTy = TheBaseType->getAs <RecordType>();
786- assert (TheRecordTy && " The Qual Type should be Record Type" );
787- const RecordDecl *TheRecordDecl = TheRecordTy->getDecl ();
788- // recurse through members, set appropriate resource class flags.
789- updateResourceClassFlagsFromRecordDecl (Flags, TheRecordDecl);
751+ updateResourceClassFlagsFromRecordType (Flags, TheRecordTy);
790752 } else
791753 Flags.Other = true ;
792754 }
755+ } else {
756+ llvm_unreachable (" expected be VarDecl or HLSLBufferDecl" );
793757 }
794758 return Flags;
795759}
796760
797761enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
798762
763+ static RegisterType getRegisterType (llvm::dxil::ResourceClass RC) {
764+ switch (RC) {
765+ case llvm::dxil::ResourceClass::SRV:
766+ return RegisterType::SRV;
767+ case llvm::dxil::ResourceClass::UAV:
768+ return RegisterType::UAV;
769+ case llvm::dxil::ResourceClass::CBuffer:
770+ return RegisterType::CBuffer;
771+ case llvm::dxil::ResourceClass::Sampler:
772+ return RegisterType::Sampler;
773+ }
774+ llvm_unreachable (" unexpected ResourceClass value" );
775+ }
776+
799777static RegisterType getRegisterType (StringRef Slot) {
800778 switch (Slot[0 ]) {
801779 case ' t' :
@@ -865,6 +843,8 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
865843 assert (((TheVarDecl && !CBufferOrTBuffer) ||
866844 (!TheVarDecl && CBufferOrTBuffer)) &&
867845 " either TheVarDecl or CBufferOrTBuffer should be set" );
846+ (void )TheVarDecl;
847+ (void )CBufferOrTBuffer;
868848
869849 RegisterBindingFlags Flags = HLSLFillRegisterBindingFlags (S, TheDecl);
870850 assert ((int )Flags.Other + (int )Flags.Resource + (int )Flags.Basic +
@@ -886,34 +866,8 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
886866 // next, if resource is set, make sure the register type in the register
887867 // annotation is compatible with the variable's resource type.
888868 if (Flags.Resource ) {
889- const HLSLResourceClassAttr *resClassAttr = nullptr ;
890- if (CBufferOrTBuffer) {
891- resClassAttr = CBufferOrTBuffer->getAttr <HLSLResourceClassAttr>();
892- } else if (TheVarDecl) {
893- resClassAttr =
894- getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl);
895- }
896-
897- assert (resClassAttr &&
898- " any decl that set the resource flag on analysis should "
899- " have a resource class attribute attached." );
900- const llvm::hlsl::ResourceClass DeclResourceClass =
901- resClassAttr->getResourceClass ();
902-
903- // confirm that the register type is bound to its expected resource class
904- static RegisterType ExpectedRegisterTypesForResourceClass[] = {
905- RegisterType::SRV,
906- RegisterType::UAV,
907- RegisterType::CBuffer,
908- RegisterType::Sampler,
909- };
910- assert ((size_t )DeclResourceClass <
911- std::size (ExpectedRegisterTypesForResourceClass) &&
912- " DeclResourceClass has unexpected value" );
913-
914- RegisterType ExpectedRegisterType =
915- ExpectedRegisterTypesForResourceClass[(int )DeclResourceClass];
916- if (regType != ExpectedRegisterType) {
869+ RegisterType expRegType = getRegisterType (Flags.ResourceClass .value ());
870+ if (regType != expRegType) {
917871 S.Diag (TheDecl->getLocation (), diag::err_hlsl_binding_type_mismatch)
918872 << regTypeNum;
919873 }
@@ -955,7 +909,7 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
955909}
956910
957911void SemaHLSL::handleResourceBindingAttr (Decl *TheDecl, const ParsedAttr &AL) {
958- if (dyn_cast <VarDecl>(TheDecl)) {
912+ if (isa <VarDecl>(TheDecl)) {
959913 if (SemaRef.RequireCompleteType (TheDecl->getBeginLoc (),
960914 cast<ValueDecl>(TheDecl)->getType (),
961915 diag::err_incomplete_type))
0 commit comments