diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 8cba9ce845ddc..f682be90a80b6 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -6648,6 +6648,10 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) { /// /// Moral of the story: drop integration header ASAP (but that is blocked /// by support for 3rd-party host compilers, which is important). +static std::string stringifyTypeWithEnumValues(QualType T, + PrintingPolicy &Policy, + ASTContext &Ctx); + class FreeFunctionTemplateKernelArgsPrinter : public ConstTemplateArgumentVisitor> { @@ -6683,6 +6687,18 @@ class FreeFunctionTemplateKernelArgsPrinter DesugarTemplateArgument(Arg).print(Policy, O, /*IncludeType=*/false); } + void printEnumValue(QualType EnumTy, const llvm::APSInt &Val) { + O << "static_cast<"; + if (const auto *ET = EnumTy->getAs()) + ET->getOriginalDecl()->printQualifiedName(O, Policy, + /*WithGlobalNsPrefix=*/false); + else + EnumTy.print(O, Policy); + llvm::SmallString<8> Num; + Val.toString(Num, /*Radix=*/10, /*Signed=*/Val.isSigned()); + O << ">(" << Num << ")"; + } + public: FreeFunctionTemplateKernelArgsPrinter(raw_ostream &O, PrintingPolicy &Policy, ASTContext &Context) @@ -6711,7 +6727,8 @@ class FreeFunctionTemplateKernelArgsPrinter const auto *TST = dyn_cast(T.getTypePtr()); const auto *CTST = dyn_cast(CT.getTypePtr()); if (!TST || !CTST) { - O << T.getDesugaredType(Context).getAsString(Policy); + O << stringifyTypeWithEnumValues(T.getDesugaredType(Context), Policy, + Context); return; } @@ -6814,11 +6831,17 @@ class FreeFunctionTemplateKernelArgsPrinter void VisitIntegralTemplateArgument(const TemplateArgument &Arg, ArrayRef) { + QualType T = Arg.getIntegralType(); + if (T->isEnumeralType()) + return printEnumValue(T, Arg.getAsIntegral()); PrintDesugared(Arg); } void VisitStructuralValueTemplateArgument(const TemplateArgument &Arg, ArrayRef) { + QualType T = Arg.getIntegralType(); + if (T->isEnumeralType()) + return printEnumValue(T, Arg.getAsIntegral()); PrintDesugared(Arg); } @@ -6837,22 +6860,23 @@ class FreeFunctionTemplateKernelArgsPrinter Expr *E = Arg.getAsExpr(); assert(E && "Failed to get an Expr for an Expression template arg?"); - if (Arg.isInstantiationDependent() || - E->getType()->isScopedEnumeralType()) { - // Scoped enumerations can't be implicitly cast from integers, so - // we don't need to evaluate them. - // If expression is instantiation-dependent, then we can't evaluate it - // either, let's fallback to default printing mechanism. + if (Arg.isInstantiationDependent()) { PrintDesugared(Arg); return; } + if (E->getType()->isEnumeralType()) { + Expr::EvalResult Res; + if (E->EvaluateAsConstantExpr(Res, Context) && !Res.Val.isAbsent() && + Res.Val.isInt()) + return printEnumValue(E->getType(), Res.Val.getInt()); + } + Expr::EvalResult Res; - [[maybe_unused]] bool Success = - Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context); + [[maybe_unused]] bool Success = E->EvaluateAsConstantExpr(Res, Context); assert(Success && "invalid non-type template argument?"); assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?"); - Res.Val.printPretty(O, Policy, Arg.getAsExpr()->getType(), &Context); + Res.Val.printPretty(O, Policy, E->getType(), &Context); } void VisitPackTemplateArgument(const TemplateArgument &Arg, @@ -6861,6 +6885,162 @@ class FreeFunctionTemplateKernelArgsPrinter } }; +class EnumValueTemplateArgPrinter + : public ConstTemplateArgumentVisitor { + raw_ostream &OS; + PrintingPolicy &Policy; + ASTContext &Ctx; + + void printEnumValue(QualType EnumTy, const llvm::APSInt &Val) { + OS << "static_cast<"; + if (const auto *ET = EnumTy->getAs()) + ET->getOriginalDecl()->printQualifiedName(OS, Policy, + /*WithGlobalNsPrefix=*/false); + else + EnumTy.print(OS, Policy); + llvm::SmallString<8> Num; + Val.toString(Num, /*Radix=*/10, /*Signed=*/Val.isSigned()); + OS << ">(" << Num << ")"; + } + + void printTemplateArgs(ArrayRef Args) { + llvm::ListSeparator LS(", "); + for (const auto &A : Args) { + // Skip empty packs without emitting separators. + if (A.getKind() == TemplateArgument::ArgKind::Pack && !A.pack_size()) + continue; + OS << LS; + Visit(A); + } + } + + void printTypeWithEnumValues(QualType T) { + QualType DT = T.getDesugaredType(Ctx); + + if (const auto *CTSD = dyn_cast_or_null( + DT->getAsCXXRecordDecl())) { + if (!Policy.SuppressTagKeyword) + OS << CTSD->getKindName() << " "; + CTSD->printQualifiedName(OS, Policy, + /*WithGlobalNsPrefix=*/false); + OS << "<"; + printTemplateArgs(CTSD->getTemplateArgs().asArray()); + OS << ">"; + return; + } + + if (const auto *TST = DT->getAs()) { + if (const auto *RT = TST->getAs()) { + if (!Policy.SuppressTagKeyword) + OS << RT->getDecl()->getKindName() << " "; + } + + if (TemplateDecl *TD = TST->getTemplateName().getAsTemplateDecl()) + TD->printQualifiedName(OS, Policy, + /*WithGlobalNsPrefix=*/false); + else + TST->getTemplateName().print(OS, Policy); + + OS << "<"; + printTemplateArgs(TST->template_arguments()); + OS << ">"; + return; + } + + T.print(OS, Policy); + } + +public: + EnumValueTemplateArgPrinter(raw_ostream &OS, PrintingPolicy &Policy, + ASTContext &Ctx) + : OS(OS), Policy(Policy), Ctx(Ctx) {} + + void PrintType(QualType T) { printTypeWithEnumValues(T); } + + void Visit(const TemplateArgument &TA) { + if (TA.isNull()) + return; + ConstTemplateArgumentVisitor::Visit(TA); + } + + void VisitTypeTemplateArgument(const TemplateArgument &Arg) { + printTypeWithEnumValues(Arg.getAsType()); + } + + void VisitIntegralTemplateArgument(const TemplateArgument &Arg) { + QualType T = Arg.getIntegralType(); + if (T->isEnumeralType()) + return printEnumValue(T, Arg.getAsIntegral()); + Arg.print(Policy, OS, /*IncludeType=*/false); + } + + void VisitExpressionTemplateArgument(const TemplateArgument &Arg) { + Expr *E = Arg.getAsExpr(); + if (!E || Arg.isInstantiationDependent()) { + Arg.print(Policy, OS, /*IncludeType=*/false); + return; + } + + if (E->getType()->isEnumeralType()) { + Expr::EvalResult Res; + if (E->EvaluateAsConstantExpr(Res, Ctx) && !Res.Val.isAbsent() && + Res.Val.isInt()) { + return printEnumValue(E->getType(), Res.Val.getInt()); + } + } + + E = E->IgnoreParenImpCasts(); + const EnumConstantDecl *ECD = nullptr; + if (const auto *DRE = dyn_cast(E)) + ECD = dyn_cast(DRE->getDecl()); + else if (const auto *ME = dyn_cast(E)) + ECD = dyn_cast(ME->getMemberDecl()); + + if (ECD) + return printEnumValue(ECD->getType(), ECD->getInitVal()); + + Arg.print(Policy, OS, /*IncludeType=*/false); + } + + void VisitTemplateTemplateArgument(const TemplateArgument &Arg) { + Arg.getAsTemplate().print(OS, Policy); + } + + void VisitTemplateExpansionTemplateArgument(const TemplateArgument &Arg) { + Arg.print(Policy, OS, /*IncludeType=*/false); + } + + void VisitStructuralValueTemplateArgument(const TemplateArgument &Arg) { + Arg.print(Policy, OS, /*IncludeType=*/false); + } + + void VisitPackTemplateArgument(const TemplateArgument &Arg) { + printTemplateArgs(Arg.getPackAsArray()); + } +}; + +static std::string +stringifyTemplateArgWithEnumValues(const TemplateArgument &Arg, + PrintingPolicy &Policy, ASTContext &Ctx) { + std::string Out; + llvm::raw_string_ostream OS(Out); + EnumValueTemplateArgPrinter Printer(OS, Policy, Ctx); + Printer.Visit(Arg); + OS.flush(); + return Out; +} + +static std::string stringifyTypeWithEnumValues(QualType T, + PrintingPolicy &Policy, + ASTContext &Ctx) { + std::string Out; + llvm::raw_string_ostream OS(Out); + EnumValueTemplateArgPrinter Printer(OS, Policy, Ctx); + Printer.PrintType(T); + OS.flush(); + return Out; +} + class FreeFunctionPrinter { raw_ostream &O; PrintingPolicy &Policy; @@ -6993,13 +7173,14 @@ class FreeFunctionPrinter { else if (X.getKind() == TemplateArgument::Pack) { for (const auto &PackArg : X.pack_elements()) { StringStream << ", "; - PackArg.print(Policy, StringStream, /*IncludeType*/ true); + StringStream << stringifyTemplateArgWithEnumValues(PackArg, Policy, + Context); } continue; } else StringStream << ", "; - X.print(Policy, StringStream, /*IncludeType*/ true); + StringStream << stringifyTemplateArgWithEnumValues(X, Policy, Context); } StringStream.flush(); if (Buffer.front() != '<') @@ -7409,12 +7590,15 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { ParmList += "Args ..."; } else { Policy.SuppressTagKeyword = true; - Param->getType().print(ParmListWithNamesOstream, Policy); + ParmListWithNamesOstream << stringifyTypeWithEnumValues( + Param->getType(), Policy, S.getASTContext()); Policy.SuppressTagKeyword = false; ParmListWithNamesOstream << " " << Param->getNameAsString(); - ParmList += Param->getType().getCanonicalType().getAsString(Policy); + ParmList += stringifyTypeWithEnumValues( + Param->getType().getCanonicalType(), Policy, S.getASTContext()); } } + ParmListWithNamesOstream.flush(); FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate(); Policy.PrintAsCanonical = false; diff --git a/clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp b/clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp index 81120ef702d52..3d08b5999bd5f 100644 --- a/clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp +++ b/clang/test/CodeGenSYCL/free-function-kernel-expr-as-template-arg.cpp @@ -87,22 +87,22 @@ void constexpr_ns2(Arg) {} [[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void constexpr_ns2(Arg2) {} -// CHECK: void constexpr_ns2(Arg2 ); +// CHECK: void constexpr_ns2(Arg2(0)> ); [[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void constexpr_ns2(Arg3) {} -// CHECK: void constexpr_ns2(Arg3 ); +// CHECK: void constexpr_ns2(Arg3(0)> ); [[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void constexpr_ns2(Arg4) {} -// CHECK: void constexpr_ns2(Arg4 ); +// CHECK: void constexpr_ns2(Arg4(0)> ); [[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void constexpr_ns2(Arg5) {} -// CHECK: void constexpr_ns2(Arg5 ); +// CHECK: void constexpr_ns2(Arg5(0)> ); [[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void constexpr_call(Arg) {} diff --git a/clang/test/CodeGenSYCL/free-function-kernel-templated-arg-with-enum.cpp b/clang/test/CodeGenSYCL/free-function-kernel-templated-arg-with-enum.cpp index bc45a922bd3b5..2cb6ac42c6de9 100644 --- a/clang/test/CodeGenSYCL/free-function-kernel-templated-arg-with-enum.cpp +++ b/clang/test/CodeGenSYCL/free-function-kernel-templated-arg-with-enum.cpp @@ -37,18 +37,18 @@ template void templated_on_A(ns::feature_A Arg) {} template void templated_on_A(ns::feature_A); -// CHECK: template void templated_on_A(ns::feature_A); +// CHECK: template void templated_on_A(ns::feature_A(1)>); template [[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void templated_on_B(ns::nested::feature_B Arg) {} template void templated_on_B(ns::nested::feature_B); -// CHECK: template void templated_on_B(ns::nested::feature_B); +// CHECK: template void templated_on_B(ns::nested::feature_B(0), static_cast(2)>); template [[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void templated_on_C(ns::nested2::feature_C Arg) {} template void templated_on_C(ns::nested2::feature_C<42>); -// CHECK: template void templated_on_C(ns::nested2::feature_C); +// CHECK: template void templated_on_C(ns::nested2::feature_C(1)>); diff --git a/clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp b/clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp index 56db32875ee51..8e0a6a3a02d3a 100644 --- a/clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp +++ b/clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp @@ -380,9 +380,9 @@ namespace Testing::Tests { // CHECK-NEXT: } // CHECK-NEXT: template struct Arg1; -// CHECK: void foo(Arg1 > arg); +// CHECK: void foo(Arg1> arg); // CHECK-NEXT: static constexpr auto __sycl_shim7() { -// CHECK-NEXT: return (void (*)(struct Arg1 >))foo; +// CHECK-NEXT: return (void (*)(struct Arg1>))foo; // CHECK-NEXT: } // CHECK: namespace sycl { diff --git a/clang/test/CodeGenSYCL/free_function_int_header.cpp b/clang/test/CodeGenSYCL/free_function_int_header.cpp index b4326f023df54..0e2851ac20cd4 100644 --- a/clang/test/CodeGenSYCL/free_function_int_header.cpp +++ b/clang/test/CodeGenSYCL/free_function_int_header.cpp @@ -1486,9 +1486,9 @@ void ff_28(TemplatedAccessorStruct arg1) { // CHECK-NEXT: template class accessor; // CHECK-NEXT: }} -// CHECK: void ff_20(sycl::accessor > acc); +// CHECK: void ff_20(sycl::accessor(1026), static_cast(2014), static_cast(0), sycl::ext::oneapi::accessor_property_list<>> acc); // CHECK-NEXT: static constexpr auto __sycl_shim29() { -// CHECK-NEXT: return (void (*)(class sycl::accessor >))ff_20; +// CHECK-NEXT: return (void (*)(class sycl::accessor(1026), static_cast(2014), static_cast(0), class sycl::ext::oneapi::accessor_property_list<>>))ff_20; // CHECK-NEXT: } // CHECK: namespace sycl { diff --git a/sycl/test-e2e/FreeFunctionKernels/free_function_user_enum_class.cpp b/sycl/test-e2e/FreeFunctionKernels/free_function_user_enum_class.cpp new file mode 100644 index 0000000000000..1b7ef4c8675b6 --- /dev/null +++ b/sycl/test-e2e/FreeFunctionKernels/free_function_user_enum_class.cpp @@ -0,0 +1,278 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// This test verifies that we can use scoped enum types as arguments in free +// function kernels. + +#include "free_function_user_enum_class.hpp" +#include "helpers.hpp" +#include +#include +#include +#include +#include + +namespace syclext = sycl::ext::oneapi; +namespace syclexp = sycl::ext::oneapi::experimental; +namespace at::native::xpu { + +enum class OP_MODE_SCOPED : uint8_t { INC, DEC, MUL, DIV }; +enum OP_MODE : uint8_t { INC, DEC, MUL, DIV }; + +template struct TestStruct { + OP_MODE_SCOPED _op = op; +}; + +template +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>)) +void freefunctionkernel(TestStruct(2), int> t, + T *data, size_t size) { + auto item = syclext::this_work_item::get_nd_item<1>(); + size_t idx = item.get_global_linear_id(); + if (idx < size) { + if (t._op == OP_MODE_SCOPED::MUL) + data[idx] *= 2; + } +} + +template +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>)) +void freefunctionkernel_enum(T *data, size_t size, OP_MODE_SCOPED op) { + auto item = syclext::this_work_item::get_nd_item<1>(); + size_t idx = item.get_global_linear_id(); + if (idx < size) { + if (op == OP_MODE_SCOPED::INC) + data[idx]++; + else if (op == OP_MODE_SCOPED::DEC) + data[idx]--; + else if (op == OP_MODE_SCOPED::MUL) + data[idx] *= 2; + else if (op == OP_MODE_SCOPED::DIV) + data[idx] /= 2; + else + data[idx] = -1; + } +} + +template struct LpMaxFunctor { + void operator()(sycl::nd_item<1> item, T *data, size_t size) { + size_t idx = item.get_global_linear_id(); + if (idx < size) { + if constexpr (adam_mode == ADAM_MODE::ADAMW) + data[idx] += 1; + else + data[idx] += 2; + } + } +}; + +template struct OpFunctor { + void operator()(sycl::nd_item<1> item, T *data, size_t size) { + size_t idx = item.get_global_linear_id(); + if (idx < size) { + if constexpr (op_mode == OP_MODE::DEC) + data[idx]--; + } + } +}; + +template struct OpSFunctor { + void operator()(sycl::nd_item<1> item, T *data, size_t size) { + size_t idx = item.get_global_linear_id(); + if (idx < size) { + if constexpr (op_mode == OP_MODE_SCOPED::INC) + data[idx]++; + else if constexpr (op_mode == OP_MODE_SCOPED::DEC) + data[idx]--; + else if constexpr (op_mode == OP_MODE_SCOPED::MUL) + data[idx] *= 2; + else if constexpr (op_mode == OP_MODE_SCOPED::DIV) + data[idx] /= 2; + else + data[idx] = -1; // Invalid operation, set to -1 for testing + } + } +}; + +template +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>)) +void applyKernel(U callable, T *data, size_t size) { + auto item = syclext::this_work_item::get_nd_item<1>(); + callable(item, data, size); +} + +template +static inline void +sycl_kernel_submit(sycl::range global_range, sycl::range local_range, + sycl::queue q, int slm_sz, Kargs... args) { + sycl::context ctxt = q.get_context(); + auto exe_bndl = + syclexp::get_kernel_bundle(ctxt); + sycl::kernel ker = exe_bndl.template ext_oneapi_get_kernel(); + if (slm_sz != 0) { + syclexp::launch_config cfg{ + sycl::nd_range(sycl::range(global_range), + sycl::range(local_range)), + syclexp::properties{syclexp::work_group_scratch_size(slm_sz)}}; + syclexp::nd_launch(q, cfg, ker, args...); + } else { + syclexp::launch_config cfg{sycl::nd_range( + sycl::range(global_range), sycl::range(local_range))}; + syclexp::nd_launch(q, cfg, ker, args...); + } +} + +template +void apply_lp(sycl::queue &q, T callable, int *data, size_t size, + size_t global_size) { + constexpr auto kernel = + applyKernel, int>; + sycl_kernel_submit(sycl::range<1>(global_size), sycl::range<1>(1), q, + 0, callable, data, size); +} + +template +void apply_op_scoped(sycl::queue &q, T callable, int *data, size_t size, + size_t global_size) { + constexpr auto kernel = + applyKernel, int>; + sycl_kernel_submit(sycl::range<1>(global_size), sycl::range<1>(1), q, + 0, callable, data, size); +} + +template +void apply_op(sycl::queue &q, T callable, int *data, size_t size, + size_t global_size) { + constexpr auto kernel = applyKernel, int>; + sycl_kernel_submit(sycl::range<1>(global_size), sycl::range<1>(1), q, + 0, callable, data, size); +} + +template +void apply_op_scoped_cast(sycl::queue &q, T callable, int *data, size_t size, + size_t global_size) { + constexpr auto kernel = + applyKernel(2)>, int>; + sycl_kernel_submit(sycl::range<1>(global_size), sycl::range<1>(1), q, + 0, callable, data, size); +} + +template +void apply_op_scoped_cast1(sycl::queue &q, T callable, int *data, size_t size, + size_t global_size) { + constexpr auto kernel = + applyKernel(32)>, int>; + sycl_kernel_submit(sycl::range<1>(global_size), sycl::range<1>(1), q, + 0, callable, data, size); +} + +template +void apply_op_scoped_cast2(sycl::queue &q, int *data, size_t size, + size_t global_size) { + constexpr auto kernel = freefunctionkernel; + TestStruct(2), int> t; + sycl_kernel_submit(sycl::range<1>(global_size), sycl::range<1>(1), q, + 0, t, data, size); +} + +template +void apply_op_scoped_enum_named(sycl::queue &q, OP_MODE_SCOPED op, int *data, + size_t size, size_t global_size) { + constexpr auto kernel = freefunctionkernel_enum; + sycl_kernel_submit(sycl::range<1>(global_size), sycl::range<1>(1), q, + 0, data, size, op); +} + +} // namespace at::native::xpu + +int main() { + sycl::queue q; + constexpr size_t N = 10; + auto *data = sycl::malloc_shared(N, q); + assert(data && "USM allocation failed"); + + for (size_t i = 0; i < N; ++i) + data[i] = i + 1; + + at::native::xpu::LpMaxFunctor + subKernel; + apply_lp(q, subKernel, data, N, N); + q.wait_and_throw(); + + for (size_t i = 0; i < N; ++i) { + assert(data[i] == i + 2); + data[i] = i; + } + + at::native::xpu::OpSFunctor + subKernelOpS; + apply_op_scoped(q, subKernelOpS, data, N, N); + q.wait_and_throw(); + + for (size_t i = 0; i < N; ++i) { + assert(data[i] == i * 2); + data[i] = i; + } + + at::native::xpu::OpFunctor subKernelOp; + apply_op(q, subKernelOp, data, N, N); + q.wait_and_throw(); + + for (size_t i = 0; i < N; ++i) { + assert(data[i] == (i - 1) && "DEC operation failed"); + data[i] = i; + } + + at::native::xpu::OpSFunctor(2)> + subKernelOpSCast; + apply_op_scoped_cast(q, subKernelOpSCast, data, N, N); + q.wait_and_throw(); + + for (size_t i = 0; i < N; ++i) { + assert(data[i] == i * 2 && "MUL operation with casted enum failed"); + data[i] = i; + } + + at::native::xpu::OpSFunctor(32)> + subKernelOpSCast1; + apply_op_scoped_cast1(q, subKernelOpSCast1, data, N, N); + q.wait_and_throw(); + + for (size_t i = 0; i < N; ++i) { + assert(data[i] == -1 && "MUL operation with casted enum failed"); + data[i] = i; + } + + at::native::xpu::apply_op_scoped_cast2(q, data, N, N); + q.wait_and_throw(); + + for (size_t i = 0; i < N; ++i) { + assert(data[i] == i * 2 && "MUL operation via free function kernel failed"); + data[i] = i; + } + + at::native::xpu::apply_op_scoped_enum_named( + q, at::native::xpu::OP_MODE_SCOPED::DEC, data, N, N); + q.wait_and_throw(); + + for (size_t i = 0; i < N; ++i) { + assert(data[i] == (i - 1) && + "Named enum value via free function kernel failed"); + data[i] = i; + } + + at::native::xpu::apply_op_scoped_enum_named( + q, static_cast(2), data, N, N); + q.wait_and_throw(); + + for (size_t i = 0; i < N; ++i) { + assert(data[i] == i * 2 && + "Casted enum value via free function kernel failed"); + data[i] = i; + } + + sycl::free(data, q); + return 0; +} diff --git a/sycl/test-e2e/FreeFunctionKernels/free_function_user_enum_class.hpp b/sycl/test-e2e/FreeFunctionKernels/free_function_user_enum_class.hpp new file mode 100644 index 0000000000000..642895508656f --- /dev/null +++ b/sycl/test-e2e/FreeFunctionKernels/free_function_user_enum_class.hpp @@ -0,0 +1,7 @@ +#include + +namespace at::native::xpu { + +enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 }; + +}