-
Notifications
You must be signed in to change notification settings - Fork 801
[SYCL][clang] Fix more free-function kernel integration header cases #20877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
549ca8f
bea8af0
14b14d5
3c7f972
abbb457
9bdcbac
b4a108d
07ff294
ae6028e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6613,6 +6613,254 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) { | |
| [](raw_ostream &, const NamespaceDecl *) {}, OS, DC); | ||
| } | ||
|
|
||
| /// Dedicated visitor which helps with printing of kernel arguments in forward | ||
| /// declarations of free function kernels which are declared as function | ||
| /// templates. | ||
| /// | ||
| /// Based on: | ||
| /// \code | ||
| /// template <typename T1, typename T2> | ||
| /// void foo(T1 a, int b, T2 c); | ||
| /// \endcode | ||
| /// | ||
| /// It prints into the output stream "T1, int, T2". | ||
| /// | ||
| /// The main complexity (which motivates addition of such visitor) comes from | ||
| /// the fact that there could be type aliases and default template arguments. | ||
| /// For example: | ||
| /// \code | ||
| /// template<typename T> | ||
| /// void kernel(sycl::accessor<T, 1>); | ||
| /// template void kernel(sycl::accessor<int, 1>); | ||
| /// \endcode | ||
| /// sycl::accessor has many template arguments which have default values. If | ||
| /// we iterate over non-canonicalized argument type, we don't get those default | ||
| /// values and we don't get necessary namespace qualifiers for all the template | ||
| /// arguments. If we iterate over canonicalized argument type, then all | ||
| /// references to T will be replaced with something like type-argument-X-Y. | ||
| /// What this visitor does is it iterates over both in sync, picking the right | ||
| /// values from one or another. | ||
| /// | ||
| /// The template argument visitor functions take an additional | ||
| /// ArrayRef<TemplateArgument> argument corresponding to the template arguments | ||
| /// of the outermost template. This is used by some of these functions for | ||
| /// mapping dependent template arguments. | ||
| /// | ||
| /// Moral of the story: drop integration header ASAP (but that is blocked | ||
| /// by support for 3rd-party host compilers, which is important). | ||
| class FreeFunctionTemplateKernelArgsPrinter | ||
| : public ConstTemplateArgumentVisitor<FreeFunctionTemplateKernelArgsPrinter, | ||
| void, ArrayRef<TemplateArgument>> { | ||
| raw_ostream &O; | ||
| PrintingPolicy &Policy; | ||
| ASTContext &Context; | ||
|
|
||
| using Base = | ||
| ConstTemplateArgumentVisitor<FreeFunctionTemplateKernelArgsPrinter, void, | ||
| ArrayRef<TemplateArgument>>; | ||
|
|
||
| // Desugars a template argument. This helps avoid aliases. | ||
| static TemplateArgument DesugarTemplateArgument(const TemplateArgument &Arg) { | ||
| switch (Arg.getKind()) { | ||
| case TemplateArgument::ArgKind::Type: { | ||
| QualType ArgTy = Arg.getAsType(); | ||
| return {QualType(ArgTy->getUnqualifiedDesugaredType(), | ||
| ArgTy.getCVRQualifiers())}; | ||
| } | ||
| case TemplateArgument::ArgKind::Template: { | ||
| TemplateName TN = Arg.getAsTemplate(); | ||
| while (std::optional<TemplateName> DesugaredTN = | ||
| TN.desugar(/*IgnoreDeduced=*/false)) | ||
| TN = *DesugaredTN; | ||
| return {TN}; | ||
| } | ||
| default: | ||
| return Arg; | ||
| } | ||
| } | ||
|
|
||
| void PrintDesugared(const TemplateArgument &Arg) { | ||
| DesugarTemplateArgument(Arg).print(Policy, O, /*IncludeType=*/false); | ||
| } | ||
|
|
||
| public: | ||
| FreeFunctionTemplateKernelArgsPrinter(raw_ostream &O, PrintingPolicy &Policy, | ||
| ASTContext &Context) | ||
| : O(O), Policy(Policy), Context(Context) {} | ||
|
|
||
| void Visit(const ParmVarDecl *Param) { | ||
| // There are cases when we can't directly use neither the original | ||
| // argument type, nor its canonical version. An example would be: | ||
| // template<typename T> | ||
| // void kernel(sycl::accessor<T, 1>); | ||
| // template void kernel(sycl::accessor<int, 1>); | ||
| // Accessor has multiple non-type template arguments with default values | ||
| // and non-qualified type will not include necessary namespaces for all | ||
| // of them. Qualified type will have that information, but all references | ||
| // to T will be replaced to something like type-argument-0 | ||
| // What we do instead is we iterate template arguments of both versions | ||
| // of a type in sync and take elements from one or another to get the best | ||
| // of both: proper references to template arguments of a kernel itself and | ||
| // fully-qualified names for enumerations. | ||
| // | ||
| // Moral of the story: drop integration header ASAP (but that is blocked | ||
| // by support for 3rd-party host compilers, which is important). | ||
| QualType T = Param->getType(); | ||
| QualType CT = T.getCanonicalType(); | ||
|
|
||
| const auto *TST = dyn_cast<TemplateSpecializationType>(T.getTypePtr()); | ||
| const auto *CTST = dyn_cast<TemplateSpecializationType>(CT.getTypePtr()); | ||
| if (!TST || !CTST) { | ||
| O << T.getDesugaredType(Context).getAsString(Policy); | ||
| return; | ||
| } | ||
|
|
||
| const TemplateSpecializationType *TSTAsNonAlias = | ||
| TST->getAsNonAliasTemplateSpecializationType(); | ||
| if (TSTAsNonAlias) | ||
| TST = TSTAsNonAlias; | ||
|
|
||
| ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments(); | ||
| ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments(); | ||
|
|
||
| const TemplateDecl *TD = CTST->getTemplateName().getAsTemplateDecl(); | ||
| if (!TD->getIdentifier()) | ||
| TD = TST->getTemplateName().getAsTemplateDecl(); | ||
| assert(TD->getIdentifier() && | ||
| "Either the type or the canonical type should have an identifier."); | ||
| TD->printQualifiedName(O); | ||
|
|
||
| O << "<"; | ||
| for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()), | ||
| SE = SpecArgs.size(); | ||
| I < E; ++I) { | ||
| if (I != 0) | ||
| O << ", "; | ||
| // If we have a specialized argument, use it. Otherwise fallback to a | ||
| // default argument. | ||
| // We pass specialized arguments in case there are references to them | ||
| // from other types. | ||
| // FIXME: passing SpecArgs here is incorrect. It refers to template | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That FIXME does seem concerning.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, though the current version of this still addresses a chunk of the current issues we are seeing with the prototype generation. I can try to add some more disabled cases to the test so we know what to fix.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After having a look at this, I cannot seem to find a case that allows template arguments of differing depth. Maybe @AlexeySachkov knows of one? |
||
| // arguments of a single function argument, but DeclArgs contain | ||
| // references (in form of depth-index) to template arguments of the | ||
| // function itself which results in incorrect integration header being | ||
| // produced. | ||
| Base::Visit(I < SE ? SpecArgs[I] : DeclArgs[I], SpecArgs); | ||
| } | ||
| O << ">"; | ||
| } | ||
|
|
||
| // Internal version of the function above that is used when template argument | ||
| // is a template by itself | ||
| void Visit(const TemplateSpecializationType *T, | ||
| ArrayRef<TemplateArgument> SpecArgs) { | ||
| const TemplateDecl *TD = T->getTemplateName().getAsTemplateDecl(); | ||
| const auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(TD); | ||
| if (TTPD && !TTPD->getIdentifier()) | ||
| PrintDesugared(SpecArgs[TTPD->getIndex()]); | ||
| else | ||
| TD->printQualifiedName(O); | ||
| O << "<"; | ||
| ArrayRef<const TemplateArgument> DeclArgs = T->template_arguments(); | ||
| for (size_t I = 0, E = DeclArgs.size(); I < E; ++I) { | ||
| if (I != 0) | ||
| O << ", "; | ||
| Base::Visit(DeclArgs[I], SpecArgs); | ||
| } | ||
| O << ">"; | ||
| } | ||
|
|
||
| void VisitNullTemplateArgument(const TemplateArgument &, | ||
| ArrayRef<TemplateArgument>) { | ||
| llvm_unreachable("If template argument has not been deduced, then we can't " | ||
| "forward-declare it, something went wrong"); | ||
| } | ||
|
|
||
| void VisitTypeTemplateArgument(const TemplateArgument &Arg, | ||
| ArrayRef<TemplateArgument> SpecArgs) { | ||
| TemplateArgument DesugaredArg = DesugarTemplateArgument(Arg); | ||
| // If we reference an existing template argument without a known identifier, | ||
| // print it instead. | ||
| const auto *TPT = dyn_cast<TemplateTypeParmType>(DesugaredArg.getAsType()); | ||
| if (TPT && !TPT->getIdentifier()) { | ||
| PrintDesugared(SpecArgs[TPT->getIndex()]); | ||
| return; | ||
| } | ||
|
|
||
| const auto *TST = | ||
| dyn_cast<TemplateSpecializationType>(DesugaredArg.getAsType()); | ||
| if (TST && Arg.isInstantiationDependent()) { | ||
| // This is an instantiation dependent template specialization, meaning | ||
| // that some of its arguments reference template arguments of the free | ||
| // function kernel itself. | ||
| Visit(TST, SpecArgs); | ||
| return; | ||
| } | ||
|
|
||
| DesugaredArg.print(Policy, O, /* IncludeType = */ false); | ||
| } | ||
|
|
||
| void VisitDeclarationTemplateArgument(const TemplateArgument &, | ||
| ArrayRef<TemplateArgument>) { | ||
| llvm_unreachable("Free function kernels cannot have non-type template " | ||
Fznamznon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "arguments which are pointers or references"); | ||
| } | ||
|
|
||
| void VisitNullPtrTemplateArgument(const TemplateArgument &, | ||
| ArrayRef<TemplateArgument>) { | ||
| llvm_unreachable("Free function kernels cannot have non-type template " | ||
| "arguments which are pointers or references"); | ||
| } | ||
|
|
||
| void VisitIntegralTemplateArgument(const TemplateArgument &Arg, | ||
| ArrayRef<TemplateArgument>) { | ||
| PrintDesugared(Arg); | ||
| } | ||
|
|
||
| void VisitStructuralValueTemplateArgument(const TemplateArgument &Arg, | ||
| ArrayRef<TemplateArgument>) { | ||
| PrintDesugared(Arg); | ||
| } | ||
|
|
||
| void VisitTemplateTemplateArgument(const TemplateArgument &Arg, | ||
| ArrayRef<TemplateArgument>) { | ||
| PrintDesugared(Arg); | ||
| } | ||
|
|
||
| void VisitTemplateExpansionTemplateArgument(const TemplateArgument &Arg, | ||
| ArrayRef<TemplateArgument>) { | ||
| PrintDesugared(Arg); | ||
| } | ||
|
|
||
| void VisitExpressionTemplateArgument(const TemplateArgument &Arg, | ||
| ArrayRef<TemplateArgument>) { | ||
Fznamznon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Expr *E = Arg.getAsExpr(); | ||
| assert(E && "Failed to get an Expr for an Expression template arg?"); | ||
|
|
||
| if (Arg.isInstantiationDependent() || | ||
| E->getType()->isScopedEnumeralType()) { | ||
| // Scoped enumerations can't be implicitly cast from integers, so | ||
| // we don't need to evaluate them. | ||
| // If expression is instantiation-dependent, then we can't evaluate it | ||
| // either, let's fallback to default printing mechanism. | ||
| PrintDesugared(Arg); | ||
| return; | ||
| } | ||
|
|
||
| Expr::EvalResult Res; | ||
| [[maybe_unused]] bool Success = | ||
| Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context); | ||
| assert(Success && "invalid non-type template argument?"); | ||
| assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?"); | ||
| Res.Val.printPretty(O, Policy, Arg.getAsExpr()->getType(), &Context); | ||
| } | ||
|
|
||
| void VisitPackTemplateArgument(const TemplateArgument &Arg, | ||
| ArrayRef<TemplateArgument>) { | ||
| PrintDesugared(Arg); | ||
| } | ||
| }; | ||
|
|
||
| class FreeFunctionPrinter { | ||
| raw_ostream &O; | ||
| PrintingPolicy &Policy; | ||
|
|
@@ -6776,86 +7024,16 @@ class FreeFunctionPrinter { | |
| llvm::raw_svector_ostream ParmListOstream{ParamList}; | ||
| Policy.SuppressTagKeyword = true; | ||
|
|
||
| for (ParmVarDecl *Param : Parameters) { | ||
| FreeFunctionTemplateKernelArgsPrinter Printer(ParmListOstream, Policy, | ||
| Context); | ||
|
|
||
| for (const ParmVarDecl *Param : Parameters) { | ||
| if (FirstParam) | ||
| FirstParam = false; | ||
| else | ||
| ParmListOstream << ", "; | ||
|
|
||
| // There are cases when we can't directly use neither the original | ||
| // argument type, nor its canonical version. An example would be: | ||
| // template<typename T> | ||
| // void kernel(sycl::accessor<T, 1>); | ||
| // template void kernel(sycl::accessor<int, 1>); | ||
| // Accessor has multiple non-type template arguments with default values | ||
| // and non-qualified type will not include necessary namespaces for all | ||
| // of them. Qualified type will have that information, but all references | ||
| // to T will be replaced to something like type-argument-0 | ||
| // What we do instead is we iterate template arguments of both versions | ||
| // of a type in sync and take elements from one or another to get the best | ||
| // of both: proper references to template arguments of a kernel itself and | ||
| // fully-qualified names for enumerations. | ||
| // | ||
| // Moral of the story: drop integration header ASAP (but that is blocked | ||
| // by support for 3rd-party host compilers, which is important). | ||
| QualType T = Param->getType(); | ||
| QualType CT = T.getCanonicalType(); | ||
|
|
||
| const auto *TST = dyn_cast<TemplateSpecializationType>(T.getTypePtr()); | ||
| const auto *CTST = dyn_cast<TemplateSpecializationType>(CT.getTypePtr()); | ||
| if (!TST || !CTST) { | ||
| ParmListOstream << T.getAsString(Policy); | ||
| continue; | ||
| } | ||
|
|
||
| const TemplateSpecializationType *TSTAsNonAlias = | ||
| TST->getAsNonAliasTemplateSpecializationType(); | ||
| if (TSTAsNonAlias) | ||
| TST = TSTAsNonAlias; | ||
|
|
||
| TemplateName CTN = CTST->getTemplateName(); | ||
| CTN.getAsTemplateDecl()->printQualifiedName(ParmListOstream); | ||
| ParmListOstream << "<"; | ||
|
|
||
| ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments(); | ||
| ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments(); | ||
|
|
||
| auto TemplateArgPrinter = [&](const TemplateArgument &Arg) { | ||
| if (Arg.getKind() != TemplateArgument::ArgKind::Expression || | ||
| Arg.isInstantiationDependent()) { | ||
| Arg.print(Policy, ParmListOstream, /* IncludeType = */ false); | ||
| return; | ||
| } | ||
|
|
||
| Expr *E = Arg.getAsExpr(); | ||
| assert(E && "Failed to get an Expr for an Expression template arg?"); | ||
| if (E->getType().getTypePtr()->isScopedEnumeralType()) { | ||
| // Scoped enumerations can't be implicitly cast from integers, so | ||
| // we don't need to evaluate them. | ||
| Arg.print(Policy, ParmListOstream, /* IncludeType = */ false); | ||
| return; | ||
| } | ||
|
|
||
| Expr::EvalResult Res; | ||
| [[maybe_unused]] bool Success = | ||
| Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context); | ||
| assert(Success && "invalid non-type template argument?"); | ||
| assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?"); | ||
| Res.Val.printPretty(ParmListOstream, Policy, Arg.getAsExpr()->getType(), | ||
| &Context); | ||
| }; | ||
|
|
||
| for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()), | ||
| SE = SpecArgs.size(); | ||
| I < E; ++I) { | ||
| if (I != 0) | ||
| ParmListOstream << ", "; | ||
| // If we have a specialized argument, use it. Otherwise fallback to a | ||
| // default argument. | ||
| TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]); | ||
| } | ||
|
|
||
| ParmListOstream << ">"; | ||
| Printer.Visit(Param); | ||
| } | ||
| return ParamList.str().str(); | ||
| } | ||
|
|
@@ -6873,26 +7051,39 @@ class FreeFunctionPrinter { | |
| std::string getTemplateParameters(const clang::TemplateParameterList *TPL) { | ||
| std::string TemplateParams{"template <"}; | ||
| bool FirstParam{true}; | ||
| for (NamedDecl *Param : *TPL) { | ||
| for (const NamedDecl *Param : *TPL) { | ||
| if (!FirstParam) | ||
| TemplateParams += ", "; | ||
| FirstParam = false; | ||
| if (const auto *TemplateParam = dyn_cast<TemplateTypeParmDecl>(Param)) { | ||
| TemplateParams += | ||
| TemplateParam->wasDeclaredWithTypename() ? "typename " : "class "; | ||
| if (TemplateParam->isParameterPack()) | ||
| TemplateParams += "... "; | ||
| TemplateParams += TemplateParam->getNameAsString(); | ||
| } else if (const auto *NonTypeParam = | ||
| dyn_cast<NonTypeTemplateParmDecl>(Param)) { | ||
| TemplateParams += NonTypeParam->getType().getAsString(); | ||
| TemplateParams += " "; | ||
| TemplateParams += NonTypeParam->getNameAsString(); | ||
| } | ||
| TemplateParams += getTemplateParameter(Param); | ||
| } | ||
| TemplateParams += "> "; | ||
| return TemplateParams; | ||
| } | ||
|
|
||
| /// Helper method to get text representation of a template parameter. | ||
| /// \param Param The template parameter. | ||
| std::string getTemplateParameter(const NamedDecl *Param) { | ||
| auto GetTypenameOrClass = [](const auto *Param) { | ||
| return Param->wasDeclaredWithTypename() ? "typename " : "class "; | ||
| }; | ||
| if (const auto *TemplateParam = dyn_cast<TemplateTypeParmDecl>(Param)) { | ||
| std::string TemplateParamStr = GetTypenameOrClass(TemplateParam); | ||
| if (TemplateParam->isParameterPack()) | ||
| TemplateParamStr += "... "; | ||
| TemplateParamStr += TemplateParam->getNameAsString(); | ||
| return TemplateParamStr; | ||
| } else if (const auto *NonTypeParam = | ||
| dyn_cast<NonTypeTemplateParmDecl>(Param)) { | ||
| return NonTypeParam->getType().getAsString() + " " + | ||
| NonTypeParam->getNameAsString(); | ||
| } else if (const auto *TTParam = | ||
| dyn_cast<TemplateTemplateParmDecl>(Param)) { | ||
| return getTemplateParameters(TTParam->getTemplateParameters()) + " " + | ||
| GetTypenameOrClass(TTParam) + TTParam->getNameAsString(); | ||
| } | ||
| return ""; | ||
| } | ||
| }; | ||
|
|
||
| void SYCLIntegrationHeader::emit(raw_ostream &O) { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.