Skip to content
Open
212 changes: 198 additions & 14 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6648,6 +6648,10 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
///
/// Moral of the story: drop integration header ASAP (but that is blocked
/// by support for 3rd-party host compilers, which is important).
static std::string stringifyTypeWithEnumValues(QualType T,
PrintingPolicy &Policy,
ASTContext &Ctx);

class FreeFunctionTemplateKernelArgsPrinter
: public ConstTemplateArgumentVisitor<FreeFunctionTemplateKernelArgsPrinter,
void, ArrayRef<TemplateArgument>> {
Expand Down Expand Up @@ -6683,6 +6687,18 @@ class FreeFunctionTemplateKernelArgsPrinter
DesugarTemplateArgument(Arg).print(Policy, O, /*IncludeType=*/false);
}

void printEnumValue(QualType EnumTy, const llvm::APSInt &Val) {
O << "static_cast<";
if (const auto *ET = EnumTy->getAs<EnumType>())
ET->getOriginalDecl()->printQualifiedName(O, Policy,
/*WithGlobalNsPrefix=*/false);
else
EnumTy.print(O, Policy);
llvm::SmallString<8> Num;
Val.toString(Num, /*Radix=*/10, /*Signed=*/Val.isSigned());
O << ">(" << Num << ")";
}

public:
FreeFunctionTemplateKernelArgsPrinter(raw_ostream &O, PrintingPolicy &Policy,
ASTContext &Context)
Expand Down Expand Up @@ -6711,7 +6727,8 @@ class FreeFunctionTemplateKernelArgsPrinter
const auto *TST = dyn_cast<TemplateSpecializationType>(T.getTypePtr());
const auto *CTST = dyn_cast<TemplateSpecializationType>(CT.getTypePtr());
if (!TST || !CTST) {
O << T.getDesugaredType(Context).getAsString(Policy);
O << stringifyTypeWithEnumValues(T.getDesugaredType(Context), Policy,
Context);
return;
}

Expand Down Expand Up @@ -6814,11 +6831,17 @@ class FreeFunctionTemplateKernelArgsPrinter

void VisitIntegralTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument>) {
QualType T = Arg.getIntegralType();
if (T->isEnumeralType())
return printEnumValue(T, Arg.getAsIntegral());
PrintDesugared(Arg);
}

void VisitStructuralValueTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument>) {
QualType T = Arg.getIntegralType();
if (T->isEnumeralType())
return printEnumValue(T, Arg.getAsIntegral());
PrintDesugared(Arg);
}

Expand All @@ -6837,22 +6860,23 @@ class FreeFunctionTemplateKernelArgsPrinter
Expr *E = Arg.getAsExpr();
assert(E && "Failed to get an Expr for an Expression template arg?");

if (Arg.isInstantiationDependent() ||
E->getType()->isScopedEnumeralType()) {
// Scoped enumerations can't be implicitly cast from integers, so
// we don't need to evaluate them.
// If expression is instantiation-dependent, then we can't evaluate it
// either, let's fallback to default printing mechanism.
if (Arg.isInstantiationDependent()) {
PrintDesugared(Arg);
return;
}

if (E->getType()->isEnumeralType()) {
Expr::EvalResult Res;
if (E->EvaluateAsConstantExpr(Res, Context) && !Res.Val.isAbsent() &&
Res.Val.isInt())
return printEnumValue(E->getType(), Res.Val.getInt());
}

Expr::EvalResult Res;
[[maybe_unused]] bool Success =
Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context);
[[maybe_unused]] bool Success = E->EvaluateAsConstantExpr(Res, Context);
assert(Success && "invalid non-type template argument?");
assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?");
Res.Val.printPretty(O, Policy, Arg.getAsExpr()->getType(), &Context);
Res.Val.printPretty(O, Policy, E->getType(), &Context);
}

void VisitPackTemplateArgument(const TemplateArgument &Arg,
Expand All @@ -6861,6 +6885,162 @@ class FreeFunctionTemplateKernelArgsPrinter
}
};

