diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 3dd32e5b786e6..cf34c5aaeaf31 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -6613,6 +6613,254 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) { [](raw_ostream &, const NamespaceDecl *) {}, OS, DC); } +/// Dedicated visitor which helps with printing of kernel arguments in forward +/// declarations of free function kernels which are declared as function +/// templates. +/// +/// Based on: +/// \code +/// template +/// void foo(T1 a, int b, T2 c); +/// \endcode +/// +/// It prints into the output stream "T1, int, T2". +/// +/// The main complexity (which motivates addition of such visitor) comes from +/// the fact that there could be type aliases and default template arguments. +/// For example: +/// \code +/// template +/// void kernel(sycl::accessor); +/// template void kernel(sycl::accessor); +/// \endcode +/// sycl::accessor has many template arguments which have default values. If +/// we iterate over non-canonicalized argument type, we don't get those default +/// values and we don't get necessary namespace qualifiers for all the template +/// arguments. If we iterate over canonicalized argument type, then all +/// references to T will be replaced with something like type-argument-X-Y. +/// What this visitor does is it iterates over both in sync, picking the right +/// values from one or another. +/// +/// The template argument visitor functions take an additional +/// ArrayRef argument corresponding to the template arguments +/// of the outermost template. This is used by some of these functions for +/// mapping dependent template arguments. +/// +/// Moral of the story: drop integration header ASAP (but that is blocked +/// by support for 3rd-party host compilers, which is important). +class FreeFunctionTemplateKernelArgsPrinter + : public ConstTemplateArgumentVisitor> { + raw_ostream &O; + PrintingPolicy &Policy; + ASTContext &Context; + + using Base = + ConstTemplateArgumentVisitor>; + + // Desugars a template argument. This helps avoid aliases. + static TemplateArgument DesugarTemplateArgument(const TemplateArgument &Arg) { + switch (Arg.getKind()) { + case TemplateArgument::ArgKind::Type: { + QualType ArgTy = Arg.getAsType(); + return {QualType(ArgTy->getUnqualifiedDesugaredType(), + ArgTy.getCVRQualifiers())}; + } + case TemplateArgument::ArgKind::Template: { + TemplateName TN = Arg.getAsTemplate(); + while (std::optional DesugaredTN = + TN.desugar(/*IgnoreDeduced=*/false)) + TN = *DesugaredTN; + return {TN}; + } + default: + return Arg; + } + } + + void PrintDesugared(const TemplateArgument &Arg) { + DesugarTemplateArgument(Arg).print(Policy, O, /*IncludeType=*/false); + } + +public: + FreeFunctionTemplateKernelArgsPrinter(raw_ostream &O, PrintingPolicy &Policy, + ASTContext &Context) + : O(O), Policy(Policy), Context(Context) {} + + void Visit(const ParmVarDecl *Param) { + // There are cases when we can't directly use neither the original + // argument type, nor its canonical version. An example would be: + // template + // void kernel(sycl::accessor); + // template void kernel(sycl::accessor); + // Accessor has multiple non-type template arguments with default values + // and non-qualified type will not include necessary namespaces for all + // of them. Qualified type will have that information, but all references + // to T will be replaced to something like type-argument-0 + // What we do instead is we iterate template arguments of both versions + // of a type in sync and take elements from one or another to get the best + // of both: proper references to template arguments of a kernel itself and + // fully-qualified names for enumerations. + // + // Moral of the story: drop integration header ASAP (but that is blocked + // by support for 3rd-party host compilers, which is important). + QualType T = Param->getType(); + QualType CT = T.getCanonicalType(); + + const auto *TST = dyn_cast(T.getTypePtr()); + const auto *CTST = dyn_cast(CT.getTypePtr()); + if (!TST || !CTST) { + O << T.getDesugaredType(Context).getAsString(Policy); + return; + } + + const TemplateSpecializationType *TSTAsNonAlias = + TST->getAsNonAliasTemplateSpecializationType(); + if (TSTAsNonAlias) + TST = TSTAsNonAlias; + + ArrayRef SpecArgs = TST->template_arguments(); + ArrayRef DeclArgs = CTST->template_arguments(); + + const TemplateDecl *TD = CTST->getTemplateName().getAsTemplateDecl(); + if (!TD->getIdentifier()) + TD = TST->getTemplateName().getAsTemplateDecl(); + assert(TD->getIdentifier() && + "Either the type or the canonical type should have an identifier."); + TD->printQualifiedName(O); + + O << "<"; + for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()), + SE = SpecArgs.size(); + I < E; ++I) { + if (I != 0) + O << ", "; + // If we have a specialized argument, use it. Otherwise fallback to a + // default argument. + // We pass specialized arguments in case there are references to them + // from other types. + // FIXME: passing SpecArgs here is incorrect. It refers to template + // arguments of a single function argument, but DeclArgs contain + // references (in form of depth-index) to template arguments of the + // function itself which results in incorrect integration header being + // produced. + Base::Visit(I < SE ? SpecArgs[I] : DeclArgs[I], SpecArgs); + } + O << ">"; + } + + // Internal version of the function above that is used when template argument + // is a template by itself + void Visit(const TemplateSpecializationType *T, + ArrayRef SpecArgs) { + const TemplateDecl *TD = T->getTemplateName().getAsTemplateDecl(); + const auto *TTPD = dyn_cast(TD); + if (TTPD && !TTPD->getIdentifier()) + PrintDesugared(SpecArgs[TTPD->getIndex()]); + else + TD->printQualifiedName(O); + O << "<"; + ArrayRef DeclArgs = T->template_arguments(); + for (size_t I = 0, E = DeclArgs.size(); I < E; ++I) { + if (I != 0) + O << ", "; + Base::Visit(DeclArgs[I], SpecArgs); + } + O << ">"; + } + + void VisitNullTemplateArgument(const TemplateArgument &, + ArrayRef) { + llvm_unreachable("If template argument has not been deduced, then we can't " + "forward-declare it, something went wrong"); + } + + void VisitTypeTemplateArgument(const TemplateArgument &Arg, + ArrayRef SpecArgs) { + TemplateArgument DesugaredArg = DesugarTemplateArgument(Arg); + // If we reference an existing template argument without a known identifier, + // print it instead. + const auto *TPT = dyn_cast(DesugaredArg.getAsType()); + if (TPT && !TPT->getIdentifier()) { + PrintDesugared(SpecArgs[TPT->getIndex()]); + return; + } + + const auto *TST = + dyn_cast(DesugaredArg.getAsType()); + if (TST && Arg.isInstantiationDependent()) { + // This is an instantiation dependent template specialization, meaning + // that some of its arguments reference template arguments of the free + // function kernel itself. + Visit(TST, SpecArgs); + return; + } + + DesugaredArg.print(Policy, O, /* IncludeType = */ false); + } + + void VisitDeclarationTemplateArgument(const TemplateArgument &, + ArrayRef) { + llvm_unreachable("Free function kernels cannot have non-type template " + "arguments which are pointers or references"); + } + + void VisitNullPtrTemplateArgument(const TemplateArgument &, + ArrayRef) { + llvm_unreachable("Free function kernels cannot have non-type template " + "arguments which are pointers or references"); + } + + void VisitIntegralTemplateArgument(const TemplateArgument &Arg, + ArrayRef) { + PrintDesugared(Arg); + } + + void VisitStructuralValueTemplateArgument(const TemplateArgument &Arg, + ArrayRef) { + PrintDesugared(Arg); + } + + void VisitTemplateTemplateArgument(const TemplateArgument &Arg, + ArrayRef) { + PrintDesugared(Arg); + } + + void VisitTemplateExpansionTemplateArgument(const TemplateArgument &Arg, + ArrayRef) { + PrintDesugared(Arg); + } + + void VisitExpressionTemplateArgument(const TemplateArgument &Arg, + ArrayRef) { + Expr *E = Arg.getAsExpr(); + assert(E && "Failed to get an Expr for an Expression template arg?"); + + if (Arg.isInstantiationDependent() || + E->getType()->isScopedEnumeralType()) { + // Scoped enumerations can't be implicitly cast from integers, so + // we don't need to evaluate them. + // If expression is instantiation-dependent, then we can't evaluate it + // either, let's fallback to default printing mechanism. + PrintDesugared(Arg); + return; + } + + Expr::EvalResult Res; + [[maybe_unused]] bool Success = + Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context); + assert(Success && "invalid non-type template argument?"); + assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?"); + Res.Val.printPretty(O, Policy, Arg.getAsExpr()->getType(), &Context); + } + + void VisitPackTemplateArgument(const TemplateArgument &Arg, + ArrayRef) { + PrintDesugared(Arg); + } +}; + class FreeFunctionPrinter { raw_ostream &O; PrintingPolicy &Policy; @@ -6776,86 +7024,16 @@ class FreeFunctionPrinter { llvm::raw_svector_ostream ParmListOstream{ParamList}; Policy.SuppressTagKeyword = true; - for (ParmVarDecl *Param : Parameters) { + FreeFunctionTemplateKernelArgsPrinter Printer(ParmListOstream, Policy, + Context); + + for (const ParmVarDecl *Param : Parameters) { if (FirstParam) FirstParam = false; else ParmListOstream << ", "; - // There are cases when we can't directly use neither the original - // argument type, nor its canonical version. An example would be: - // template - // void kernel(sycl::accessor); - // template void kernel(sycl::accessor); - // Accessor has multiple non-type template arguments with default values - // and non-qualified type will not include necessary namespaces for all - // of them. Qualified type will have that information, but all references - // to T will be replaced to something like type-argument-0 - // What we do instead is we iterate template arguments of both versions - // of a type in sync and take elements from one or another to get the best - // of both: proper references to template arguments of a kernel itself and - // fully-qualified names for enumerations. - // - // Moral of the story: drop integration header ASAP (but that is blocked - // by support for 3rd-party host compilers, which is important). - QualType T = Param->getType(); - QualType CT = T.getCanonicalType(); - - const auto *TST = dyn_cast(T.getTypePtr()); - const auto *CTST = dyn_cast(CT.getTypePtr()); - if (!TST || !CTST) { - ParmListOstream << T.getAsString(Policy); - continue; - } - - const TemplateSpecializationType *TSTAsNonAlias = - TST->getAsNonAliasTemplateSpecializationType(); - if (TSTAsNonAlias) - TST = TSTAsNonAlias; - - TemplateName CTN = CTST->getTemplateName(); - CTN.getAsTemplateDecl()->printQualifiedName(ParmListOstream); - ParmListOstream << "<"; - - ArrayRef SpecArgs = TST->template_arguments(); - ArrayRef DeclArgs = CTST->template_arguments(); - - auto TemplateArgPrinter = [&](const TemplateArgument &Arg) { - if (Arg.getKind() != TemplateArgument::ArgKind::Expression || - Arg.isInstantiationDependent()) { - Arg.print(Policy, ParmListOstream, /* IncludeType = */ false); - return; - } - - Expr *E = Arg.getAsExpr(); - assert(E && "Failed to get an Expr for an Expression template arg?"); - if (E->getType().getTypePtr()->isScopedEnumeralType()) { - // Scoped enumerations can't be implicitly cast from integers, so - // we don't need to evaluate them. - Arg.print(Policy, ParmListOstream, /* IncludeType = */ false); - return; - } - - Expr::EvalResult Res; - [[maybe_unused]] bool Success = - Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context); - assert(Success && "invalid non-type template argument?"); - assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?"); - Res.Val.printPretty(ParmListOstream, Policy, Arg.getAsExpr()->getType(), - &Context); - }; - - for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()), - SE = SpecArgs.size(); - I < E; ++I) { - if (I != 0) - ParmListOstream << ", "; - // If we have a specialized argument, use it. Otherwise fallback to a - // default argument. - TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]); - } - - ParmListOstream << ">"; + Printer.Visit(Param); } return ParamList.str().str(); } @@ -6873,26 +7051,39 @@ class FreeFunctionPrinter { std::string getTemplateParameters(const clang::TemplateParameterList *TPL) { std::string TemplateParams{"template <"}; bool FirstParam{true}; - for (NamedDecl *Param : *TPL) { + for (const NamedDecl *Param : *TPL) { if (!FirstParam) TemplateParams += ", "; FirstParam = false; - if (const auto *TemplateParam = dyn_cast(Param)) { - TemplateParams += - TemplateParam->wasDeclaredWithTypename() ? "typename " : "class "; - if (TemplateParam->isParameterPack()) - TemplateParams += "... "; - TemplateParams += TemplateParam->getNameAsString(); - } else if (const auto *NonTypeParam = - dyn_cast(Param)) { - TemplateParams += NonTypeParam->getType().getAsString(); - TemplateParams += " "; - TemplateParams += NonTypeParam->getNameAsString(); - } + TemplateParams += getTemplateParameter(Param); } TemplateParams += "> "; return TemplateParams; } + + /// Helper method to get text representation of a template parameter. + /// \param Param The template parameter. + std::string getTemplateParameter(const NamedDecl *Param) { + auto GetTypenameOrClass = [](const auto *Param) { + return Param->wasDeclaredWithTypename() ? "typename " : "class "; + }; + if (const auto *TemplateParam = dyn_cast(Param)) { + std::string TemplateParamStr = GetTypenameOrClass(TemplateParam); + if (TemplateParam->isParameterPack()) + TemplateParamStr += "... "; + TemplateParamStr += TemplateParam->getNameAsString(); + return TemplateParamStr; + } else if (const auto *NonTypeParam = + dyn_cast(Param)) { + return NonTypeParam->getType().getAsString() + " " + + NonTypeParam->getNameAsString(); + } else if (const auto *TTParam = + dyn_cast(Param)) { + return getTemplateParameters(TTParam->getTemplateParameters()) + " " + + GetTypenameOrClass(TTParam) + TTParam->getNameAsString(); + } + return ""; + } }; void SYCLIntegrationHeader::emit(raw_ostream &O) { diff --git a/clang/test/CodeGenSYCL/free-function-kernel-pack-template-arg.cpp b/clang/test/CodeGenSYCL/free-function-kernel-pack-template-arg.cpp new file mode 100644 index 0000000000000..90faccbd8807a --- /dev/null +++ b/clang/test/CodeGenSYCL/free-function-kernel-pack-template-arg.cpp @@ -0,0 +1,14 @@ +// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s +// RUN: FileCheck -input-file=%t.h %s +// +// The purpose of this test is to ensure that forward declarations of free +// function kernels are emitted properly. +// However, this test checks a specific scenario: +// - parameter packs are emitted correctly. + +template +[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] +void parameter_pack(T... Args) {} +template void parameter_pack(int Arg1, float Arg2); + +// CHECK: template void parameter_pack(T...); diff --git a/clang/test/CodeGenSYCL/free-function-kernel-type-alias-arg.cpp b/clang/test/CodeGenSYCL/free-function-kernel-type-alias-arg.cpp index 5d6ea216d7d38..aa7db827743c3 100644 --- a/clang/test/CodeGenSYCL/free-function-kernel-type-alias-arg.cpp +++ b/clang/test/CodeGenSYCL/free-function-kernel-type-alias-arg.cpp @@ -14,6 +14,9 @@ typedef int IntTypedef; template struct Foo {}; +template +using FooUsing = Foo; + using FooIntUsing = Foo; typedef Foo FooIntTypedef; @@ -29,15 +32,35 @@ using BarUsing2 = Bar, T1>; template using BarUsingBarUsing2 = BarUsing2; +template +using BarUsingFooIntUsing = Bar; + +template +using BarUsingBarUsingFooIntUsing = BarUsingFooIntUsing; + class Baz { public: using type = BarUsing; }; +template