@@ -3344,58 +3344,6 @@ static const char *paramKind2Str(KernelParamKind K) {
33443344#undef CASE
33453345}
33463346
3347- // Emits a forward declaration
3348- void SYCLIntegrationHeader::emitFwdDecl (raw_ostream &O, const Decl *D,
3349- SourceLocation KernelLocation) {
3350- // wrap the declaration into namespaces if needed
3351- unsigned NamespaceCnt = 0 ;
3352- std::string NSStr = " " ;
3353- const DeclContext *DC = D->getDeclContext ();
3354-
3355- while (DC) {
3356- auto *NS = dyn_cast_or_null<NamespaceDecl>(DC);
3357-
3358- if (!NS) {
3359- break ;
3360- }
3361-
3362- ++NamespaceCnt;
3363- const StringRef NSInlinePrefix = NS->isInline () ? " inline " : " " ;
3364- NSStr.insert (
3365- 0 , Twine (NSInlinePrefix + " namespace " + NS->getName () + " { " ).str ());
3366- DC = NS->getDeclContext ();
3367- }
3368- O << NSStr;
3369- if (NamespaceCnt > 0 )
3370- O << " \n " ;
3371- // print declaration into a string:
3372- PrintingPolicy P (D->getASTContext ().getLangOpts ());
3373- P.adjustForCPlusPlusFwdDecl ();
3374- P.SuppressTypedefs = true ;
3375- P.SuppressUnwrittenScope = true ;
3376- std::string S;
3377- llvm::raw_string_ostream SO (S);
3378- D->print (SO, P);
3379- O << SO.str ();
3380-
3381- if (const auto *ED = dyn_cast<EnumDecl>(D)) {
3382- QualType T = ED->getIntegerType ();
3383- // Backup since getIntegerType() returns null for enum forward
3384- // declaration with no fixed underlying type
3385- if (T.isNull ())
3386- T = ED->getPromotionType ();
3387- O << " : " << T.getAsString ();
3388- }
3389-
3390- O << " ;\n " ;
3391-
3392- // print closing braces for namespaces if needed
3393- for (unsigned I = 0 ; I < NamespaceCnt; ++I)
3394- O << " }" ;
3395- if (NamespaceCnt > 0 )
3396- O << " \n " ;
3397- }
3398-
33993347// Emits forward declarations of classes and template classes on which
34003348// declaration of given type depends.
34013349// For example, consider SimpleVadd
@@ -3432,126 +3380,176 @@ void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D,
34323380// template <typename T> class MyTmplClass;
34333381// template <typename T1, unsigned int N, typename ...T2> class SimpleVadd;
34343382//
3435- void SYCLIntegrationHeader::emitForwardClassDecls (
3436- raw_ostream &O, QualType T, SourceLocation KernelLocation,
3437- llvm::SmallPtrSetImpl<const void *> &Printed) {
3383+ class SYCLFwdDeclEmitter
3384+ : public TypeVisitor<SYCLFwdDeclEmitter>,
3385+ public ConstTemplateArgumentVisitor<SYCLFwdDeclEmitter> {
3386+ using InnerTypeVisitor = TypeVisitor<SYCLFwdDeclEmitter>;
3387+ using InnerTemplArgVisitor = ConstTemplateArgumentVisitor<SYCLFwdDeclEmitter>;
3388+ raw_ostream &OS;
3389+ llvm::SmallPtrSet<const NamedDecl *, 4 > Printed;
3390+ PrintingPolicy Policy;
34383391
3439- // peel off the pointer types and get the class/struct type:
3440- for (; T->isPointerType (); T = T->getPointeeType ())
3441- ;
3442- const CXXRecordDecl *RD = T->getAsCXXRecordDecl ();
3392+ void printForwardDecl (NamedDecl *D) {
3393+ // wrap the declaration into namespaces if needed
3394+ unsigned NamespaceCnt = 0 ;
3395+ std::string NSStr = " " ;
3396+ const DeclContext *DC = D->getDeclContext ();
34433397
3444- if (!RD) {
3398+ while (DC) {
3399+ const auto *NS = dyn_cast_or_null<NamespaceDecl>(DC);
34453400
3446- return ;
3401+ if (!NS)
3402+ break ;
3403+
3404+ ++NamespaceCnt;
3405+ const StringRef NSInlinePrefix = NS->isInline () ? " inline " : " " ;
3406+ NSStr.insert (
3407+ 0 ,
3408+ Twine (NSInlinePrefix + " namespace " + NS->getName () + " { " ).str ());
3409+ DC = NS->getDeclContext ();
3410+ }
3411+ OS << NSStr;
3412+ if (NamespaceCnt > 0 )
3413+ OS << " \n " ;
3414+
3415+ D->print (OS, Policy);
3416+
3417+ if (const auto *ED = dyn_cast<EnumDecl>(D)) {
3418+ QualType T = ED->getIntegerType ();
3419+ // Backup since getIntegerType() returns null for enum forward
3420+ // declaration with no fixed underlying type
3421+ if (T.isNull ())
3422+ T = ED->getPromotionType ();
3423+ OS << " : " << T.getAsString ();
3424+ }
3425+
3426+ OS << " ;\n " ;
3427+
3428+ // print closing braces for namespaces if needed
3429+ for (unsigned I = 0 ; I < NamespaceCnt; ++I)
3430+ OS << " }" ;
3431+ if (NamespaceCnt > 0 )
3432+ OS << " \n " ;
34473433 }
34483434
3449- // see if this is a template specialization ...
3450- if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
3451- // ... yes, it is template specialization:
3452- // - first, recurse into template parameters and emit needed forward
3453- // declarations
3454- const TemplateArgumentList &Args = TSD->getTemplateArgs ();
3435+ // Checks if we've already printed forward declaration and prints it if not.
3436+ void checkAndEmitForwardDecl (NamedDecl *D) {
3437+ if (Printed.insert (D).second )
3438+ printForwardDecl (D);
3439+ }
34553440
3456- for (unsigned I = 0 ; I < Args.size (); I++) {
3457- const TemplateArgument &Arg = Args[I];
3441+ void VisitTemplateArgs (ArrayRef<TemplateArgument> Args) {
3442+ for (size_t I = 0 , E = Args.size (); I < E; ++I)
3443+ Visit (Args[I]);
3444+ }
34583445
3459- switch (Arg.getKind ()) {
3460- case TemplateArgument::ArgKind::Type:
3461- case TemplateArgument::ArgKind::Integral: {
3462- QualType T = (Arg.getKind () == TemplateArgument::ArgKind::Type)
3463- ? Arg.getAsType ()
3464- : Arg.getIntegralType ();
3465-
3466- // Handle Kernel Name Type templated using enum type and value.
3467- if (const auto *ET = T->getAs <EnumType>()) {
3468- const EnumDecl *ED = ET->getDecl ();
3469- emitFwdDecl (O, ED, KernelLocation);
3470- } else if (Arg.getKind () == TemplateArgument::ArgKind::Type)
3471- emitForwardClassDecls (O, T, KernelLocation, Printed);
3472- break ;
3473- }
3474- case TemplateArgument::ArgKind::Pack: {
3475- ArrayRef<TemplateArgument> Pack = Arg.getPackAsArray ();
3446+ public:
3447+ SYCLFwdDeclEmitter (raw_ostream &OS, LangOptions LO) : OS(OS), Policy(LO) {
3448+ Policy.adjustForCPlusPlusFwdDecl ();
3449+ Policy.SuppressTypedefs = true ;
3450+ Policy.SuppressUnwrittenScope = true ;
3451+ }
34763452
3477- for (const auto &T : Pack) {
3478- if (T.getKind () == TemplateArgument::ArgKind::Type) {
3479- emitForwardClassDecls (O, T.getAsType (), KernelLocation, Printed);
3480- }
3481- }
3482- break ;
3483- }
3484- case TemplateArgument::ArgKind::Template: {
3485- // recursion is not required, since the maximum possible nesting level
3486- // equals two for template argument
3487- //
3488- // for example:
3489- // template <typename T> class Bar;
3490- // template <template <typename> class> class Baz;
3491- // template <template <template <typename> class> class T>
3492- // class Foo;
3493- //
3494- // The Baz is a template class. The Baz<Bar> is a class. The class Foo
3495- // should be specialized with template class, not a class. The correct
3496- // specialization of template class Foo is Foo<Baz>. The incorrect
3497- // specialization of template class Foo is Foo<Baz<Bar>>. In this case
3498- // template class Foo specialized by class Baz<Bar>, not a template
3499- // class template <template <typename> class> class T as it should.
3500- TemplateDecl *TD = Arg.getAsTemplate ().getAsTemplateDecl ();
3501- TemplateParameterList *TemplateParams = TD->getTemplateParameters ();
3502- for (NamedDecl *P : *TemplateParams) {
3503- // If template template paramter type has an enum value template
3504- // parameter, forward declaration of enum type is required. Only enum
3505- // values (not types) need to be handled. For example, consider the
3506- // following kernel name type:
3507- //
3508- // template <typename EnumTypeOut, template <EnumValueIn EnumValue,
3509- // typename TypeIn> class T> class Foo;
3510- //
3511- // The correct specialization for Foo (with enum type) is:
3512- // Foo<EnumTypeOut, Baz>, where Baz is a template class.
3513- //
3514- // Therefore the forward class declarations generated in the
3515- // integration header are:
3516- // template <EnumValueIn EnumValue, typename TypeIn> class Baz;
3517- // template <typename EnumTypeOut, template <EnumValueIn EnumValue,
3518- // typename EnumTypeIn> class T> class Foo;
3519- //
3520- // This requires the following enum forward declarations:
3521- // enum class EnumTypeOut : int; (Used to template Foo)
3522- // enum class EnumValueIn : int; (Used to template Baz)
3523- if (NonTypeTemplateParmDecl *TemplateParam =
3524- dyn_cast<NonTypeTemplateParmDecl>(P)) {
3525- QualType T = TemplateParam->getType ();
3526- if (const auto *ET = T->getAs <EnumType>()) {
3527- const EnumDecl *ED = ET->getDecl ();
3528- emitFwdDecl (O, ED, KernelLocation);
3529- }
3530- }
3531- }
3532- if (Printed.insert (TD).second ) {
3533- emitFwdDecl (O, TD, KernelLocation);
3534- }
3535- break ;
3536- }
3537- default :
3538- break ; // nop
3539- }
3453+ void Visit (QualType T) {
3454+ if (T.isNull ())
3455+ return ;
3456+ InnerTypeVisitor::Visit (T.getTypePtr ());
3457+ }
3458+
3459+ void Visit (const TemplateArgument &TA) {
3460+ if (TA.isNull ())
3461+ return ;
3462+ InnerTemplArgVisitor::Visit (TA);
3463+ }
3464+
3465+ void VisitPointerType (const PointerType *T) {
3466+ // Peel off the pointer types.
3467+ QualType PT = T->getPointeeType ();
3468+ while (PT->isPointerType ())
3469+ PT = PT->getPointeeType ();
3470+ Visit (PT);
3471+ }
3472+
3473+ void VisitTagType (const TagType *T) {
3474+ TagDecl *TD = T->getDecl ();
3475+ if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(TD)) {
3476+ // - first, recurse into template parameters and emit needed forward
3477+ // declarations
3478+ ArrayRef<TemplateArgument> Args = TSD->getTemplateArgs ().asArray ();
3479+ VisitTemplateArgs (Args);
3480+ // - second, emit forward declaration for the template class being
3481+ // specialized
3482+ ClassTemplateDecl *CTD = TSD->getSpecializedTemplate ();
3483+ assert (CTD && " template declaration must be available" );
3484+
3485+ checkAndEmitForwardDecl (CTD);
3486+ return ;
35403487 }
3541- // - second, emit forward declaration for the template class being
3542- // specialized
3543- ClassTemplateDecl *CTD = TSD->getSpecializedTemplate ();
3544- assert (CTD && " template declaration must be available" );
3488+ checkAndEmitForwardDecl (TD);
3489+ }
3490+
3491+ void VisitTypeTemplateArgument (const TemplateArgument &TA) {
3492+ QualType T = TA.getAsType ();
3493+ Visit (T);
3494+ }
35453495
3546- if (Printed.insert (CTD).second ) {
3547- emitFwdDecl (O, CTD, KernelLocation);
3496+ void VisitIntegralTemplateArgument (const TemplateArgument &TA) {
3497+ QualType T = TA.getIntegralType ();
3498+ if (const EnumType *ET = T->getAs <EnumType>())
3499+ VisitTagType (ET);
3500+ }
3501+
3502+ void VisitTemplateTemplateArgument (const TemplateArgument &TA) {
3503+ // recursion is not required, since the maximum possible nesting level
3504+ // equals two for template argument
3505+ //
3506+ // for example:
3507+ // template <typename T> class Bar;
3508+ // template <template <typename> class> class Baz;
3509+ // template <template <template <typename> class> class T>
3510+ // class Foo;
3511+ //
3512+ // The Baz is a template class. The Baz<Bar> is a class. The class Foo
3513+ // should be specialized with template class, not a class. The correct
3514+ // specialization of template class Foo is Foo<Baz>. The incorrect
3515+ // specialization of template class Foo is Foo<Baz<Bar>>. In this case
3516+ // template class Foo specialized by class Baz<Bar>, not a template
3517+ // class template <template <typename> class> class T as it should.
3518+ TemplateDecl *TD = TA.getAsTemplate ().getAsTemplateDecl ();
3519+ TemplateParameterList *TemplateParams = TD->getTemplateParameters ();
3520+ for (NamedDecl *P : *TemplateParams) {
3521+ // If template template parameter type has an enum value template
3522+ // parameter, forward declaration of enum type is required. Only enum
3523+ // values (not types) need to be handled. For example, consider the
3524+ // following kernel name type:
3525+ //
3526+ // template <typename EnumTypeOut, template <EnumValueIn EnumValue,
3527+ // typename TypeIn> class T> class Foo;
3528+ //
3529+ // The correct specialization for Foo (with enum type) is:
3530+ // Foo<EnumTypeOut, Baz>, where Baz is a template class.
3531+ //
3532+ // Therefore the forward class declarations generated in the
3533+ // integration header are:
3534+ // template <EnumValueIn EnumValue, typename TypeIn> class Baz;
3535+ // template <typename EnumTypeOut, template <EnumValueIn EnumValue,
3536+ // typename EnumTypeIn> class T> class Foo;
3537+ //
3538+ // This requires the following enum forward declarations:
3539+ // enum class EnumTypeOut : int; (Used to template Foo)
3540+ // enum class EnumValueIn : int; (Used to template Baz)
3541+ if (NonTypeTemplateParmDecl *TemplateParam =
3542+ dyn_cast<NonTypeTemplateParmDecl>(P))
3543+ if (const EnumType *ET = TemplateParam->getType ()->getAs <EnumType>())
3544+ VisitTagType (ET);
35483545 }
3549- } else if (Printed.insert (RD).second ) {
3550- // emit forward declarations for "leaf" classes in the template parameter
3551- // tree;
3552- emitFwdDecl (O, RD, KernelLocation);
3546+ checkAndEmitForwardDecl (TD);
35533547 }
3554- }
3548+
3549+ void VisitPackTemplateArgument (const TemplateArgument &TA) {
3550+ VisitTemplateArgs (TA.getPackAsArray ());
3551+ }
3552+ };
35553553
35563554class SYCLKernelNameTypePrinter
35573555 : public TypeVisitor<SYCLKernelNameTypePrinter>,
@@ -3709,10 +3707,9 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
37093707 if (!UnnamedLambdaSupport) {
37103708 O << " // Forward declarations of templated kernel function types:\n " ;
37113709
3712- llvm::SmallPtrSet<const void *, 4 > Printed;
3713- for (const KernelDesc &K : KernelDescs) {
3714- emitForwardClassDecls (O, K.NameType , K.KernelLocation , Printed);
3715- }
3710+ SYCLFwdDeclEmitter FwdDeclEmitter (O, S.getLangOpts ());
3711+ for (const KernelDesc &K : KernelDescs)
3712+ FwdDeclEmitter.Visit (K.NameType );
37163713 }
37173714 O << " \n " ;
37183715
0 commit comments