class EnumValueTemplateArgPrinter
: public ConstTemplateArgumentVisitor<EnumValueTemplateArgPrinter> {
raw_ostream &OS;
PrintingPolicy &Policy;
ASTContext &Ctx;

void printEnumValue(QualType EnumTy, const llvm::APSInt &Val) {
OS << "static_cast<";
if (const auto *ET = EnumTy->getAs<EnumType>())
ET->getOriginalDecl()->printQualifiedName(OS, Policy,
/*WithGlobalNsPrefix=*/false);
else
EnumTy.print(OS, Policy);
llvm::SmallString<8> Num;
Val.toString(Num, /*Radix=*/10, /*Signed=*/Val.isSigned());
OS << ">(" << Num << ")";
}

void printTemplateArgs(ArrayRef<TemplateArgument> Args) {
llvm::ListSeparator LS(", ");
for (const auto &A : Args) {
// Skip empty packs without emitting separators.
if (A.getKind() == TemplateArgument::ArgKind::Pack && !A.pack_size())
continue;
OS << LS;
Visit(A);
}
}

void printTypeWithEnumValues(QualType T) {
QualType DT = T.getDesugaredType(Ctx);

if (const auto *CTSD = dyn_cast_or_null<ClassTemplateSpecializationDecl>(
DT->getAsCXXRecordDecl())) {
if (!Policy.SuppressTagKeyword)
OS << CTSD->getKindName() << " ";
CTSD->printQualifiedName(OS, Policy,
/*WithGlobalNsPrefix=*/false);
OS << "<";
printTemplateArgs(CTSD->getTemplateArgs().asArray());
OS << ">";
return;
}

if (const auto *TST = DT->getAs<TemplateSpecializationType>()) {
if (const auto *RT = TST->getAs<RecordType>()) {
if (!Policy.SuppressTagKeyword)
OS << RT->getDecl()->getKindName() << " ";
}

if (TemplateDecl *TD = TST->getTemplateName().getAsTemplateDecl())
TD->printQualifiedName(OS, Policy,
/*WithGlobalNsPrefix=*/false);
else
TST->getTemplateName().print(OS, Policy);

OS << "<";
printTemplateArgs(TST->template_arguments());
OS << ">";
return;
}

T.print(OS, Policy);
}

public:
EnumValueTemplateArgPrinter(raw_ostream &OS, PrintingPolicy &Policy,
ASTContext &Ctx)
: OS(OS), Policy(Policy), Ctx(Ctx) {}

void PrintType(QualType T) { printTypeWithEnumValues(T); }

void Visit(const TemplateArgument &TA) {
if (TA.isNull())
return;
ConstTemplateArgumentVisitor::Visit(TA);
}

void VisitTypeTemplateArgument(const TemplateArgument &Arg) {
printTypeWithEnumValues(Arg.getAsType());
}

void VisitIntegralTemplateArgument(const TemplateArgument &Arg) {
QualType T = Arg.getIntegralType();
if (T->isEnumeralType())
return printEnumValue(T, Arg.getAsIntegral());
Arg.print(Policy, OS, /*IncludeType=*/false);
}

void VisitExpressionTemplateArgument(const TemplateArgument &Arg) {
Expr *E = Arg.getAsExpr();
if (!E || Arg.isInstantiationDependent()) {
Arg.print(Policy, OS, /*IncludeType=*/false);
return;
}

if (E->getType()->isEnumeralType()) {
Expr::EvalResult Res;
if (E->EvaluateAsConstantExpr(Res, Ctx) && !Res.Val.isAbsent() &&
Res.Val.isInt()) {
return printEnumValue(E->getType(), Res.Val.getInt());
}
}

E = E->IgnoreParenImpCasts();
const EnumConstantDecl *ECD = nullptr;
if (const auto *DRE = dyn_cast<DeclRefExpr>(E))
ECD = dyn_cast<EnumConstantDecl>(DRE->getDecl());
else if (const auto *ME = dyn_cast<MemberExpr>(E))
ECD = dyn_cast<EnumConstantDecl>(ME->getMemberDecl());

if (ECD)
return printEnumValue(ECD->getType(), ECD->getInitVal());

Arg.print(Policy, OS, /*IncludeType=*/false);
}

void VisitTemplateTemplateArgument(const TemplateArgument &Arg) {
Arg.getAsTemplate().print(OS, Policy);
}

void VisitTemplateExpansionTemplateArgument(const TemplateArgument &Arg) {
Arg.print(Policy, OS, /*IncludeType=*/false);
}

void VisitStructuralValueTemplateArgument(const TemplateArgument &Arg) {
Arg.print(Policy, OS, /*IncludeType=*/false);
}

void VisitPackTemplateArgument(const TemplateArgument &Arg) {
printTemplateArgs(Arg.getPackAsArray());
}
};

