@@ -38,18 +38,6 @@ enum target {
38
38
image_array
39
39
};
40
40
41
- enum RestrictKind {
42
- KernelGlobalVariable,
43
- KernelRTTI,
44
- KernelNonConstStaticDataVariable,
45
- KernelCallVirtualFunction,
46
- KernelCallRecursiveFunction,
47
- KernelCallFunctionPointer,
48
- KernelAllocateStorage,
49
- KernelUseExceptions,
50
- KernelUseAssembly
51
- };
52
-
53
41
using ParamDesc = std::tuple<QualType, IdentifierInfo *, TypeSourceInfo *>;
54
42
55
43
/// Various utilities.
@@ -95,16 +83,16 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
95
83
// definitions.
96
84
if (RecursiveSet.count(Callee)) {
97
85
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict)
98
- << KernelCallRecursiveFunction;
86
+ << Sema:: KernelCallRecursiveFunction;
99
87
SemaRef.Diag(Callee->getSourceRange().getBegin(),
100
88
diag::note_sycl_recursive_function_declared_here)
101
- << KernelCallRecursiveFunction;
89
+ << Sema:: KernelCallRecursiveFunction;
102
90
}
103
91
104
92
if (const CXXMethodDecl *Method = dyn_cast<CXXMethodDecl>(Callee))
105
93
if (Method->isVirtual())
106
94
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict)
107
- << KernelCallVirtualFunction;
95
+ << Sema:: KernelCallVirtualFunction;
108
96
109
97
CheckSYCLType(Callee->getReturnType(), Callee->getSourceRange());
110
98
@@ -116,7 +104,7 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
116
104
}
117
105
} else if (!SemaRef.getLangOpts().SYCLAllowFuncPtr)
118
106
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict)
119
- << KernelCallFunctionPointer;
107
+ << Sema:: KernelCallFunctionPointer;
120
108
return true;
121
109
}
122
110
@@ -144,12 +132,12 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
144
132
}
145
133
146
134
bool VisitCXXTypeidExpr(CXXTypeidExpr *E) {
147
- SemaRef.Diag(E->getExprLoc(), diag::err_sycl_restrict) << KernelRTTI;
135
+ SemaRef.Diag(E->getExprLoc(), diag::err_sycl_restrict) << Sema:: KernelRTTI;
148
136
return true;
149
137
}
150
138
151
139
bool VisitCXXDynamicCastExpr(const CXXDynamicCastExpr *E) {
152
- SemaRef.Diag(E->getExprLoc(), diag::err_sycl_restrict) << KernelRTTI;
140
+ SemaRef.Diag(E->getExprLoc(), diag::err_sycl_restrict) << Sema:: KernelRTTI;
153
141
return true;
154
142
}
155
143
@@ -178,7 +166,7 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
178
166
bool IsConst = VD->getType().getNonReferenceType().isConstQualified();
179
167
if (!IsConst && VD->isStaticDataMember())
180
168
SemaRef.Diag(E->getExprLoc(), diag::err_sycl_restrict)
181
- << KernelNonConstStaticDataVariable;
169
+ << Sema:: KernelNonConstStaticDataVariable;
182
170
}
183
171
return true;
184
172
}
@@ -189,11 +177,11 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
189
177
bool IsConst = VD->getType().getNonReferenceType().isConstQualified();
190
178
if (!IsConst && VD->isStaticDataMember())
191
179
SemaRef.Diag(E->getExprLoc(), diag::err_sycl_restrict)
192
- << KernelNonConstStaticDataVariable;
180
+ << Sema:: KernelNonConstStaticDataVariable;
193
181
else if (!IsConst && VD->hasGlobalStorage() && !VD->isStaticLocal() &&
194
182
!VD->isStaticDataMember() && !isa<ParmVarDecl>(VD))
195
183
SemaRef.Diag(E->getLocation(), diag::err_sycl_restrict)
196
- << KernelGlobalVariable;
184
+ << Sema:: KernelGlobalVariable;
197
185
if (!VD->isLocalVarDeclOrParm() && VD->hasGlobalStorage()) {
198
186
VD->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
199
187
SemaRef.addSyclDeviceDecl(VD);
@@ -213,7 +201,7 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
213
201
if (FunctionDecl *FD = E->getOperatorNew()) {
214
202
if (FD->isReplaceableGlobalAllocationFunction()) {
215
203
SemaRef.Diag(E->getExprLoc(), diag::err_sycl_restrict)
216
- << KernelAllocateStorage;
204
+ << Sema:: KernelAllocateStorage;
217
205
} else if (FunctionDecl *Def = FD->getDefinition()) {
218
206
if (!Def->hasAttr<SYCLDeviceAttr>()) {
219
207
Def->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
@@ -223,40 +211,16 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
223
211
}
224
212
return true;
225
213
}
226
-
227
- bool VisitCXXThrowExpr(CXXThrowExpr *E) {
228
- SemaRef.Diag(E->getExprLoc(), diag::err_sycl_restrict)
229
- << KernelUseExceptions;
230
- return true;
231
- }
232
-
233
- bool VisitCXXCatchStmt(CXXCatchStmt *S) {
234
- SemaRef.Diag(S->getBeginLoc(), diag::err_sycl_restrict)
235
- << KernelUseExceptions;
236
- return true;
237
- }
238
-
239
- bool VisitCXXTryStmt(CXXTryStmt *S) {
240
- SemaRef.Diag(S->getBeginLoc(), diag::err_sycl_restrict)
241
- << KernelUseExceptions;
242
- return true;
243
- }
244
-
245
- bool VisitSEHTryStmt(SEHTryStmt *S) {
246
- SemaRef.Diag(S->getBeginLoc(), diag::err_sycl_restrict)
247
- << KernelUseExceptions;
248
- return true;
249
- }
250
-
214
+
251
215
bool VisitGCCAsmStmt(GCCAsmStmt *S) {
252
216
SemaRef.Diag(S->getBeginLoc(), diag::err_sycl_restrict)
253
- << KernelUseAssembly;
217
+ << Sema:: KernelUseAssembly;
254
218
return true;
255
219
}
256
-
220
+
257
221
bool VisitMSAsmStmt(MSAsmStmt *S) {
258
222
SemaRef.Diag(S->getBeginLoc(), diag::err_sycl_restrict)
259
- << KernelUseAssembly;
223
+ << Sema:: KernelUseAssembly;
260
224
return true;
261
225
}
262
226
@@ -361,21 +325,31 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
361
325
return true;
362
326
363
327
if (CRD->isPolymorphic()) {
364
- SemaRef.Diag(CRD->getLocation(), diag::err_sycl_virtual_types);
365
- SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here);
328
+ // Exceptions aren't allowed in SYCL device code.
329
+ if (SemaRef.getLangOpts().SYCLIsDevice) {
330
+ SemaRef.SYCLDiagIfDeviceCode(CRD->getLocation(),
331
+ diag::err_sycl_restrict)
332
+ << Sema::KernelHavePolymorphicClass;
333
+ SemaRef.SYCLDiagIfDeviceCode(Loc.getBegin(),
334
+ diag::note_sycl_used_here);
335
+ }
366
336
return false;
367
337
}
368
338
369
339
for (const auto &Field : CRD->fields()) {
370
340
if (!CheckSYCLType(Field->getType(), Field->getSourceRange(), Visited)) {
371
- SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here);
341
+ if (SemaRef.getLangOpts().SYCLIsDevice)
342
+ SemaRef.SYCLDiagIfDeviceCode(Loc.getBegin(),
343
+ diag::note_sycl_used_here);
372
344
return false;
373
345
}
374
346
}
375
347
} else if (const auto *RD = Ty->getAsRecordDecl()) {
376
348
for (const auto &Field : RD->fields()) {
377
349
if (!CheckSYCLType(Field->getType(), Field->getSourceRange(), Visited)) {
378
- SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here);
350
+ if (SemaRef.getLangOpts().SYCLIsDevice)
351
+ SemaRef.SYCLDiagIfDeviceCode(Loc.getBegin(),
352
+ diag::note_sycl_used_here);
379
353
return false;
380
354
}
381
355
}
@@ -1036,6 +1010,55 @@ void Sema::MarkDevice(void) {
1036
1010
}
1037
1011
}
1038
1012
}
1013
+ //
1014
+ // Do we know that we will eventually codegen the given function?
1015
+ static bool isKnownEmitted(Sema &S, FunctionDecl *FD) {
1016
+ if (!FD)
1017
+ return true; // Seen in LIT testing
1018
+
1019
+ if (FD->hasAttr<SYCLDeviceAttr>() ||
1020
+ FD->hasAttr<SYCLKernelAttr>())
1021
+ return true;
1022
+
1023
+ // Templates are emitted when they're instantiated.
1024
+ if (FD->isDependentContext())
1025
+ return false;
1026
+
1027
+ // Otherwise, the function is known-emitted if it's in our set of
1028
+ // known-emitted functions.
1029
+ return S.DeviceKnownEmittedFns.count(FD) > 0;
1030
+ }
1031
+
1032
+ Sema::DeviceDiagBuilder Sema::SYCLDiagIfDeviceCode(SourceLocation Loc,
1033
+ unsigned DiagID) {
1034
+ assert(getLangOpts().SYCLIsDevice &&
1035
+ "Should only be called during SYCL compilation");
1036
+ DeviceDiagBuilder::Kind DiagKind = [this] {
1037
+ if (isKnownEmitted(*this, dyn_cast<FunctionDecl>(CurContext)))
1038
+ return DeviceDiagBuilder::K_ImmediateWithCallStack;
1039
+ else
1040
+ return DeviceDiagBuilder::K_Deferred;
1041
+ }();
1042
+ return DeviceDiagBuilder(DiagKind, Loc, DiagID,
1043
+ dyn_cast<FunctionDecl>(CurContext), *this);
1044
+ }
1045
+
1046
+ bool Sema::CheckSYCLCall(SourceLocation Loc, FunctionDecl *Callee) {
1047
+
1048
+ assert(Callee && "Callee may not be null.");
1049
+ FunctionDecl *Caller = getCurFunctionDecl();
1050
+
1051
+ // If the caller is known-emitted, mark the callee as known-emitted.
1052
+ // Otherwise, mark the call in our call graph so we can traverse it later.
1053
+ if (//!isOpenMPDeviceDelayedContext(*this) ||
1054
+ (Caller && Caller->hasAttr<SYCLKernelAttr>()) ||
1055
+ (Caller && Caller->hasAttr<SYCLDeviceAttr>()) ||
1056
+ (Caller && isKnownEmitted(*this, Caller)))
1057
+ markKnownEmitted(*this, Caller, Callee, Loc, isKnownEmitted);
1058
+ else if (Caller)
1059
+ DeviceCallGraph[Caller].insert({Callee, Loc});
1060
+ return true;
1061
+ }
1039
1062
1040
1063
// -----------------------------------------------------------------------------
1041
1064
// Integration header functionality implementation
0 commit comments