Skip to content
367 changes: 279 additions & 88 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6613,6 +6613,254 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
[](raw_ostream &, const NamespaceDecl *) {}, OS, DC);
}

/// Dedicated visitor which helps with printing of kernel arguments in forward
/// declarations of free function kernels which are declared as function
/// templates.
///
/// Based on:
/// \code
/// template <typename T1, typename T2>
/// void foo(T1 a, int b, T2 c);
/// \endcode
///
/// It prints into the output stream "T1, int, T2".
///
/// The main complexity (which motivates addition of such visitor) comes from
/// the fact that there could be type aliases and default template arguments.
/// For example:
/// \code
/// template<typename T>
/// void kernel(sycl::accessor<T, 1>);
/// template void kernel(sycl::accessor<int, 1>);
/// \endcode
/// sycl::accessor has many template arguments which have default values. If
/// we iterate over non-canonicalized argument type, we don't get those default
/// values and we don't get necessary namespace qualifiers for all the template
/// arguments. If we iterate over canonicalized argument type, then all
/// references to T will be replaced with something like type-argument-X-Y.
/// What this visitor does is it iterates over both in sync, picking the right
/// values from one or another.
///
/// The template argument visitor functions take an additional
/// ArrayRef<TemplateArgument> argument corresponding to the template arguments
/// of the outermost template. This is used by some of these functions for
/// mapping dependent template arguments.
///
/// Moral of the story: drop integration header ASAP (but that is blocked
/// by support for 3rd-party host compilers, which is important).
class FreeFunctionTemplateKernelArgsPrinter
: public ConstTemplateArgumentVisitor<FreeFunctionTemplateKernelArgsPrinter,
void, ArrayRef<TemplateArgument>> {
raw_ostream &O;
PrintingPolicy &Policy;
ASTContext &Context;

using Base =
ConstTemplateArgumentVisitor<FreeFunctionTemplateKernelArgsPrinter, void,
ArrayRef<TemplateArgument>>;

// Desugars a template argument. This helps avoid aliases.
static TemplateArgument DesugarTemplateArgument(const TemplateArgument &Arg) {
switch (Arg.getKind()) {
case TemplateArgument::ArgKind::Type: {
QualType ArgTy = Arg.getAsType();
return {QualType(ArgTy->getUnqualifiedDesugaredType(),
ArgTy.getCVRQualifiers())};
}
case TemplateArgument::ArgKind::Template: {
TemplateName TN = Arg.getAsTemplate();
while (std::optional<TemplateName> DesugaredTN =
TN.desugar(/*IgnoreDeduced=*/false))
TN = *DesugaredTN;
return {TN};
}
default:
return Arg;
}
}

void PrintDesugared(const TemplateArgument &Arg) {
DesugarTemplateArgument(Arg).print(Policy, O, /*IncludeType=*/false);
}

public:
FreeFunctionTemplateKernelArgsPrinter(raw_ostream &O, PrintingPolicy &Policy,
ASTContext &Context)
: O(O), Policy(Policy), Context(Context) {}

void Visit(const ParmVarDecl *Param) {
// There are cases when we can't directly use neither the original
// argument type, nor its canonical version. An example would be:
// template<typename T>
// void kernel(sycl::accessor<T, 1>);
// template void kernel(sycl::accessor<int, 1>);
// Accessor has multiple non-type template arguments with default values
// and non-qualified type will not include necessary namespaces for all
// of them. Qualified type will have that information, but all references
// to T will be replaced to something like type-argument-0
// What we do instead is we iterate template arguments of both versions
// of a type in sync and take elements from one or another to get the best
// of both: proper references to template arguments of a kernel itself and
// fully-qualified names for enumerations.
//
// Moral of the story: drop integration header ASAP (but that is blocked
// by support for 3rd-party host compilers, which is important).
QualType T = Param->getType();
QualType CT = T.getCanonicalType();

const auto *TST = dyn_cast<TemplateSpecializationType>(T.getTypePtr());
const auto *CTST = dyn_cast<TemplateSpecializationType>(CT.getTypePtr());
if (!TST || !CTST) {
O << T.getDesugaredType(Context).getAsString(Policy);
return;
}

const TemplateSpecializationType *TSTAsNonAlias =
TST->getAsNonAliasTemplateSpecializationType();
if (TSTAsNonAlias)
TST = TSTAsNonAlias;

ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments();
ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments();

const TemplateDecl *TD = CTST->getTemplateName().getAsTemplateDecl();
if (!TD->getIdentifier())
TD = TST->getTemplateName().getAsTemplateDecl();
assert(TD->getIdentifier() &&
"Either the type or the canonical type should have an identifier.");
TD->printQualifiedName(O);

O << "<";
for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()),
SE = SpecArgs.size();
I < E; ++I) {
if (I != 0)
O << ", ";
// If we have a specialized argument, use it. Otherwise fallback to a
// default argument.
// We pass specialized arguments in case there are references to them
// from other types.
// FIXME: passing SpecArgs here is incorrect. It refers to template
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That FIXME does seem concerning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, though the current version of this still addresses a chunk of the current issues we are seeing with the prototype generation. I can try to add some more disabled cases to the test so we know what to fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After having a look at this, I cannot seem to find a case that allows template arguments of differing depth. Maybe @AlexeySachkov knows of one?

// arguments of a single function argument, but DeclArgs contain
// references (in form of depth-index) to template arguments of the
// function itself which results in incorrect integration header being
// produced.
Base::Visit(I < SE ? SpecArgs[I] : DeclArgs[I], SpecArgs);
}
O << ">";
}