static std::string
stringifyTemplateArgWithEnumValues(const TemplateArgument &Arg,
PrintingPolicy &Policy, ASTContext &Ctx) {
std::string Out;
llvm::raw_string_ostream OS(Out);
EnumValueTemplateArgPrinter Printer(OS, Policy, Ctx);
Printer.Visit(Arg);
OS.flush();
return Out;
}

static std::string stringifyTypeWithEnumValues(QualType T,
PrintingPolicy &Policy,
ASTContext &Ctx) {
std::string Out;
llvm::raw_string_ostream OS(Out);
EnumValueTemplateArgPrinter Printer(OS, Policy, Ctx);
Printer.PrintType(T);
OS.flush();
return Out;
}

class FreeFunctionPrinter {
raw_ostream &O;
PrintingPolicy &Policy;
Expand Down Expand Up @@ -6993,13 +7173,14 @@ class FreeFunctionPrinter {
else if (X.getKind() == TemplateArgument::Pack) {
for (const auto &PackArg : X.pack_elements()) {
StringStream << ", ";
PackArg.print(Policy, StringStream, /*IncludeType*/ true);
StringStream << stringifyTemplateArgWithEnumValues(PackArg, Policy,
Context);
}
continue;
} else
StringStream << ", ";

X.print(Policy, StringStream, /*IncludeType*/ true);
StringStream << stringifyTemplateArgWithEnumValues(X, Policy, Context);
}
StringStream.flush();
if (Buffer.front() != '<')
Expand Down Expand Up @@ -7409,12 +7590,15 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
ParmList += "Args ...";
} else {
Policy.SuppressTagKeyword = true;
Param->getType().print(ParmListWithNamesOstream, Policy);
ParmListWithNamesOstream << stringifyTypeWithEnumValues(
Param->getType(), Policy, S.getASTContext());
Policy.SuppressTagKeyword = false;
ParmListWithNamesOstream << " " << Param->getNameAsString();
ParmList += Param->getType().getCanonicalType().getAsString(Policy);
ParmList += stringifyTypeWithEnumValues(
Param->getType().getCanonicalType(), Policy, S.getASTContext());
}
}

ParmListWithNamesOstream.flush();
FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate();
Policy.PrintAsCanonical = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,22 @@ void constexpr_ns2(Arg<ns::Foo::D>) {}
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns2(Arg2<ns::non_class_enum::VAL_A>) {}

// CHECK: void constexpr_ns2(Arg2<ns::VAL_A> );
// CHECK: void constexpr_ns2(Arg2<static_cast<ns::non_class_enum>(0)> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns2(Arg3<ns::non_class_enum_typed::VAL_C>) {}

// CHECK: void constexpr_ns2(Arg3<ns::VAL_C> );
// CHECK: void constexpr_ns2(Arg3<static_cast<ns::non_class_enum_typed>(0)> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns2(Arg4<ns::class_enum::VAL_A>) {}

// CHECK: void constexpr_ns2(Arg4<ns::class_enum::VAL_A> );
// CHECK: void constexpr_ns2(Arg4<static_cast<ns::class_enum>(0)> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_ns2(Arg5<ns::class_enum_typed::VAL_C>) {}

