diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index ad5a3768846c7..631f81fe4a9ab 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -1455,13 +1455,8 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { InitExprs.push_back(ILE); } - void createSpecialMethodCall(const CXXRecordDecl *SpecialClass, Expr *Base, - const std::string &MethodName, - FieldDecl *Field) { - CXXMethodDecl *Method = getMethodByName(SpecialClass, MethodName); - assert(Method && - "The accessor/sampler/stream must have the __init method. Stream" - " must also have __finalize method"); + CXXMemberCallExpr *createSpecialMethodCall(Expr *Base, CXXMethodDecl *Method, + FieldDecl *Field) { unsigned NumParams = Method->getNumParams(); llvm::SmallVector ParamDREs(NumParams); llvm::ArrayRef KernelParameters = @@ -1485,10 +1480,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { CXXMemberCallExpr *Call = CXXMemberCallExpr::Create( SemaRef.Context, MethodME, ParamStmts, ResultTy, VK, SourceLocation(), FPOptionsOverride()); - if (MethodName == FinalizeMethodName) - FinalizeStmts.push_back(Call); - else - BodyStmts.push_back(Call); + return Call; } // FIXME Avoid creation of kernel obj clone. @@ -1517,8 +1509,12 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, None); InitExprs.push_back(MemberInit.get()); - createSpecialMethodCall(RecordDecl, MemberExprBases.back(), InitMethodName, - FD); + CXXMethodDecl *InitMethod = getMethodByName(RecordDecl, InitMethodName); + if (InitMethod) { + CXXMemberCallExpr *InitCall = + createSpecialMethodCall(MemberExprBases.back(), InitMethod, FD); + BodyStmts.push_back(InitCall); + } return true; } @@ -1535,8 +1531,12 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, None); InitExprs.push_back(MemberInit.get()); - createSpecialMethodCall(RecordDecl, MemberExprBases.back(), InitMethodName, - nullptr); + CXXMethodDecl *InitMethod = getMethodByName(RecordDecl, InitMethodName); + if (InitMethod) { + CXXMemberCallExpr *InitCall = + createSpecialMethodCall(MemberExprBases.back(), InitMethod, nullptr); + BodyStmts.push_back(InitCall); + } return true; } @@ -1578,14 +1578,27 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { return handleSpecialType(FD, Ty); } + bool handleSyclSpecConstantType(FieldDecl *FD, QualType Ty) final { + return handleSpecialType(FD, Ty); + } + bool handleSyclStreamType(FieldDecl *FD, QualType Ty) final { const auto *StreamDecl = Ty->getAsCXXRecordDecl(); createExprForStructOrScalar(FD); size_t NumBases = MemberExprBases.size(); - createSpecialMethodCall(StreamDecl, MemberExprBases[NumBases - 2], - InitMethodName, FD); - createSpecialMethodCall(StreamDecl, MemberExprBases[NumBases - 2], - FinalizeMethodName, FD); + CXXMethodDecl *InitMethod = getMethodByName(StreamDecl, InitMethodName); + if (InitMethod) { + CXXMemberCallExpr *InitCall = + createSpecialMethodCall(MemberExprBases.back(), InitMethod, FD); + BodyStmts.push_back(InitCall); + } + CXXMethodDecl *FinalizeMethod = + getMethodByName(StreamDecl, FinalizeMethodName); + if (FinalizeMethod) { + CXXMemberCallExpr *FinalizeCall = createSpecialMethodCall( + MemberExprBases[NumBases - 2], FinalizeMethod, FD); + FinalizeStmts.push_back(FinalizeCall); + } return true; } @@ -1796,7 +1809,7 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { cast(FieldTy->getAsRecordDecl()) ->getTemplateInstantiationArgs(); assert(TemplateArgs.size() == 2 && - "Incorrect template args for Accessor Type"); + "Incorrect template args for spec constant type"); // Get specialization constant ID type, which is the second template // argument. QualType SpecConstIDTy = TemplateArgs.get(1).getAsType().getCanonicalType(); diff --git a/clang/test/SemaSYCL/spec-const-kernel-arg.cpp b/clang/test/SemaSYCL/spec-const-kernel-arg.cpp new file mode 100644 index 0000000000000..d40e5296949a2 --- /dev/null +++ b/clang/test/SemaSYCL/spec-const-kernel-arg.cpp @@ -0,0 +1,29 @@ +// RUN: %clang_cc1 -I %S/Inputs -fsycl -fsycl-is-device -ast-dump %s | FileCheck %s + +// This test checks that compiler generates correct initialization for spec +// constants + +#include + +struct SpecConstantsWrapper { + cl::sycl::experimental::spec_constant SC1; + cl::sycl::experimental::spec_constant SC2; +}; + +int main() { + cl::sycl::experimental::spec_constant SC; + SpecConstantsWrapper W; + cl::sycl::kernel_single_task( + [=]() { + (void)SC; + (void)W; + }); +} + +// CHECK: FunctionDecl {{.*}}kernel_sc{{.*}} 'void ()' +// CHECK: VarDecl {{.*}}'(lambda at {{.*}}' +// CHECK-NEXT: InitListExpr {{.*}}'(lambda at {{.*}}' +// CHECK-NEXT: CXXConstructExpr {{.*}}'cl::sycl::experimental::spec_constant':'cl::sycl::experimental::spec_constant' +// CHECK-NEXT: InitListExpr {{.*}} 'SpecConstantsWrapper' +// CHECK-NEXT: CXXConstructExpr {{.*}} 'cl::sycl::experimental::spec_constant':'cl::sycl::experimental::spec_constant' +// CHECK-NEXT: CXXConstructExpr {{.*}} 'cl::sycl::experimental::spec_constant':'cl::sycl::experimental::spec_constant' diff --git a/sycl/include/CL/sycl/experimental/spec_constant.hpp b/sycl/include/CL/sycl/experimental/spec_constant.hpp index 104137fdba9c5..7952d98bd481f 100644 --- a/sycl/include/CL/sycl/experimental/spec_constant.hpp +++ b/sycl/include/CL/sycl/experimental/spec_constant.hpp @@ -32,7 +32,10 @@ template class spec_constant { private: // Implementation defined constructor. #ifdef __SYCL_DEVICE_ONLY__ +public: spec_constant() {} + +private: #else spec_constant(T Cst) : Val(Cst) {} #endif diff --git a/sycl/test/spec_const/spec_const_hw.cpp b/sycl/test/spec_const/spec_const_hw.cpp index 442121353bb73..b8d161cbac204 100644 --- a/sycl/test/spec_const/spec_const_hw.cpp +++ b/sycl/test/spec_const/spec_const_hw.cpp @@ -39,6 +39,15 @@ float foo( return f32; } +struct SCWrapper { + SCWrapper(cl::sycl::program &p) + : SC1(p.set_spec_constant(4)), + SC2(p.set_spec_constant(2)) {} + + cl::sycl::experimental::spec_constant SC1; + cl::sycl::experimental::spec_constant SC2; +}; + int main(int argc, char **argv) { val = argc + 16; @@ -61,6 +70,7 @@ int main(int argc, char **argv) { std::cout << "val = " << val << "\n"; cl::sycl::program program1(q.get_context()); cl::sycl::program program2(q.get_context()); + cl::sycl::program program3(q.get_context()); int goldi = (int)get_value(); // TODO make this floating point once supported by the compiler @@ -77,11 +87,17 @@ int main(int argc, char **argv) { // SYCL RT execution path program2.build_with_kernel_type("-cl-fast-relaxed-math"); + SCWrapper W(program3); + program3.build_with_kernel_type(); + int goldw = 6; + std::vector veci(1); std::vector vecf(1); + std::vector vecw(1); try { cl::sycl::buffer bufi(veci.data(), veci.size()); cl::sycl::buffer buff(vecf.data(), vecf.size()); + cl::sycl::buffer bufw(vecw.data(), vecw.size()); q.submit([&](cl::sycl::handler &cgh) { auto acci = bufi.get_access(cgh); @@ -99,6 +115,13 @@ int main(int argc, char **argv) { accf[0] = foo(f32); }); }); + + q.submit([&](cl::sycl::handler &cgh) { + auto accw = bufw.get_access(cgh); + cgh.single_task( + program3.get_kernel(), + [=]() { accw[0] = W.SC1.get() + W.SC2.get(); }); + }); } catch (cl::sycl::exception &e) { std::cout << "*** Exception caught: " << e.what() << "\n"; return 1; @@ -116,6 +139,12 @@ int main(int argc, char **argv) { std::cout << "*** ERROR: " << valf << " != " << goldf << "(gold)\n"; passed = false; } + int valw = vecw[0]; + + if (valw != goldw) { + std::cout << "*** ERROR: " << valw << " != " << goldw << "(gold)\n"; + passed = false; + } std::cout << (passed ? "passed\n" : "FAILED\n"); return passed ? 0 : 1; }