// Internal version of the function above that is used when template argument
// is a template by itself
void Visit(const TemplateSpecializationType *T,
ArrayRef<TemplateArgument> SpecArgs) {
const TemplateDecl *TD = T->getTemplateName().getAsTemplateDecl();
const auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(TD);
if (TTPD && !TTPD->getIdentifier())
PrintDesugared(SpecArgs[TTPD->getIndex()]);
else
TD->printQualifiedName(O);
O << "<";
ArrayRef<const TemplateArgument> DeclArgs = T->template_arguments();
for (size_t I = 0, E = DeclArgs.size(); I < E; ++I) {
if (I != 0)
O << ", ";
Base::Visit(DeclArgs[I], SpecArgs);
}
O << ">";
}

void VisitNullTemplateArgument(const TemplateArgument &,
ArrayRef<TemplateArgument>) {
llvm_unreachable("If template argument has not been deduced, then we can't "
"forward-declare it, something went wrong");
}

void VisitTypeTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument> SpecArgs) {
TemplateArgument DesugaredArg = DesugarTemplateArgument(Arg);
// If we reference an existing template argument without a known identifier,
// print it instead.
const auto *TPT = dyn_cast<TemplateTypeParmType>(DesugaredArg.getAsType());
if (TPT && !TPT->getIdentifier()) {
PrintDesugared(SpecArgs[TPT->getIndex()]);
return;
}

const auto *TST =
dyn_cast<TemplateSpecializationType>(DesugaredArg.getAsType());
if (TST && Arg.isInstantiationDependent()) {
// This is an instantiation dependent template specialization, meaning
// that some of its arguments reference template arguments of the free
// function kernel itself.
Visit(TST, SpecArgs);
return;
}

DesugaredArg.print(Policy, O, /* IncludeType = */ false);
}

void VisitDeclarationTemplateArgument(const TemplateArgument &,
ArrayRef<TemplateArgument>) {
llvm_unreachable("Free function kernels cannot have non-type template "
"arguments which are pointers or references");
}

void VisitNullPtrTemplateArgument(const TemplateArgument &,
ArrayRef<TemplateArgument>) {
llvm_unreachable("Free function kernels cannot have non-type template "
"arguments which are pointers or references");
}

void VisitIntegralTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument>) {
PrintDesugared(Arg);
}

void VisitStructuralValueTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument>) {
PrintDesugared(Arg);
}

void VisitTemplateTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument>) {
PrintDesugared(Arg);
}

void VisitTemplateExpansionTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument>) {
PrintDesugared(Arg);
}

void VisitExpressionTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument>) {
Expr *E = Arg.getAsExpr();
assert(E && "Failed to get an Expr for an Expression template arg?");

if (Arg.isInstantiationDependent() ||
E->getType()->isScopedEnumeralType()) {
// Scoped enumerations can't be implicitly cast from integers, so
// we don't need to evaluate them.
// If expression is instantiation-dependent, then we can't evaluate it
// either, let's fallback to default printing mechanism.
PrintDesugared(Arg);
return;
}

Expr::EvalResult Res;
[[maybe_unused]] bool Success =
Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context);
assert(Success && "invalid non-type template argument?");
assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?");
Res.Val.printPretty(O, Policy, Arg.getAsExpr()->getType(), &Context);
}

void VisitPackTemplateArgument(const TemplateArgument &Arg,
ArrayRef<TemplateArgument>) {
PrintDesugared(Arg);
}
};