// CHECK: void constexpr_ns2(Arg5<ns::class_enum_typed::VAL_C> );
// CHECK: void constexpr_ns2(Arg5<static_cast<ns::class_enum_typed>(0)> );

[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void constexpr_call(Arg<ns::bar(B)>) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@ template<typename T>
void templated_on_A(ns::feature_A<T> Arg) {}
template void templated_on_A(ns::feature_A<int>);

// CHECK: template <typename T> void templated_on_A(ns::feature_A<T, ns::enum_A::B>);
// CHECK: template <typename T> void templated_on_A(ns::feature_A<T, static_cast<ns::enum_A>(1)>);

template<typename T, int V = 42>
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void templated_on_B(ns::nested::feature_B<T, V> Arg) {}
template void templated_on_B(ns::nested::feature_B<int, 12>);

// CHECK: template <typename T, int V> void templated_on_B(ns::nested::feature_B<T, V, ns::nested::enum_B::A, ns::enum_A::C>);
// CHECK: template <typename T, int V> void templated_on_B(ns::nested::feature_B<T, V, static_cast<ns::nested::enum_B>(0), static_cast<ns::enum_A>(2)>);

template<int V>
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]]
void templated_on_C(ns::nested2::feature_C<V> Arg) {}
template void templated_on_C(ns::nested2::feature_C<42>);

// CHECK: template <int V> void templated_on_C(ns::nested2::feature_C<V, ns::nested2::enum_C::B>);
// CHECK: template <int V> void templated_on_C(ns::nested2::feature_C<V, static_cast<ns::nested2::enum_C>(1)>);
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,9 @@ namespace Testing::Tests {
// CHECK-NEXT: }
// CHECK-NEXT: template <typename T, typename> struct Arg1;

// CHECK: void foo(Arg1<int, sycl::X<sycl::detail::Y> > arg);
// CHECK: void foo(Arg1<int, sycl::X<sycl::detail::Y>> arg);
// CHECK-NEXT: static constexpr auto __sycl_shim7() {
// CHECK-NEXT: return (void (*)(struct Arg1<int, struct sycl::X<struct sycl::detail::Y> >))foo;
// CHECK-NEXT: return (void (*)(struct Arg1<int, struct sycl::X<struct sycl::detail::Y>>))foo;
// CHECK-NEXT: }

// CHECK: namespace sycl {
Expand Down
4 changes: 2 additions & 2 deletions clang/test/CodeGenSYCL/free_function_int_header.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1486,9 +1486,9 @@ void ff_28(TemplatedAccessorStruct<int> arg1) {
// CHECK-NEXT: template <typename dataT, int dimensions, sycl::access::mode accessmode, sycl::access::target accessTarget, sycl::access::placeholder isPlaceholder, typename propertyListT> class accessor;
// CHECK-NEXT: }}

// CHECK: void ff_20(sycl::accessor<int, 1, sycl::access::mode::read_write, sycl::access::target::global_buffer, sycl::access::placeholder::false_t, sycl::ext::oneapi::accessor_property_list<> > acc);
// CHECK: void ff_20(sycl::accessor<int, 1, static_cast<sycl::access::mode>(1026), static_cast<sycl::access::target>(2014), static_cast<sycl::access::placeholder>(0), sycl::ext::oneapi::accessor_property_list<>> acc);
// CHECK-NEXT: static constexpr auto __sycl_shim29() {
// CHECK-NEXT: return (void (*)(class sycl::accessor<int, 1, sycl::access::mode::read_write, sycl::access::target::global_buffer, sycl::access::placeholder::false_t, class sycl::ext::oneapi::accessor_property_list<> >))ff_20;
// CHECK-NEXT: return (void (*)(class sycl::accessor<int, 1, static_cast<sycl::access::mode>(1026), static_cast<sycl::access::target>(2014), static_cast<sycl::access::placeholder>(0), class sycl::ext::oneapi::accessor_property_list<>>))ff_20;
// CHECK-NEXT: }

// CHECK: namespace sycl {
Expand Down
Loading
Loading