@@ -331,6 +331,11 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
331331
332332private:
333333 bool CheckSYCLType (QualType Ty, SourceRange Loc) {
334+ llvm::DenseSet<QualType> visited;
335+ return CheckSYCLType (Ty, Loc, visited);
336+ }
337+
338+ bool CheckSYCLType (QualType Ty, SourceRange Loc, llvm::DenseSet<QualType> &Visited) {
334339 if (Ty->isVariableArrayType ()) {
335340 SemaRef.Diag (Loc.getBegin (), diag::err_vla_unsupported);
336341 return false ;
@@ -339,6 +344,11 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
339344 while (Ty->isAnyPointerType () || Ty->isArrayType ())
340345 Ty = QualType{Ty->getPointeeOrArrayElementType (), 0 };
341346
347+ // Pointers complicate recursion. Add this type to Visited.
348+ // If already there, bail out.
349+ if (!Visited.insert (Ty).second )
350+ return true ;
351+
342352 if (const auto *CRD = Ty->getAsCXXRecordDecl ()) {
343353 if (CRD->isPolymorphic ()) {
344354 SemaRef.Diag (CRD->getLocation (), diag::err_sycl_virtual_types);
@@ -347,25 +357,25 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
347357 }
348358
349359 for (const auto &Field : CRD->fields ()) {
350- if (!CheckSYCLType (Field->getType (), Field->getSourceRange ())) {
360+ if (!CheckSYCLType (Field->getType (), Field->getSourceRange (), Visited )) {
351361 SemaRef.Diag (Loc.getBegin (), diag::note_sycl_used_here);
352362 return false ;
353363 }
354364 }
355365 } else if (const auto *RD = Ty->getAsRecordDecl ()) {
356366 for (const auto &Field : RD->fields ()) {
357- if (!CheckSYCLType (Field->getType (), Field->getSourceRange ())) {
367+ if (!CheckSYCLType (Field->getType (), Field->getSourceRange (), Visited )) {
358368 SemaRef.Diag (Loc.getBegin (), diag::note_sycl_used_here);
359369 return false ;
360370 }
361371 }
362372 } else if (const auto *FPTy = dyn_cast<FunctionProtoType>(Ty)) {
363373 for (const auto &ParamTy : FPTy->param_types ())
364- if (!CheckSYCLType (ParamTy, Loc))
374+ if (!CheckSYCLType (ParamTy, Loc, Visited ))
365375 return false ;
366- return CheckSYCLType (FPTy->getReturnType (), Loc);
376+ return CheckSYCLType (FPTy->getReturnType (), Loc, Visited );
367377 } else if (const auto *FTy = dyn_cast<FunctionType>(Ty)) {
368- return CheckSYCLType (FTy->getReturnType (), Loc);
378+ return CheckSYCLType (FTy->getReturnType (), Loc, Visited );
369379 }
370380 return true ;
371381 }
@@ -766,6 +776,16 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
766776
767777 // Create descriptors for each accessor field in the class or struct
768778 createParamDescForWrappedAccessors (Fld, ArgTy);
779+ } else if (ArgTy->isPointerType ()) {
780+ // Pointer Arguments need to be in the global address space
781+ QualType PointeeTy = ArgTy->getPointeeType ();
782+ Qualifiers Quals = PointeeTy.getQualifiers ();
783+ Quals.setAddressSpace (LangAS::opencl_global);
784+ PointeeTy = Context.getQualifiedType (PointeeTy.getUnqualifiedType (),
785+ Quals);
786+ QualType ModTy = Context.getPointerType (PointeeTy);
787+
788+ CreateAndAddPrmDsc (Fld, ModTy);
769789 } else if (ArgTy->isScalarType ()) {
770790 CreateAndAddPrmDsc (Fld, ArgTy);
771791 } else {
@@ -853,6 +873,10 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
853873 uint64_t Sz = Ctx.getTypeSizeInChars (SamplerArg->getType ()).getQuantity ();
854874 H.addParamDesc (SYCLIntegrationHeader::kind_sampler,
855875 static_cast <unsigned >(Sz), static_cast <unsigned >(Offset));
876+ } else if (ArgTy->isPointerType ()) {
877+ uint64_t Sz = Ctx.getTypeSizeInChars (Fld->getType ()).getQuantity ();
878+ H.addParamDesc (SYCLIntegrationHeader::kind_pointer,
879+ static_cast <unsigned >(Sz), static_cast <unsigned >(Offset));
856880 } else if (ArgTy->isStructureOrClassType () || ArgTy->isScalarType ()) {
857881 // the parameter is an object of standard layout type or scalar;
858882 // the check for standard layout is done elsewhere
@@ -1017,6 +1041,7 @@ static const char *paramKind2Str(KernelParamKind K) {
10171041 CASE (accessor);
10181042 CASE (std_layout);
10191043 CASE (sampler);
1044+ CASE (pointer);
10201045 default :
10211046 return " <ERROR>" ;
10221047 }
0 commit comments