class FreeFunctionPrinter {
raw_ostream &O;
PrintingPolicy &Policy;
Expand Down Expand Up @@ -6776,86 +7024,16 @@ class FreeFunctionPrinter {
llvm::raw_svector_ostream ParmListOstream{ParamList};
Policy.SuppressTagKeyword = true;

for (ParmVarDecl *Param : Parameters) {
FreeFunctionTemplateKernelArgsPrinter Printer(ParmListOstream, Policy,
Context);

for (const ParmVarDecl *Param : Parameters) {
if (FirstParam)
FirstParam = false;
else
ParmListOstream << ", ";

// There are cases when we can't directly use neither the original
// argument type, nor its canonical version. An example would be:
// template<typename T>
// void kernel(sycl::accessor<T, 1>);
// template void kernel(sycl::accessor<int, 1>);
// Accessor has multiple non-type template arguments with default values
// and non-qualified type will not include necessary namespaces for all
// of them. Qualified type will have that information, but all references
// to T will be replaced to something like type-argument-0
// What we do instead is we iterate template arguments of both versions
// of a type in sync and take elements from one or another to get the best
// of both: proper references to template arguments of a kernel itself and
// fully-qualified names for enumerations.
//
// Moral of the story: drop integration header ASAP (but that is blocked
// by support for 3rd-party host compilers, which is important).
QualType T = Param->getType();
QualType CT = T.getCanonicalType();

const auto *TST = dyn_cast<TemplateSpecializationType>(T.getTypePtr());
const auto *CTST = dyn_cast<TemplateSpecializationType>(CT.getTypePtr());
if (!TST || !CTST) {
ParmListOstream << T.getAsString(Policy);
continue;
}

const TemplateSpecializationType *TSTAsNonAlias =
TST->getAsNonAliasTemplateSpecializationType();
if (TSTAsNonAlias)
TST = TSTAsNonAlias;

TemplateName CTN = CTST->getTemplateName();
CTN.getAsTemplateDecl()->printQualifiedName(ParmListOstream);
ParmListOstream << "<";

ArrayRef<TemplateArgument> SpecArgs = TST->template_arguments();
ArrayRef<TemplateArgument> DeclArgs = CTST->template_arguments();

auto TemplateArgPrinter = [&](const TemplateArgument &Arg) {
if (Arg.getKind() != TemplateArgument::ArgKind::Expression ||
Arg.isInstantiationDependent()) {
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
return;
}

Expr *E = Arg.getAsExpr();
assert(E && "Failed to get an Expr for an Expression template arg?");
if (E->getType().getTypePtr()->isScopedEnumeralType()) {
// Scoped enumerations can't be implicitly cast from integers, so
// we don't need to evaluate them.
Arg.print(Policy, ParmListOstream, /* IncludeType = */ false);
return;
}

Expr::EvalResult Res;
[[maybe_unused]] bool Success =
Arg.getAsExpr()->EvaluateAsConstantExpr(Res, Context);
assert(Success && "invalid non-type template argument?");
assert(!Res.Val.isAbsent() && "couldn't read the evaulation result?");
Res.Val.printPretty(ParmListOstream, Policy, Arg.getAsExpr()->getType(),
&Context);
};

for (size_t I = 0, E = std::max(DeclArgs.size(), SpecArgs.size()),
SE = SpecArgs.size();
I < E; ++I) {
if (I != 0)
ParmListOstream << ", ";
// If we have a specialized argument, use it. Otherwise fallback to a
// default argument.
TemplateArgPrinter(I < SE ? SpecArgs[I] : DeclArgs[I]);
}

ParmListOstream << ">";
Printer.Visit(Param);
}
return ParamList.str().str();
}
Expand All @@ -6873,26 +7051,39 @@ class FreeFunctionPrinter {
std::string getTemplateParameters(const clang::TemplateParameterList *TPL) {
std::string TemplateParams{"template <"};
bool FirstParam{true};
for (NamedDecl *Param : *TPL) {
for (const NamedDecl *Param : *TPL) {
if (!FirstParam)
TemplateParams += ", ";
FirstParam = false;
if (const auto *TemplateParam = dyn_cast<TemplateTypeParmDecl>(Param)) {
TemplateParams +=
TemplateParam->wasDeclaredWithTypename() ? "typename " : "class ";
if (TemplateParam->isParameterPack())
TemplateParams += "... ";
TemplateParams += TemplateParam->getNameAsString();
} else if (const auto *NonTypeParam =
dyn_cast<NonTypeTemplateParmDecl>(Param)) {
TemplateParams += NonTypeParam->getType().getAsString();
TemplateParams += " ";
TemplateParams += NonTypeParam->getNameAsString();
}
TemplateParams += getTemplateParameter(Param);
}
TemplateParams += "> ";
return TemplateParams;
}

/// Helper method to get text representation of a template parameter.
/// \param Param The template parameter.
std::string getTemplateParameter(const NamedDecl *Param) {
auto GetTypenameOrClass = [](const auto *Param) {
return Param->wasDeclaredWithTypename() ? "typename " : "class ";
};
if (const auto *TemplateParam = dyn_cast<TemplateTypeParmDecl>(Param)) {
std::string TemplateParamStr = GetTypenameOrClass(TemplateParam);
if (TemplateParam->isParameterPack())
TemplateParamStr += "... ";
TemplateParamStr += TemplateParam->getNameAsString();
return TemplateParamStr;
} else if (const auto *NonTypeParam =
dyn_cast<NonTypeTemplateParmDecl>(Param)) {
return NonTypeParam->getType().getAsString() + " " +
NonTypeParam->getNameAsString();
} else if (const auto *TTParam =
dyn_cast<TemplateTemplateParmDecl>(Param)) {
return getTemplateParameters(TTParam->getTemplateParameters()) + " " +
GetTypenameOrClass(TTParam) + TTParam->getNameAsString();
}
return "";
}
};

void SYCLIntegrationHeader::emit(raw_ostream &O) {
Expand Down
Loading
Loading