diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index a1d5d65299e7f..f63f5dc31976c 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4050,6 +4050,12 @@ }, "sqlState" : "42823" }, + "INVALID_TEMP_OBJ_QUALIFIER" : { + "message" : [ + "Temporary cannot be qualified with . Temporary objects can only be qualified with SESSION or SYSTEM.SESSION." + ], + "sqlState" : "42602" + }, "INVALID_TEMP_OBJ_REFERENCE" : { "message" : [ "Cannot create the persistent object of the type because it references to the temporary object of the type . Please make the temporary object persistent, or make the persistent object temporary." @@ -4120,6 +4126,12 @@ ], "sqlState" : "42000" }, + "INVALID_USAGE_OF_STAR_WITH_TABLE_IDENTIFIER_IN_COUNT" : { + "message" : [ + "count(.*) is not allowed. Use count(*) or expand the columns manually, e.g. count(col1, col2)." + ], + "sqlState" : "42000" + }, "INVALID_UTF8_STRING" : { "message" : [ "Invalid UTF8 byte sequence found in string: ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6e899e958f157..91079fbacb5e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1829,23 +1829,33 @@ class Analyzer( case _ => false } + /** + * Checks if the given function name parts match the expected unqualified function name, + * regardless of whether it's qualified or not. + * Handles: "count", "builtin.count", "system.builtin.count", "session.count", etc. + */ + private def matchesFunctionName(nameParts: Seq[String], expectedName: String): Boolean = { + nameParts.lastOption.exists(_.equalsIgnoreCase(expectedName)) + } + /** * Expands the matching attribute.*'s in `child`'s output. */ def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = { expr.transformUp { case f0: UnresolvedFunction if !f0.isDistinct && - f0.nameParts.map(_.toLowerCase(Locale.ROOT)) == Seq("count") && + matchesFunctionName(f0.nameParts, "count") && isCountStarExpansionAllowed(f0.arguments) => // Transform COUNT(*) into COUNT(1). - f0.copy(nameParts = Seq("count"), arguments = Seq(Literal(1))) + // Preserve the original qualification (if any) in the transformed node. + f0.copy(arguments = Seq(Literal(1))) case f1: UnresolvedFunction if containsStar(f1.arguments) => // SPECIAL CASE: We want to block count(tblName.*) because in spark, count(tblName.*) will // be expanded while count(*) will be converted to count(1). They will produce different // results and confuse users if there are any null values. For count(t1.*, t2.*), it is // still allowed, since it's well-defined in spark. if (!conf.allowStarWithSingleTableIdentifierInCount && - f1.nameParts == Seq("count") && + matchesFunctionName(f1.nameParts, "count") && f1.arguments.length == 1) { f1.arguments.foreach { case u: UnresolvedStar if u.isQualifiedByTable(child.output, resolver) => @@ -1993,8 +2003,13 @@ class Analyzer( * only performs simple existence check according to the function identifier to quickly identify * undefined functions without triggering relation resolution, which may incur potentially * expensive partition/schema discovery process in some cases. - * In order to avoid duplicate external functions lookup, the external function identifier will - * store in the local hash set externalFunctionNameSet. + * + * To avoid duplicate external catalog lookups, this rule maintains a per-plan cache of + * persistent function names (externalFunctionNameSet). Builtin and temporary functions are + * validated on every occurrence since they're fast in-memory lookups, but persistent functions + * are cached after the first validation to avoid repeated external catalog calls for the same + * function within a single plan. + * * @see [[ResolveFunctions]] * @see https://issues.apache.org/jira/browse/SPARK-19737 */ @@ -2004,24 +2019,55 @@ class Analyzer( plan.resolveExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_FUNCTION)) { case f @ UnresolvedFunction(nameParts, _, _, _, _, _, _) => - if (functionResolution.lookupBuiltinOrTempFunction(nameParts, Some(f)).isDefined) { + // For builtin/temp functions, we can do a quick check without catalog lookup + val quickCheck = if (nameParts.size == 1) { + functionResolution.lookupBuiltinOrTempFunction(nameParts, Some(f)) + } else if (FunctionResolution.maybeBuiltinFunctionName(nameParts) || + FunctionResolution.maybeTempFunctionName(nameParts)) { + functionResolution.lookupBuiltinOrTempFunction(nameParts, Some(f)) + } else { + None + } + + if (quickCheck.isDefined) { + // It's a builtin or temp function - no need for catalog lookup or caching f } else { + // Might be a persistent function - compute full name and check cache first val CatalogAndIdentifier(catalog, ident) = relationResolution.expandIdentifier(nameParts) - val fullName = - normalizeFuncName((catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq) + val fullName = normalizeFuncName( + (catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq) + if (externalFunctionNameSet.contains(fullName)) { - f - } else if (catalog.asFunctionCatalog.functionExists(ident)) { - externalFunctionNameSet.add(fullName) + // Already validated this function exists - skip lookup f } else { - val catalogPath = (catalog.name() +: catalogManager.currentNamespace).mkString(".") - throw QueryCompilationErrors.unresolvedRoutineError( - nameParts, - Seq("system.builtin", "system.session", catalogPath), - f.origin) + // Not in cache - do full lookup to determine type + val functionType = functionResolution.lookupFunctionType(nameParts, Some(f)) + + functionType match { + case FunctionType.Builtin | FunctionType.Temporary => + // This shouldn't happen since we checked above, but handle it + f + + case FunctionType.Persistent => + // Cache it to avoid repeated external catalog lookups + externalFunctionNameSet.add(fullName) + f + + case FunctionType.TableOnly => + // Function exists ONLY in table registry - cannot be used in scalar context + throw QueryCompilationErrors.notAScalarFunctionError(nameParts.mkString("."), f) + + case FunctionType.NotFound => + // Function doesn't exist anywhere - throw UNRESOLVED_ROUTINE error + val catalogPath = (catalog.name +: catalogManager.currentNamespace).mkString(".") + throw QueryCompilationErrors.unresolvedRoutineError( + nameParts, + Seq("system.builtin", "system.session", catalogPath), + f.origin) + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 26353bb8d46f9..5727d818cddcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.variant._ import org.apache.spark.sql.catalyst.expressions.xml._ import org.apache.spark.sql.catalyst.plans.logical.{FunctionBuilderBase, Generate, LogicalPlan, OneRowRelation, PythonWorkerLogs, Range} import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -78,10 +79,17 @@ trait FunctionRegistryBase[T] { /* Create or replace a temporary function. */ final def createOrReplaceTempFunction( name: String, builder: FunctionBuilder, source: String): Unit = { - registerFunction( - FunctionIdentifier(name), - builder, - source) + // Internal functions (source="internal") are NOT qualified with + // CatalogManager.SESSION_NAMESPACE database because they use a separate + // internal registry and are resolved differently + val identifier = if (source == "internal") { + FunctionIdentifier(name) + } else { + // Regular temporary functions are qualified with CatalogManager.SESSION_NAMESPACE + // to enable coexistence with builtin functions of the same name + FunctionIdentifier(name, Some(CatalogManager.SESSION_NAMESPACE)) + } + registerFunction(identifier, builder, source) } @throws[AnalysisException]("If function does not exist") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala index 29f4db65def01..57bf1af0f691e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala @@ -28,6 +28,23 @@ import org.apache.spark.sql.connector.catalog.{ CatalogManager, LookupCatalog } + +/** + * Represents the type/location of a function. + */ +sealed trait FunctionType +object FunctionType { + /** Function is a built-in function in the builtin registry. */ + case object Builtin extends FunctionType + /** Function is a temporary function in the session registry. */ + case object Temporary extends FunctionType + /** Function is a persistent function in the external catalog. */ + case object Persistent extends FunctionType + /** Function exists only as a table function (cannot be used in scalar context). */ + case object TableOnly extends FunctionType + /** Function does not exist anywhere. */ + case object NotFound extends FunctionType +} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.functions.{ AggregateFunction => V2AggregateFunction, @@ -77,7 +94,14 @@ class FunctionResolution( u: Option[UnresolvedFunction]): Option[ExpressionInfo] = { if (name.size == 1 && u.exists(_.isInternal)) { FunctionRegistry.internal.lookupFunction(FunctionIdentifier(name.head)) + } else if (maybeBuiltinFunctionName(name)) { + // Explicitly qualified as builtin - lookup only builtin + v1SessionCatalog.lookupBuiltinFunction(name.last) + } else if (maybeTempFunctionName(name)) { + // Explicitly qualified as temp - lookup only temp + v1SessionCatalog.lookupTempFunction(name.last) } else if (name.size == 1) { + // Unqualified - check temp first (shadowing), then builtin v1SessionCatalog.lookupBuiltinOrTempFunction(name.head) } else { None @@ -85,24 +109,152 @@ class FunctionResolution( } def lookupBuiltinOrTempTableFunction(name: Seq[String]): Option[ExpressionInfo] = { - if (name.length == 1) { + if (maybeExtensionFunctionName(name)) { + // Explicitly qualified as extension - lookup only extension + v1SessionCatalog.lookupExtensionTableFunction(name.last) + } else if (maybeBuiltinFunctionName(name)) { + // Explicitly qualified as builtin - lookup only builtin + v1SessionCatalog.lookupBuiltinTableFunction(name.last) + } else if (maybeTempFunctionName(name)) { + // Explicitly qualified as temp - lookup only temp + v1SessionCatalog.lookupTempTableFunction(name.last) + } else if (name.length == 1) { + // For unqualified names, use the PATH resolution order: extension -> builtin -> session v1SessionCatalog.lookupBuiltinOrTempTableFunction(name.head) } else { None } } + /** + * Checks if a function is a builtin or temporary function (scalar or table). + * This is a convenience method that uses lookupFunctionType internally. + * + * @param nameParts The function name parts. + * @param u Optional UnresolvedFunction for internal function detection. + * @return true if the function is a builtin or temporary function, false otherwise. + */ + def isBuiltinOrTemporaryFunction( + nameParts: Seq[String], + u: Option[UnresolvedFunction]): Boolean = { + lookupFunctionType(nameParts, u) match { + case FunctionType.Builtin | FunctionType.Temporary | FunctionType.TableOnly => true + case _ => false + } + } + + /** + * Determines the type/location of a function (builtin, temporary, persistent, etc.). + * This is used by the LookupFunctions analyzer rule for early validation and optimization. + * This method only performs the lookup and classification - it does not throw errors. + * + * @param nameParts The function name parts. + * @param node Optional UnresolvedFunction node for lookups that may need it. + * @return The type of the function (Builtin, Temporary, Persistent, TableOnly, or NotFound). + */ + def lookupFunctionType( + nameParts: Seq[String], + node: Option[UnresolvedFunction] = None): FunctionType = { + + // Check if it's explicitly qualified as extension, builtin, or temp + if (maybeExtensionFunctionName(nameParts)) { + // Explicitly qualified as extension (e.g., extension.func or system.extension.func) + if (lookupBuiltinOrTempFunction(nameParts, node).isDefined) { + return FunctionType.Builtin // Extensions are treated as builtin for resolution purposes + } + } else if (maybeBuiltinFunctionName(nameParts)) { + // Explicitly qualified as builtin (e.g., builtin.abs or system.builtin.abs) + if (lookupBuiltinOrTempFunction(nameParts, node).isDefined) { + return FunctionType.Builtin + } + } else if (maybeTempFunctionName(nameParts)) { + // Explicitly qualified as temp (e.g., session.func or system.session.func) + if (lookupBuiltinOrTempFunction(nameParts, node).isDefined) { + return FunctionType.Temporary + } + } else { + // Unqualified or qualified with a catalog + // Use lookupBuiltinOrTempFunction which handles internal functions correctly + val funcInfoOpt = lookupBuiltinOrTempFunction(nameParts, node) + funcInfoOpt match { + case Some(info) => + // Determine if it's extension, temp, or builtin from the ExpressionInfo + if (info.getDb == CatalogManager.EXTENSION_NAMESPACE) { + // Extensions are treated as builtin for resolution purposes + return FunctionType.Builtin + } else if (info.getDb == CatalogManager.SESSION_NAMESPACE) { + // Could be temp or internal - check if it's in the internal registry + if (nameParts.size == 1 && node.exists(_.isInternal)) { + return FunctionType.Builtin // Internal functions are treated as builtins + } else { + return FunctionType.Temporary + } + } else { + return FunctionType.Builtin + } + case None => + // Not found as scalar, continue checking + } + } + + // Check if function exists as table function only + if (lookupBuiltinOrTempTableFunction(nameParts).isDefined) { + return FunctionType.TableOnly + } + + // Check external catalog for persistent functions + val CatalogAndIdentifier(catalog, ident) = relationResolution.expandIdentifier(nameParts) + if (catalog.asFunctionCatalog.functionExists(ident)) { + return FunctionType.Persistent + } + + // Function doesn't exist anywhere + FunctionType.NotFound + } + def resolveBuiltinOrTempFunction( name: Seq[String], arguments: Seq[Expression], u: UnresolvedFunction): Option[Expression] = { + + // Step 1: Try to resolve as scalar function val expression = if (name.size == 1 && u.isInternal) { Option(FunctionRegistry.internal.lookupFunction(FunctionIdentifier(name.head), arguments)) + } else if (maybeExtensionFunctionName(name)) { + // Explicitly qualified as extension - resolve only extension + v1SessionCatalog.resolveExtensionFunction(name.last, arguments) + } else if (maybeBuiltinFunctionName(name)) { + // Explicitly qualified as builtin - resolve only builtin + v1SessionCatalog.resolveBuiltinFunction(name.last, arguments) + } else if (maybeTempFunctionName(name)) { + // Explicitly qualified as temp - resolve only temp + v1SessionCatalog.resolveTempFunction(name.last, arguments) } else if (name.size == 1) { - v1SessionCatalog.resolveBuiltinOrTempFunction(name.head, arguments) + // For unqualified names, use the PATH resolution order: extension -> builtin -> session + // This ensures built-in functions take precedence over temp functions (security fix) + // Cross-type checking: If only a temp table function exists (no scalar version), + // throw error when used in scalar context + val funcName = name.head + val scalarResult = v1SessionCatalog.resolveBuiltinOrTempFunction(funcName, arguments) + + if (scalarResult.isEmpty && v1SessionCatalog.lookupTempTableFunction(funcName).isDefined) { + // No scalar function found (neither builtin nor temp), but temp table function exists + throw QueryCompilationErrors.notAScalarFunctionError(name.mkString("."), u) + } else { + scalarResult + } } else { None } + + // Step 2: Check for table-only functions (cross-type error detection) + // If not found as scalar, check if it exists as a table-only function + if (expression.isEmpty && name.size == 1) { + if (v1SessionCatalog.lookupBuiltinOrTempTableFunction(name.head).isDefined) { + throw QueryCompilationErrors.notAScalarFunctionError(name.mkString("."), u) + } + } + expression.map { func => validateFunction(func, arguments.length, u) } @@ -111,11 +263,74 @@ class FunctionResolution( def resolveBuiltinOrTempTableFunction( name: Seq[String], arguments: Seq[Expression]): Option[LogicalPlan] = { - if (name.length == 1) { - v1SessionCatalog.resolveBuiltinOrTempTableFunction(name.head, arguments) + + // Step 1: Try to resolve as table function + val tableFunctionResult = if (maybeExtensionFunctionName(name)) { + // Explicitly qualified as extension - resolve only extension + v1SessionCatalog.resolveExtensionTableFunction(name.last, arguments) + } else if (maybeBuiltinFunctionName(name)) { + // Explicitly qualified as builtin - resolve only builtin + v1SessionCatalog.resolveBuiltinTableFunction(name.last, arguments) + } else if (maybeTempFunctionName(name)) { + // Explicitly qualified as temp - resolve only temp + v1SessionCatalog.resolveTempTableFunction(name.last, arguments) + } else if (name.length == 1) { + // For unqualified names, use the PATH resolution order: extension -> builtin -> session + // This ensures built-in table functions take precedence over temp functions (security fix) + // Cross-type checking: If only a temp scalar function exists (no table version), + // throw error when used in table context (checked below in Step 2) + val funcName = name.head + v1SessionCatalog.resolveBuiltinOrTempTableFunction(funcName, arguments) } else { None } + + // Step 2: Fallback to scalar registry for type mismatch detection + // If no table function was found (neither builtin nor temp), check if a scalar function exists. + // If yes, this is a cross-type error - scalar function used in table context. + if (tableFunctionResult.isEmpty && name.length == 1) { + if (v1SessionCatalog.lookupBuiltinOrTempFunction(name.head).isDefined) { + throw QueryCompilationErrors.notATableFunctionError(name.mkString(".")) + } + } + + tableFunctionResult + } + + /** + * Check if a function name is qualified as an extension function. + * Valid forms: extension.func or system.extension.func + */ + private def maybeExtensionFunctionName(nameParts: Seq[String]): Boolean = { + FunctionResolution.maybeExtensionFunctionName(nameParts) + } + + /** + * Check if a function name is qualified as a builtin function. + * Valid forms: builtin.func or system.builtin.func + */ + private def maybeBuiltinFunctionName(nameParts: Seq[String]): Boolean = { + FunctionResolution.maybeBuiltinFunctionName(nameParts) + } + + /** + * Check if a function name is qualified as a session temporary function. + * Valid forms: session.func or system.session.func + */ + private def maybeTempFunctionName(nameParts: Seq[String]): Boolean = { + FunctionResolution.maybeTempFunctionName(nameParts) + } + + /** + * Checks if a multi-part name is qualified with a specific namespace. + * Supports both 2-part (namespace.name) and 3-part (system.namespace.name) qualifications. + * + * @param nameParts The multi-part name to check + * @param namespace The namespace to check for (e.g., "builtin", "session") + * @return true if qualified with the given namespace + */ + private def isQualifiedWithNamespace(nameParts: Seq[String], namespace: String): Boolean = { + FunctionResolution.isQualifiedWithNamespace(nameParts, namespace) } private def validateFunction( @@ -342,3 +557,50 @@ class FunctionResolution( messageParameters = messageParameters) } } + +/** + * Companion object with shared utility methods for function name qualification checks. + */ +object FunctionResolution { + /** + * Check if a function name is qualified as an extension function. + * Valid forms: extension.func or system.extension.func + */ + def maybeExtensionFunctionName(nameParts: Seq[String]): Boolean = { + isQualifiedWithNamespace(nameParts, CatalogManager.EXTENSION_NAMESPACE) + } + + /** + * Check if a function name is qualified as a builtin function. + * Valid forms: builtin.func or system.builtin.func + */ + def maybeBuiltinFunctionName(nameParts: Seq[String]): Boolean = { + isQualifiedWithNamespace(nameParts, CatalogManager.BUILTIN_NAMESPACE) + } + + /** + * Check if a function name is qualified as a session temporary function. + * Valid forms: session.func or system.session.func + */ + def maybeTempFunctionName(nameParts: Seq[String]): Boolean = { + isQualifiedWithNamespace(nameParts, CatalogManager.SESSION_NAMESPACE) + } + + /** + * Checks if a multi-part name is qualified with a specific namespace. + * Supports both 2-part (namespace.name) and 3-part (system.namespace.name) qualifications. + * + * @param nameParts The multi-part name to check + * @param namespace The namespace to check for (e.g., "extension", "builtin", "session") + * @return true if qualified with the given namespace + */ + def isQualifiedWithNamespace(nameParts: Seq[String], namespace: String): Boolean = { + nameParts.length match { + case 2 => nameParts.head.equalsIgnoreCase(namespace) + case 3 => + nameParts(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) && + nameParts(1).equalsIgnoreCase(namespace) + case _ => false + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index e433401511d3a..f99e9761c1263 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -136,7 +136,9 @@ class ResolveCatalogs(val catalogManager: CatalogManager) /** * Resolves a function identifier, checking for builtin and temp functions first. - * Builtin and temp functions are only registered with unqualified names. + * Builtin and temp functions are only registered with unqualified names, but can be + * referenced with qualified names like builtin.abs, system.builtin.abs, session.func, + * or system.session.func. */ private def resolveFunctionIdentifier( nameParts: Seq[String], @@ -154,6 +156,14 @@ class ResolveCatalogs(val catalogManager: CatalogManager) val CatalogAndIdentifier(catalog, ident) = nameParts ResolvedIdentifier(catalog, ident) } + } else if (FunctionResolution.maybeBuiltinFunctionName(nameParts)) { + // Explicitly qualified as builtin (e.g., builtin.abs or system.builtin.abs) + val ident = Identifier.of(Array(CatalogManager.BUILTIN_NAMESPACE), nameParts.last) + ResolvedIdentifier(FakeSystemCatalog, ident) + } else if (FunctionResolution.maybeTempFunctionName(nameParts)) { + // Explicitly qualified as temp (e.g., session.func or system.session.func) + val ident = Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), nameParts.last) + ResolvedIdentifier(FakeSystemCatalog, ident) } else { val CatalogAndIdentifier(catalog, ident) = nameParts ResolvedIdentifier(catalog, ident) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala index fe4c06aff199b..80db9b5e1bc21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.analysis.resolver -import java.util.Locale - import scala.util.Random import org.apache.spark.sql.AnalysisException @@ -38,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.{ TryEval } import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.connector.catalog.CatalogManager /** * A resolver for [[UnresolvedFunction]]s that resolves functions to concrete [[Expression]]s. @@ -172,11 +171,36 @@ class FunctionResolver( /** * Method used to determine whether the given function should be replaced with another one. + * Only accepts unqualified or properly qualified builtin count function. + * Rejects catalog.db.count (persistent function) to avoid incorrect normalization. */ private def isCount(unresolvedFunction: UnresolvedFunction): Boolean = { - !unresolvedFunction.isDistinct && - unresolvedFunction.nameParts.length == 1 && - unresolvedFunction.nameParts.head.toLowerCase(Locale.ROOT) == "count" + if (unresolvedFunction.isDistinct) { + return false + } + + val nameParts = unresolvedFunction.nameParts + + // Validate that this is actually the builtin count function, not a persistent one + val isBuiltinCount = nameParts.length match { + case 1 => + // Unqualified: "count" + nameParts.head.equalsIgnoreCase("count") + case 2 => + // Two parts: must be "builtin.count" + nameParts.head.equalsIgnoreCase(CatalogManager.BUILTIN_NAMESPACE) && + nameParts.last.equalsIgnoreCase("count") + case 3 => + // Three parts: must be "system.builtin.count" + nameParts(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) && + nameParts(1).equalsIgnoreCase(CatalogManager.BUILTIN_NAMESPACE) && + nameParts.last.equalsIgnoreCase("count") + case _ => + // More than 3 parts or other patterns are not builtin count + false + } + + isBuiltinCount } /** @@ -188,7 +212,6 @@ class FunctionResolver( private def normalizeCountExpression( unresolvedFunction: UnresolvedFunction): UnresolvedFunction = { unresolvedFunction.copy( - nameParts = Seq("count"), arguments = Seq(Literal(1)), filter = unresolvedFunction.filter ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala index 8c26003b733b8..f87758835114b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala @@ -345,14 +345,42 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { createNamedStruct.children.forall(checkExpression) } - private def checkUnresolvedFunction(unresolvedFunction: UnresolvedFunction) = - unresolvedFunction.nameParts.size == 1 && - !ResolverGuard.UNSUPPORTED_FUNCTION_NAMES.contains(unresolvedFunction.nameParts.head) && - // UDFs are not supported - FunctionRegistry.functionSet.contains( - FunctionIdentifier(unresolvedFunction.nameParts.head.toLowerCase(Locale.ROOT)) - ) && - unresolvedFunction.children.forall(checkExpression) + private def checkUnresolvedFunction(unresolvedFunction: UnresolvedFunction) = { + val nameParts = unresolvedFunction.nameParts + + // Only accept unqualified names or names explicitly qualified as builtin. + // Session/temporary functions are UDFs and not supported by single-pass analyzer. + // Persistent functions from external catalogs are also not supported. + val isBuiltinOrUnqualified = nameParts.length match { + case 1 => + // Unqualified: "count" - check if it's a builtin + true + case 2 => + // Two parts: must be "builtin.count" + nameParts.head.equalsIgnoreCase(CatalogManager.BUILTIN_NAMESPACE) + case 3 => + // Three parts: must be "system.builtin.count" + nameParts(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) && + nameParts(1).equalsIgnoreCase(CatalogManager.BUILTIN_NAMESPACE) + case _ => + // More than 3 parts is not valid + false + } + + if (!isBuiltinOrUnqualified) { + // This is session.func, catalog.db.function, or invalid - not supported + false + } else { + // Extract the unqualified function name (last part) to check against builtin set + val functionName = nameParts.last + !ResolverGuard.UNSUPPORTED_FUNCTION_NAMES.contains(functionName) && + // UDFs are not supported - only built-in functions + FunctionRegistry.functionSet.contains( + FunctionIdentifier(functionName.toLowerCase(Locale.ROOT)) + ) && + unresolvedFunction.children.forall(checkExpression) + } + } private def checkLiteral(literal: Literal) = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 130ccc1bc6e15..f3b307a28c003 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -84,6 +84,220 @@ class SessionCatalog( import SessionCatalog._ import CatalogTypes.TablePartitionSpec + // Marker database name for temporary functions to distinguish from builtin functions + /** + * Database qualifier used to store temporary functions in the function registry. + * Temporary functions use composite keys to coexist with builtin functions of the same name: + * - Builtin functions: FunctionIdentifier(name, None) + * - Extension functions: FunctionIdentifier(name, Some(CatalogManager.EXTENSION_NAMESPACE)) + * - Temp functions: FunctionIdentifier(name, Some(CatalogManager.SESSION_NAMESPACE)) + * This allows all three to exist in the same registry without conflicts. + */ + private val TEMP_FUNCTION_DB = CatalogManager.SESSION_NAMESPACE + private val EXTENSION_FUNCTION_DB = CatalogManager.EXTENSION_NAMESPACE + + /** + * Creates a FunctionIdentifier for an extension function with the EXTENSION_FUNCTION_DB database + * qualifier. This enables extension functions to coexist with builtin and temp functions. + * + * @param name The function name (unqualified) + * @return FunctionIdentifier with database = EXTENSION_FUNCTION_DB + */ + private def extensionFunctionIdentifier(name: String): FunctionIdentifier = + FunctionIdentifier(format(name), Some(EXTENSION_FUNCTION_DB)) + + /** + * Creates a FunctionIdentifier for a temporary function with the TEMP_FUNCTION_DB database + * qualifier. This enables temporary functions to coexist with builtin functions of the same name. + * + * @param name The function name (unqualified) + * @return FunctionIdentifier with database = TEMP_FUNCTION_DB + */ + private def tempFunctionIdentifier(name: String): FunctionIdentifier = + FunctionIdentifier(format(name), Some(TEMP_FUNCTION_DB)) + + /** + * Checks if a FunctionIdentifier represents an extension function by checking for the + * EXTENSION_FUNCTION_DB database qualifier. + * + * @param identifier The FunctionIdentifier to check + * @return true if this is an extension function identifier + */ + private def isExtensionFunctionIdentifier(identifier: FunctionIdentifier): Boolean = + identifier.database.contains(EXTENSION_FUNCTION_DB) + + /** + * Checks if a FunctionIdentifier represents a temporary function by checking for the + * TEMP_FUNCTION_DB database qualifier. + * + * @param identifier The FunctionIdentifier to check + * @return true if this is a temporary function identifier + */ + private def isTempFunctionIdentifier(identifier: FunctionIdentifier): Boolean = + identifier.database.contains(TEMP_FUNCTION_DB) + + // -------------------------------- + // | PATH-Based Resolution | + // -------------------------------- + // Functions are resolved by searching through an ordered PATH of namespaces. + // This provides a unified, data-driven resolution mechanism instead of hardcoded checks. + + /** + * Function resolution PATH - ordered list of namespaces to search for unqualified functions. + * Each entry is a FunctionIdentifier representing a namespace (catalog + database). + * The funcName field is unused (empty string) as these represent namespace templates. + * + * Resolution order (CRITICAL FOR SECURITY): + * 1. system.extension (extension functions - can shadow built-ins) + * 2. system.builtin (built-in functions - protected from user shadowing) + * 3. system.session (temporary functions - CANNOT shadow built-ins) + * + * This order ensures that: + * - Extensions (admin-installed, trusted) can shadow built-ins if needed + * - Built-ins are resolved BEFORE user temp functions, preventing security exploits + * - Users cannot shadow security-critical functions like current_user() + * + * When resolving a view, the system.session namespace is included in the path, but + * handleViewContext filters to only return temporary functions that were referred to + * when the view was created. + * + * These identifiers are cached for performance since they're accessed frequently. + */ + private val EXTENSION_NAMESPACE_TEMPLATE = FunctionIdentifier( + funcName = "", + database = Some(CatalogManager.EXTENSION_NAMESPACE), + catalog = Some(CatalogManager.SYSTEM_CATALOG_NAME)) + + private val SESSION_NAMESPACE_TEMPLATE = FunctionIdentifier( + funcName = "", + database = Some(CatalogManager.SESSION_NAMESPACE), + catalog = Some(CatalogManager.SYSTEM_CATALOG_NAME)) + + private val BUILTIN_NAMESPACE_TEMPLATE = FunctionIdentifier( + funcName = "", + database = Some(CatalogManager.BUILTIN_NAMESPACE), + catalog = Some(CatalogManager.SYSTEM_CATALOG_NAME)) + + // The resolution path: extension -> builtin -> session (security-focused order) + private val RESOLUTION_PATH = Seq( + EXTENSION_NAMESPACE_TEMPLATE, + BUILTIN_NAMESPACE_TEMPLATE, + SESSION_NAMESPACE_TEMPLATE + ) + + /** + * Returns the resolution path for function lookup. + * @return Ordered sequence of namespace identifiers. + */ + private def resolutionPath(): Seq[FunctionIdentifier] = RESOLUTION_PATH + + /** + * Maps a namespace template to an actual storage identifier for a specific function. + * This handles the asymmetry between how builtins, extensions, and temp functions are stored. + * + * Storage conventions: + * - Builtin functions: FunctionIdentifier(name, None, None) + * - Extension functions: FunctionIdentifier(name, Some("extension"), None) + * - Temp functions: FunctionIdentifier(name, Some("session"), None) + * - Other: FunctionIdentifier(name, namespace.database, namespace.catalog) + * + * @param namespace The namespace template + * @param name The function name + * @return The actual identifier to use for registry lookup + */ + private def namespaceToIdentifier( + namespace: FunctionIdentifier, + name: String): FunctionIdentifier = { + namespace.database match { + case Some(CatalogManager.EXTENSION_NAMESPACE) => + // Extension functions: stored with database="extension", no catalog + extensionFunctionIdentifier(name) + + case Some(CatalogManager.SESSION_NAMESPACE) => + // Temp functions: stored with database="session", no catalog + tempFunctionIdentifier(name) + + case Some(CatalogManager.BUILTIN_NAMESPACE) => + // Builtin functions: stored with no database or catalog + FunctionIdentifier(format(name)) + + case other => + // Other namespaces: use full qualification + // Note: This branch is for future extensions (e.g., persistent functions in PATH) + if (other.isDefined) { + logDebug(s"Function lookup in non-standard namespace: $other for function: $name") + } + FunctionIdentifier(name, namespace.database, namespace.catalog) + } + } + + /** + * Checks if a namespace represents extension functions. + */ + private def isExtensionNamespace(namespace: FunctionIdentifier): Boolean = { + namespace.database.contains(CatalogManager.EXTENSION_NAMESPACE) + } + + /** + * Checks if a namespace represents temporary functions. + */ + private def isSessionNamespace(namespace: FunctionIdentifier): Boolean = { + namespace.database.contains(CatalogManager.SESSION_NAMESPACE) + } + + /** + * Lookup a function in a specific namespace. + * + * @param namespace Namespace identifier (catalog + database) + * @param name Function name (unqualified) + * @param registry The registry to search (FunctionRegistry or TableFunctionRegistry) + * @tparam T The registry's type parameter (Expression or LogicalPlan) + * @return ExpressionInfo if function found in this namespace + */ + private def lookupInNamespace[T]( + namespace: FunctionIdentifier, + name: String, + registry: FunctionRegistryBase[T]): Option[ExpressionInfo] = { + + val identifier = namespaceToIdentifier(namespace, name) + val result = registry.lookupFunction(identifier) + + // Apply view context filtering for temp functions + if (isSessionNamespace(namespace)) { + handleViewContext(name, result) + } else { + result + } + } + + /** + * Resolve a function in a specific namespace by building it with arguments. + * + * @param namespace Namespace identifier (catalog + database) + * @param name Function name (unqualified) + * @param arguments Arguments to pass to the function builder + * @param registry The registry to search + * @tparam T The registry's type parameter + * @return Built function instance if found + */ + private def resolveInNamespace[T]( + namespace: FunctionIdentifier, + name: String, + arguments: Seq[Expression], + registry: FunctionRegistryBase[T]): Option[T] = { + + val identifier = namespaceToIdentifier(namespace, name) + + if (!registry.functionExists(identifier)) { + None + } else if (isSessionNamespace(namespace)) { + // For temp functions, apply view context handling + handleViewContext(name, Option(registry.lookupFunction(identifier, arguments))) + } else { + Some(registry.lookupFunction(identifier, arguments)) + } + } + // For testing only. def this( externalCatalog: ExternalCatalog, @@ -1935,20 +2149,35 @@ class SessionCatalog( overrideIfExists: Boolean, functionBuilder: Option[FunctionBuilder] = None): Unit = { val builder = functionBuilder.getOrElse(makeFunctionBuilder(funcDefinition)) - registerFunction(funcDefinition, overrideIfExists, functionRegistry, builder) + // Use composite keys for temporary functions, but not for persistent functions + val useComposite = funcDefinition.identifier.database.isEmpty + registerFunction(funcDefinition, overrideIfExists, functionRegistry, builder, useComposite) } private def registerFunction[T]( funcDefinition: CatalogFunction, overrideIfExists: Boolean, registry: FunctionRegistryBase[T], - functionBuilder: FunctionRegistryBase[T]#FunctionBuilder): Unit = { + functionBuilder: FunctionRegistryBase[T]#FunctionBuilder, + useCompositeKey: Boolean): Unit = { val func = funcDefinition.identifier - if (registry.functionExists(func) && !overrideIfExists) { + + // Determine the key to use for registration: + // - Temporary functions (unqualified): use composite key with TEMP_FUNCTION_DB database + // - Persistent functions (qualified): keep qualification to avoid conflicts + val identToRegister = if (func.database.isEmpty && useCompositeKey) { + // Temporary function: use TEMP_FUNCTION_DB.funcName + tempFunctionIdentifier(func.funcName) + } else { + // Persistent function: keep original qualified identifier + func + } + + if (registry.functionExists(identToRegister) && !overrideIfExists) { throw QueryCompilationErrors.functionAlreadyExistsError(func) } val info = makeExprInfoForHiveFunction(funcDefinition) - registry.registerFunction(func, info, functionBuilder) + registry.registerFunction(identToRegister, info, functionBuilder) } private def makeExprInfoForHiveFunction(func: CatalogFunction): ExpressionInfo = { @@ -1977,7 +2206,8 @@ class SessionCatalog( function, overrideIfExists, functionRegistry, - makeSQLFunctionBuilder(function)) + makeSQLFunctionBuilder(function), + isTableFunction = false) } /** @@ -1991,7 +2221,8 @@ class SessionCatalog( function, overrideIfExists, tableFunctionRegistry, - makeSQLTableFunctionBuilder(function)) + makeSQLTableFunctionBuilder(function), + isTableFunction = true) } /** @@ -2028,17 +2259,54 @@ class SessionCatalog( /** * Registers a temporary or permanent SQL function into a session-specific function registry. + * For temporary functions, validates that the function name doesn't exist as a different type + * (scalar vs. table) to prevent ambiguous DROP operations. + * For persistent functions, the metastore already enforces this constraint. */ private def registerUserDefinedFunction[T]( function: UserDefinedFunction, overrideIfExists: Boolean, registry: FunctionRegistryBase[T], - functionBuilder: Seq[Expression] => T): Unit = { + functionBuilder: Seq[Expression] => T, + isTableFunction: Boolean = false): Unit = { + + val isTemporary = function.name.database.isEmpty + + if (isTemporary) { + // Use FunctionIdentifier with TEMP_FUNCTION_DB for temporary functions + val tempIdentifier = tempFunctionIdentifier(function.name.funcName) + + // Check if this temp function already exists in the target registry + if (registry.functionExists(tempIdentifier) && !overrideIfExists) { + throw QueryCompilationErrors.functionAlreadyExistsError(function.name) + } + + // Check if function exists in the OTHER registry as a different type. + // This prevents having both scalar and table temporary functions with the same name, + // which would make DROP TEMPORARY FUNCTION ambiguous (it would drop both). + val otherRegistry: FunctionRegistryBase[_] = + if (isTableFunction) functionRegistry else tableFunctionRegistry + if (otherRegistry.functionExists(tempIdentifier) && !overrideIfExists) { + throw QueryCompilationErrors.functionAlreadyExistsError(function.name) + } + + // With OR REPLACE, drop from the other registry first if it exists there + if (overrideIfExists) { + otherRegistry.dropFunction(tempIdentifier) + } + + val info = function.toExpressionInfo + registry.registerFunction(tempIdentifier, info, functionBuilder) + } else { + // Persistent function - the metastore already enforces cross-type uniqueness, + // so we only check the target registry here. if (registry.functionExists(function.name) && !overrideIfExists) { throw QueryCompilationErrors.functionAlreadyExistsError(function.name) } + val info = function.toExpressionInfo registry.registerFunction(function.name, info, functionBuilder) + } } /** @@ -2046,15 +2314,18 @@ class SessionCatalog( * or [[TableFunctionRegistry]]. Return true if function exists. */ def unregisterFunction(name: FunctionIdentifier): Boolean = { - functionRegistry.dropFunction(name) || tableFunctionRegistry.dropFunction(name) + // If it's an unqualified name, it's a temp function stored with TEMP_FUNCTION_DB database + val tempIdent = if (name.database.isEmpty) tempFunctionIdentifier(name.funcName) else name + functionRegistry.dropFunction(tempIdent) || tableFunctionRegistry.dropFunction(tempIdent) } /** * Drop a temporary function. */ def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = { - if (!functionRegistry.dropFunction(FunctionIdentifier(name)) && - !tableFunctionRegistry.dropFunction(FunctionIdentifier(name)) && + val tempIdent = tempFunctionIdentifier(name) + if (!functionRegistry.dropFunction(tempIdent) && + !tableFunctionRegistry.dropFunction(tempIdent) && !ignoreIfNotExists) { throw new NoSuchTempFunctionException(name) } @@ -2064,9 +2335,14 @@ class SessionCatalog( * Returns whether it is a temporary function. If not existed, returns false. */ def isTemporaryFunction(name: FunctionIdentifier): Boolean = { - // A temporary function is a function that has been registered in functionRegistry - // without a database name, and is neither a built-in function nor a Hive function - name.database.isEmpty && isRegisteredFunction(name) && !isBuiltinFunction(name) + // A temporary function is stored with database = TEMP_FUNCTION_DB + if (name.database.isEmpty) { + val tempIdent = tempFunctionIdentifier(name.funcName) + functionRegistry.functionExists(tempIdent) || + tableFunctionRegistry.functionExists(tempIdent) + } else { + isTempFunctionIdentifier(name) + } } /** @@ -2074,7 +2350,25 @@ class SessionCatalog( * session. If not existed, return false. */ def isRegisteredFunction(name: FunctionIdentifier): Boolean = { + // Check if it exists as temp (with TEMP_FUNCTION_DB db) or builtin (without db) or persistent + if (name.database.isEmpty) { + val tempIdent = tempFunctionIdentifier(name.funcName) + val builtinIdent = FunctionIdentifier(format(name.funcName)) + + // Check if temp function exists + val hasTemp = functionRegistry.functionExists(tempIdent) || + tableFunctionRegistry.functionExists(tempIdent) + + // Check if builtin exists - but ONLY if it's actually a builtin, not a cached persistent + val hasBuiltin = (FunctionRegistry.functionSet.contains(builtinIdent) || + TableFunctionRegistry.functionSet.contains(builtinIdent)) && + (functionRegistry.functionExists(builtinIdent) || + tableFunctionRegistry.functionExists(builtinIdent)) + + hasTemp || hasBuiltin + } else { functionRegistry.functionExists(name) || tableFunctionRegistry.functionExists(name) + } } /** @@ -2101,85 +2395,313 @@ class SessionCatalog( } /** - * Look up the `ExpressionInfo` of the given function by name if it's a built-in or temp function. + * Handles view resolution context for temporary functions (both scalar and table-valued). + * When resolving a view, only returns the result if the function is explicitly referred + * by that view. Otherwise, tracks the function reference for future view creation. + * + * This generic helper works for both scalar functions (FunctionRegistry) and table-valued + * functions (TableFunctionRegistry) due to the type parameter T. + * + * @param name The function name (unqualified) + * @param result The result to wrap with view context handling + * @tparam T The result type (ExpressionInfo, Expression, or LogicalPlan) + * @return The result if visible in current context, None otherwise + */ + private def handleViewContext[T](name: String, result: Option[T]): Option[T] = + result.filter { _ => + val isResolvingView = AnalysisContext.get.catalogAndNamespace.nonEmpty + val referredTempFunctionNames = AnalysisContext.get.referredTempFunctionNames + + if (isResolvingView) { + // When resolving a view, only return a temp function if it's referred by this view. + referredTempFunctionNames.contains(name) + } else { + // We are not resolving a view and the function is a temp one, add it to + // AnalysisContext so if a view is being created, it can be checked. + AnalysisContext.get.referredTempFunctionNames.add(name) + true + } + } + + /** + * Generic helper for looking up functions with temp/builtin shadowing and view context. + * Checks temp function first (with TEMP_FUNCTION_DB database qualifier), then built-in + * (without database qualifier). For temp functions, applies view resolution context. + * + * @param name The function name (unqualified) + * @param registry The registry to search (FunctionRegistry or TableFunctionRegistry) + * @param checkBuiltinOperators Whether to check built-in operators first (scalar functions only) + * @tparam T The registry's type parameter (Expression for FunctionRegistry, + * LogicalPlan for TableFunctionRegistry) + * @return ExpressionInfo if function found, None otherwise + */ + /** + * Looks up functions using PATH-based resolution. + * Searches through the resolution path (session then builtin) with view context handling. + * + * @param name The function name (unqualified). + * @param registry The registry to search (FunctionRegistry or TableFunctionRegistry). + * @param checkBuiltinOperators Whether to check built-in operators first (scalar functions only). + * @tparam T The registry's type parameter (Expression for FunctionRegistry, + * LogicalPlan for TableFunctionRegistry). + * @return ExpressionInfo if function found, None otherwise. + */ + private def lookupFunctionWithShadowing[T]( + name: String, + registry: FunctionRegistryBase[T], + checkBuiltinOperators: Boolean): Option[ExpressionInfo] = { + + // Check built-in operators first (only for scalar functions). + val operatorResult = if (checkBuiltinOperators) { + FunctionRegistry.builtinOperators.get(name.toLowerCase(Locale.ROOT)) + } else { + None + } + + operatorResult.orElse { + // Use PATH-based resolution: iterate through namespaces until a match is found. + val path = resolutionPath() + + // Use iterator for short-circuit evaluation (stops at first match). + path.iterator.flatMap { namespace => + lookupInNamespace(namespace, name, registry) + }.nextOption() + } + } + + /** + * Look up the `ExpressionInfo` of the given function by name. + * Searches through extension, built-in, and temp functions in that order. * This only supports scalar functions. + * + * Resolution order: extension -> builtin -> session (temp) */ def lookupBuiltinOrTempFunction(name: String): Option[ExpressionInfo] = { - FunctionRegistry.builtinOperators.get(name.toLowerCase(Locale.ROOT)).orElse { - synchronized(lookupTempFuncWithViewContext( - name, FunctionRegistry.builtin.functionExists, functionRegistry.lookupFunction)) + lookupFunctionWithShadowing(name, functionRegistry, checkBuiltinOperators = true) + } + + /** + * Look up the `ExpressionInfo` of the given function by name. + * Searches through extension, built-in, and temp table functions in that order. + * + * Resolution order: extension -> builtin -> session (temp) + */ + def lookupBuiltinOrTempTableFunction(name: String): Option[ExpressionInfo] = { + lookupFunctionWithShadowing(name, tableFunctionRegistry, checkBuiltinOperators = false) + } + + /** + * Look up only builtin function (no temp). + */ + def lookupBuiltinFunction(name: String): Option[ExpressionInfo] = { + val builtinIdentifier = FunctionIdentifier(format(name)) + functionRegistry.lookupFunction(builtinIdentifier) + } + + /** + * Look up only builtin table function (no temp). + */ + def lookupBuiltinTableFunction(name: String): Option[ExpressionInfo] = { + val builtinIdentifier = FunctionIdentifier(format(name)) + tableFunctionRegistry.lookupFunction(builtinIdentifier) + } + + /** + * Look up only temp function (no builtin). + */ + def lookupTempFunction(name: String): Option[ExpressionInfo] = { + val tempIdentifier = tempFunctionIdentifier(name) + synchronized(lookupTempFuncWithViewContext( + name, + // Return false if temp exists (not builtin) + ident => !functionRegistry.functionExists(tempIdentifier), + _ => functionRegistry.lookupFunction(tempIdentifier))) + } + + /** + * Look up only temp table function (no builtin). + */ + def lookupTempTableFunction(name: String): Option[ExpressionInfo] = { + val tempIdentifier = tempFunctionIdentifier(name) + if (tableFunctionRegistry.functionExists(tempIdentifier)) { + tableFunctionRegistry.lookupFunction(tempIdentifier) + } else { + None } } /** - * Look up the `ExpressionInfo` of the given function by name if it's a built-in or - * temp table function. + * Look up only extension function (no builtin or temp). */ - def lookupBuiltinOrTempTableFunction(name: String): Option[ExpressionInfo] = synchronized { - lookupTempFuncWithViewContext( - name, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry.lookupFunction) + def lookupExtensionFunction(name: String): Option[ExpressionInfo] = { + val extensionIdentifier = extensionFunctionIdentifier(name) + functionRegistry.lookupFunction(extensionIdentifier) } /** - * Look up a built-in or temp scalar function by name and resolves it to an Expression if such - * a function exists. + * Look up only extension table function (no builtin or temp). */ - def resolveBuiltinOrTempFunction(name: String, arguments: Seq[Expression]): Option[Expression] = { - resolveBuiltinOrTempFunctionInternal( - name, arguments, FunctionRegistry.builtin.functionExists, functionRegistry) + def lookupExtensionTableFunction(name: String): Option[ExpressionInfo] = { + val extensionIdentifier = extensionFunctionIdentifier(name) + tableFunctionRegistry.lookupFunction(extensionIdentifier) } /** - * Look up a built-in or temp table function by name and resolves it to a LogicalPlan if such - * a function exists. + * Resolve only builtin function. */ - def resolveBuiltinOrTempTableFunction( - name: String, arguments: Seq[Expression]): Option[LogicalPlan] = { - resolveBuiltinOrTempFunctionInternal( - name, arguments, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry) + def resolveBuiltinFunction(name: String, arguments: Seq[Expression]): Option[Expression] = { + val builtinIdentifier = FunctionIdentifier(format(name)) + if (functionRegistry.functionExists(builtinIdentifier)) { + Option(functionRegistry.lookupFunction(builtinIdentifier, arguments)) + } else { + None + } } - private def resolveBuiltinOrTempFunctionInternal[T]( + /** + * Resolve only temp function. + */ + def resolveTempFunction(name: String, arguments: Seq[Expression]): Option[Expression] = { + val tempIdentifier = tempFunctionIdentifier(name) + synchronized { + if (functionRegistry.functionExists(tempIdentifier)) { + lookupTempFuncWithViewContext( + name, + // Return false if temp exists (not builtin) + ident => !functionRegistry.functionExists(tempIdentifier), + _ => Option(functionRegistry.lookupFunction(tempIdentifier, arguments))) + } else { + None + } + } + } + + /** + * Resolve only extension function. + */ + def resolveExtensionFunction(name: String, arguments: Seq[Expression]): Option[Expression] = { + val extensionIdentifier = extensionFunctionIdentifier(name) + if (functionRegistry.functionExists(extensionIdentifier)) { + Option(functionRegistry.lookupFunction(extensionIdentifier, arguments)) + } else { + None + } + } + + /** + * Look up a scalar function by name and resolve it to an Expression. + * Searches through extension, built-in, and temp functions in that order. + * + * Resolution order: extension -> builtin -> session (temp) + */ + def resolveBuiltinOrTempFunction(name: String, arguments: Seq[Expression]): Option[Expression] = + resolveFunctionWithFallback(name, arguments, functionRegistry) + + /** + * Resolve only builtin table function. + */ + def resolveBuiltinTableFunction( name: String, - arguments: Seq[Expression], - isBuiltin: FunctionIdentifier => Boolean, - registry: FunctionRegistryBase[T]): Option[T] = synchronized { - val funcIdent = FunctionIdentifier(name) - if (!registry.functionExists(funcIdent)) { + arguments: Seq[Expression]): Option[LogicalPlan] = { + val builtinIdentifier = FunctionIdentifier(format(name)) + if (tableFunctionRegistry.functionExists(builtinIdentifier)) { + Option(tableFunctionRegistry.lookupFunction(builtinIdentifier, arguments)) + } else { None + } + } + + /** + * Resolve only temp table function. + */ + def resolveTempTableFunction( + name: String, + arguments: Seq[Expression]): Option[LogicalPlan] = { + val tempIdentifier = tempFunctionIdentifier(name) + synchronized { + if (tableFunctionRegistry.functionExists(tempIdentifier)) { + lookupTempFuncWithViewContext( + name, + // Return false if temp exists (not builtin) + ident => !tableFunctionRegistry.functionExists(tempIdentifier), + _ => Option(tableFunctionRegistry.lookupFunction(tempIdentifier, arguments))) + } else { + None + } + } + } + + /** + * Resolve only extension table function. + */ + def resolveExtensionTableFunction( + name: String, + arguments: Seq[Expression]): Option[LogicalPlan] = { + val extensionIdentifier = extensionFunctionIdentifier(name) + if (tableFunctionRegistry.functionExists(extensionIdentifier)) { + Option(tableFunctionRegistry.lookupFunction(extensionIdentifier, arguments)) } else { - lookupTempFuncWithViewContext( - name, isBuiltin, ident => Option(registry.lookupFunction(ident, arguments))) + None } } + /** + * Look up a table function by name and resolve it to a LogicalPlan. + * Searches through extension, built-in, and temp functions in that order. + * + * Resolution order: extension -> builtin -> session (temp) + */ + def resolveBuiltinOrTempTableFunction( + name: String, + arguments: Seq[Expression]): Option[LogicalPlan] = + resolveFunctionWithFallback(name, arguments, tableFunctionRegistry) + + /** + * Resolves functions using PATH-based resolution. + * Searches through the resolution path, returning the first function found. + * + * @param name The function name (unqualified). + * @param arguments The arguments to pass to the function. + * @param registry The registry to search (FunctionRegistry or TableFunctionRegistry). + * @tparam T The registry's type parameter (Expression for FunctionRegistry, + * LogicalPlan for TableFunctionRegistry). + * @return Resolved function if found, None otherwise. + */ + private def resolveFunctionWithFallback[T]( + name: String, + arguments: Seq[Expression], + registry: FunctionRegistryBase[T]): Option[T] = { + + // Use PATH-based resolution: iterate through namespaces until a match is found. + val path = resolutionPath() + + // Use iterator for short-circuit evaluation (stops at first match). + path.iterator.flatMap { namespace => + resolveInNamespace(namespace, name, arguments, registry) + }.nextOption() + } + + /** + * Looks up a temporary function with view context handling. + * Used by legacy code paths that need explicit control over the isBuiltin check. + * + * @param name The function name. + * @param isBuiltin Function to check if identifier is builtin (skip view context if true). + * @param lookupFunc Function to perform the actual lookup. + * @tparam T The result type. + * @return The lookup result with view context applied. + */ private def lookupTempFuncWithViewContext[T]( name: String, isBuiltin: FunctionIdentifier => Boolean, lookupFunc: FunctionIdentifier => Option[T]): Option[T] = { val funcIdent = FunctionIdentifier(name) if (isBuiltin(funcIdent)) { + // Builtin functions are not subject to view context restrictions lookupFunc(funcIdent) } else { - val isResolvingView = AnalysisContext.get.catalogAndNamespace.nonEmpty - val referredTempFunctionNames = AnalysisContext.get.referredTempFunctionNames - if (isResolvingView) { - // When resolving a view, only return a temp function if it's referred by this view. - if (referredTempFunctionNames.contains(name)) { - lookupFunc(funcIdent) - } else { - None - } - } else { - val result = lookupFunc(funcIdent) - if (result.isDefined) { - // We are not resolving a view and the function is a temp one, add it to - // `AnalysisContext`, so during the view creation, we can save all referred temp - // functions to view metadata. - AnalysisContext.get.referredTempFunctionNames.add(name) - } - result - } + // Temp functions must respect view context + handleViewContext(name, lookupFunc(funcIdent)) } } @@ -2196,7 +2718,7 @@ class SessionCatalog( val funcMetadata = fetchCatalogFunction(qualifiedIdent) if (funcMetadata.isUserDefinedFunction) { UserDefinedFunction.fromCatalogFunction(funcMetadata, parser).toExpressionInfo - } else { + } else { makeExprInfoForHiveFunction(funcMetadata) } } @@ -2237,15 +2759,16 @@ class SessionCatalog( registerUserDefinedFunction[Expression]( udf, overrideIfExists = false, - functionRegistry, + functionRegistry, makeUserDefinedScalarFuncBuilder(udf)) } else { loadFunctionResources(funcMetadata.resources) - registerFunction( + registerFunction( funcMetadata, - overrideIfExists = false, + overrideIfExists = false, functionRegistry, - makeFunctionBuilder(funcMetadata)) + makeFunctionBuilder(funcMetadata), + useCompositeKey = false) // Persistent functions don't use composite keys } functionRegistry.lookupFunctionBuilder(qualifiedIdent).get } @@ -2273,9 +2796,9 @@ class SessionCatalog( failFunctionLookup(qualifiedIdent) } val udf = UserDefinedFunction.fromCatalogFunction(funcMetadata, parser) - registerUserDefinedFunction[LogicalPlan]( + registerUserDefinedFunction[LogicalPlan]( udf, - overrideIfExists = false, + overrideIfExists = false, tableFunctionRegistry, makeUserDefinedTableFuncBuilder(udf)) tableFunctionRegistry.lookupFunction(qualifiedIdent, arguments) @@ -2286,8 +2809,8 @@ class SessionCatalog( * Fetch a catalog function from the external catalog. */ private def fetchCatalogFunction(qualifiedIdent: FunctionIdentifier): CatalogFunction = { - val db = qualifiedIdent.database.get - val funcName = qualifiedIdent.funcName + val db = qualifiedIdent.database.get + val funcName = qualifiedIdent.funcName requireDbExists(db) try { // Please note that qualifiedIdent is provided by the user. However, @@ -2296,7 +2819,7 @@ class SessionCatalog( // catalogFunction.identifier (difference is on case-sensitivity). // At here, we preserve the input from the user. externalCatalog.getFunction(db, funcName).copy(identifier = qualifiedIdent) - } catch { + } catch { case _: NoSuchPermanentFunctionException | _: NoSuchFunctionException => failFunctionLookup(qualifiedIdent) } @@ -2336,11 +2859,20 @@ class SessionCatalog( } /** - * List all built-in and temporary functions with the given pattern. + * List all built-in, extension, and temporary functions with the given pattern. */ private def listBuiltinAndTempFunctions(pattern: String): Seq[FunctionIdentifier] = { val functions = (functionRegistry.listFunction() ++ tableFunctionRegistry.listFunction()) - .filter(_.database.isEmpty) + .filter(f => + f.database.isEmpty || + isTempFunctionIdentifier(f) || + isExtensionFunctionIdentifier(f)) + .map(f => if (isTempFunctionIdentifier(f) || isExtensionFunctionIdentifier(f)) { + // Strip namespace qualifier for temp and extension functions + FunctionIdentifier(f.funcName) + } else { + f + }) StringUtils.filterPattern(functions.map(_.unquotedString), pattern).map { f => // In functionRegistry, function names are stored as an unquoted format. Try(parser.parseFunctionIdentifier(f)) match { @@ -2386,7 +2918,9 @@ class SessionCatalog( */ def listTemporaryFunctions(): Seq[FunctionIdentifier] = { (functionRegistry.listFunction() ++ tableFunctionRegistry.listFunction()) - .filter(isTemporaryFunction) + .filter(isTempFunctionIdentifier) + // Strip the TEMP_FUNCTION_DB database qualifier + .map(ident => FunctionIdentifier(ident.funcName)) } // ----------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f4b9b06e471cb..1f5333c5deba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -875,12 +875,9 @@ class AstBuilder extends DataTypeAstBuilder /** * Add an * {{{ - * INSERT [WITH SCHEMA EVOLUTION] OVERWRITE - * TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]? [identifierList] - * INSERT [WITH SCHEMA EVOLUTION] INTO - * [TABLE] tableIdentifier [partitionSpec] ([BY NAME] | [identifierList]) - * INSERT [WITH SCHEMA EVOLUTION] INTO - * [TABLE] tableIdentifier REPLACE whereClause + * INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]? [identifierList] + * INSERT INTO [TABLE] tableIdentifier [partitionSpec] ([BY NAME] | [identifierList]) + * INSERT INTO [TABLE] tableIdentifier REPLACE whereClause * INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat] * INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS tablePropertyList] * }}} @@ -909,8 +906,7 @@ class AstBuilder extends DataTypeAstBuilder query = otherPlans.head, overwrite = false, ifPartitionNotExists = insertParams.ifPartitionNotExists, - byName = insertParams.byName, - withSchemaEvolution = table.EVOLUTION() != null) + byName = insertParams.byName) }) case table: InsertOverwriteTableContext => val insertParams = visitInsertOverwriteTable(table) @@ -927,8 +923,7 @@ class AstBuilder extends DataTypeAstBuilder query = otherPlans.head, overwrite = true, ifPartitionNotExists = insertParams.ifPartitionNotExists, - byName = insertParams.byName, - withSchemaEvolution = table.EVOLUTION() != null) + byName = insertParams.byName) }) case ctx: InsertIntoReplaceWhereContext => val options = Option(ctx.optionsClause()) @@ -937,20 +932,10 @@ class AstBuilder extends DataTypeAstBuilder Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE), isStreaming = false) val deleteExpr = expression(ctx.whereClause().booleanExpression()) val isByName = ctx.NAME() != null - val schemaEvolutionWriteOption: Map[String, String] = - if (ctx.EVOLUTION() != null) Map("mergeSchema" -> "true") else Map.empty if (isByName) { - OverwriteByExpression.byName( - table, - df = otherPlans.head, - deleteExpr, - writeOptions = schemaEvolutionWriteOption) + OverwriteByExpression.byName(table, otherPlans.head, deleteExpr) } else { - OverwriteByExpression.byPosition( - table, - query = otherPlans.head, - deleteExpr, - writeOptions = schemaEvolutionWriteOption) + OverwriteByExpression.byPosition(table, otherPlans.head, deleteExpr) } }) case dir: InsertOverwriteDirContext => @@ -2447,12 +2432,8 @@ class AstBuilder extends DataTypeAstBuilder funcCallCtx.funcName, Nil, (ident, _) => { - if (ident.length > 1) { - throw new ParseException( - errorClass = "INVALID_SQL_SYNTAX.INVALID_TABLE_VALUED_FUNC_NAME", - messageParameters = Map("funcName" -> toSQLId(ident)), - ctx = funcCallCtx) - } + // Allow qualified table-valued function names (e.g., builtin.range, system.session.my_tvf) + // Removed artificial restriction on multi-part identifiers val funcName = funcCallCtx.funcName.getText val args = funcCallCtx.functionTableArgument.asScala.map { e => Option(e.functionArgument).map(extractNamedArgument(_, funcName)) @@ -2494,62 +2475,19 @@ class AstBuilder extends DataTypeAstBuilder buildTvfFromTableFunctionCall(ctx.tableFunctionCall, ctx.tableAlias, ctx.watermarkClause) } - /** - * Extract the source name from an identifiedByClause context. - */ - private def extractSourceName(ctx: IdentifiedByClauseContext): Option[String] = { - Option(ctx).map(c => c.sourceName.identifier.getText) - } - override def visitStreamTableName(ctx: StreamTableNameContext): LogicalPlan = { val ident = visitMultipartIdentifier(ctx.multipartIdentifier) - val relation = createUnresolvedRelation( + val tableStreamingRelation = createUnresolvedRelation( ctx = ctx, ident = ident, optionsClause = Option(ctx.optionsClause), writePrivileges = Seq.empty, isStreaming = true) - val table = mayApplyAliasPlan(ctx.tableAlias, relation) - val tableWithWatermark = table.optionalMap(ctx.watermarkClause)(withWatermark) - val sourceNameOpt = extractSourceName(ctx.identifiedByClause) - tableWithWatermark.transformUp { - case r: UnresolvedRelation => - NamedStreamingRelation.withUserProvidedName( - r.copy(isStreaming = true), sourceNameOpt) - } + val tableWithWatermark = tableStreamingRelation.optionalMap(ctx.watermarkClause)(withWatermark) + mayApplyAliasPlan(ctx.tableAlias, tableWithWatermark) } - /** - * Create a logical plan for a stream TVF. - * Handles two forms: - * 1. STREAM tableFunctionCallWithTrailingClauses - clauses are inside - * 2. STREAM(tableFunctionCall) clauses - clauses are outside STREAM() for consistency with - * table names - */ - override def visitStreamTableValuedFunction(ctx: StreamTableValuedFunctionContext): LogicalPlan = - withOrigin(ctx) { - Option(ctx.tableFunctionCallWithTrailingClauses).map { funcTable => - // Form: STREAM tableFunctionCallWithTrailingClauses - val sourceName = extractSourceName(funcTable.identifiedByClause) - val tvfPlan = buildTvfFromTableFunctionCall( - funcTable.tableFunctionCall, funcTable.tableAlias, funcTable.watermarkClause) - tvfPlan.transformUp { - case tvf: UnresolvedTableValuedFunction => - NamedStreamingRelation.withUserProvidedName(tvf.copy(isStreaming = true), sourceName) - } - }.getOrElse { - // Form: STREAM(tableFunctionCall) identifiedByClause? watermarkClause? tableAlias - val sourceName = extractSourceName(ctx.identifiedByClause) - val tvfPlan = buildTvfFromTableFunctionCall( - ctx.tableFunctionCall, ctx.tableAlias, ctx.watermarkClause) - tvfPlan.transformUp { - case tvf: UnresolvedTableValuedFunction => - NamedStreamingRelation.withUserProvidedName(tvf.copy(isStreaming = true), sourceName) - } - } - } - /** * Create an inline table (a virtual table in Hive parlance). */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index d59ef5875cab9..d53a4d1c9767b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -162,5 +162,6 @@ private[sql] object CatalogManager { val SESSION_CATALOG_NAME: String = "spark_catalog" val SYSTEM_CATALOG_NAME = "system" val SESSION_NAMESPACE = "session" + val EXTENSION_NAMESPACE = "extension" val BUILTIN_NAMESPACE = "builtin" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index c0e6782c563b8..490e2873d347c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -685,8 +685,8 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat def singleTableStarInCountNotAllowedError(targetString: String): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1021", - messageParameters = Map("targetString" -> targetString)) + errorClass = "INVALID_USAGE_OF_STAR_WITH_TABLE_IDENTIFIER_IN_COUNT", + messageParameters = Map("tableName" -> targetString)) } def orderByPositionRangeError(index: Int, size: Int, t: TreeNode[_]): Throwable = { @@ -942,6 +942,21 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat origin = context) } + def notAScalarFunctionError( + functionName: String, + u: TreeNode[_]): Throwable = { + new AnalysisException( + errorClass = "NOT_A_SCALAR_FUNCTION", + messageParameters = Map("functionName" -> toSQLId(functionName)), + origin = u.origin) + } + + def notATableFunctionError(functionName: String): Throwable = { + new AnalysisException( + errorClass = "NOT_A_TABLE_FUNCTION", + messageParameters = Map("functionName" -> toSQLId(functionName))) + } + def wrongNumArgsError( name: String, validParametersCount: Seq[Any], @@ -3058,6 +3073,18 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map("functionName" -> functionName)) } + def invalidTempObjQualifierError( + objectType: String, + objectName: String, + qualifier: String): Throwable = { + new AnalysisException( + errorClass = "INVALID_TEMP_OBJ_QUALIFIER", + messageParameters = Map( + "objectType" -> objectType, + "objectName" -> toSQLId(objectName), + "qualifier" -> toSQLId(qualifier))) + } + def cannotRefreshBuiltInFuncError(functionName: String, t: TreeNode[_]): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1256", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index 39524e60862d1..09a2035b6d620 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.classic.Strategy +import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} /** @@ -365,14 +366,31 @@ class SparkSessionExtensions { private[sql] def registerFunctions(functionRegistry: FunctionRegistry) = { for ((name, expressionInfo, function) <- injectedFunctions) { - functionRegistry.registerFunction(name, expressionInfo, function) + // Extension functions use EXTENSION_NAMESPACE to: + // 1. Allow trusted extensions (Sedona, Delta) to shadow/augment built-ins + // 2. Ensure extension functions resolve BEFORE built-ins (extension → builtin → session) + // 3. Prevent untrusted user temp functions from shadowing security-critical built-ins + // (builtin resolves before session, blocking attacks like shadowing current_user()) + val extensionQualifiedName = if (name.database.isEmpty) { + FunctionIdentifier(name.funcName, Some(CatalogManager.EXTENSION_NAMESPACE)) + } else { + name + } + functionRegistry.registerFunction(extensionQualifiedName, expressionInfo, function) } functionRegistry } private[sql] def registerTableFunctions(tableFunctionRegistry: TableFunctionRegistry) = { for ((name, expressionInfo, function) <- injectedTableFunctions) { - tableFunctionRegistry.registerFunction(name, expressionInfo, function) + // Extension table functions use EXTENSION_NAMESPACE for the same reasons as scalar functions: + // resolution order extension → builtin → session protects security-critical built-ins + val extensionQualifiedName = if (name.database.isEmpty) { + FunctionIdentifier(name.funcName, Some(CatalogManager.EXTENSION_NAMESPACE)) + } else { + name + } + tableFunctionRegistry.registerFunction(extensionQualifiedName, expressionInfo, function) } tableFunctionRegistry } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 2dd88eeeb1c32..412b90c6c7a36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.catalyst.util.DateTimeConstants +import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources._ @@ -177,6 +178,51 @@ class SparkSqlAstBuilder extends AstBuilder { private val configValueDef = """([^;]*);*""".r private val strLiteralDef = """(".*?[^\\]"|'.*?[^\\]'|[^ \n\r\t"']+)""".r + /** + * Extract the actual function name from a potentially qualified temporary function name. + * Supports: funcName, session.funcName, system.session.funcName + * Throws an error for any other qualification pattern. + */ + private def extractTempFunctionName( + functionIdentifier: Seq[String], + ctx: ParserRuleContext, + forDrop: Boolean = false): String = { + functionIdentifier.length match { + case 1 => + // Simple unqualified name + functionIdentifier.head + case 2 => + // Check if it's session.funcName + if (functionIdentifier.head.equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE)) { + functionIdentifier.last + } else { + // Otherwise it's an invalid qualifier (e.g., database name) + val funcName = functionIdentifier.last + val qualifier = functionIdentifier.head + throw QueryCompilationErrors.invalidTempObjQualifierError( + "FUNCTION", funcName, qualifier) + } + case 3 => + // Check if it's system.session.funcName + if (functionIdentifier(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) && + functionIdentifier(1).equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE)) { + functionIdentifier.last + } else { + // Invalid three-part qualifier + val funcName = functionIdentifier.last + val qualifier = functionIdentifier.init.mkString(".") + throw QueryCompilationErrors.invalidTempObjQualifierError( + "FUNCTION", funcName, qualifier) + } + case _ => + // More than 3 parts - invalid + val funcName = functionIdentifier.last + val qualifier = functionIdentifier.init.mkString(".") + throw QueryCompilationErrors.invalidTempObjQualifierError( + "FUNCTION", funcName, qualifier) + } + } + private def withCatalogIdentClause( ctx: CatalogIdentifierReferenceContext, builder: Seq[String] => LogicalPlan): LogicalPlan = { @@ -813,14 +859,10 @@ class SparkSqlAstBuilder extends AstBuilder { throw QueryParsingErrors.defineTempFuncWithIfNotExistsError(ctx) } - if (functionIdentifier.length > 2) { - throw QueryParsingErrors.unsupportedFunctionNameError(functionIdentifier, ctx) - } else if (functionIdentifier.length == 2) { - // Temporary function names should not contain database prefix like "database.function" - throw QueryParsingErrors.specifyingDBInCreateTempFuncError(functionIdentifier.head, ctx) - } + // Extract the actual function name, handling session qualification + val funcName = extractTempFunctionName(functionIdentifier, ctx) CreateFunctionCommand( - FunctionIdentifier(functionIdentifier.last), + FunctionIdentifier(funcName), string(visitStringLit(ctx.className)), resources.toSeq, true, @@ -907,15 +949,10 @@ class SparkSqlAstBuilder extends AstBuilder { throw QueryParsingErrors.defineTempFuncWithIfNotExistsError(ctx) } - if (functionIdentifier.length > 2) { - throw QueryParsingErrors.unsupportedFunctionNameError(functionIdentifier, ctx) - } else if (functionIdentifier.length == 2) { - // Temporary function names should not contain database prefix like "database.function" - throw QueryParsingErrors.specifyingDBInCreateTempFuncError(functionIdentifier.head, ctx) - } - + // Extract the actual function name, handling session qualification + val funcName = extractTempFunctionName(functionIdentifier, ctx) CreateUserDefinedFunctionCommand( - functionIdentifier.asFunctionIdentifier, + FunctionIdentifier(funcName), inputParamText, returnTypeText, exprText, @@ -1035,11 +1072,10 @@ class SparkSqlAstBuilder extends AstBuilder { val identCtx = ctx.identifierReference() if (isTemp) { withIdentClause(identCtx, functionName => { - if (functionName.length > 1) { - throw QueryParsingErrors.invalidNameForDropTempFunc(functionName, ctx) - } + // Extract the actual function name, handling session qualification + val funcName = extractTempFunctionName(functionName, ctx, forDrop = true) DropFunctionCommand( - identifier = FunctionIdentifier(functionName.head), + identifier = FunctionIdentifier(funcName), ifExists = ctx.EXISTS != null, isTemp = true) }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index b190d91df588b..21970fd6ba270 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResource} import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionInfo} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.StringUtils -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{StringType, StructField, StructType} @@ -139,11 +140,27 @@ case class DropFunctionCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog if (isTemp) { - assert(identifier.database.isEmpty) - if (FunctionRegistry.builtin.functionExists(identifier)) { - throw QueryCompilationErrors.cannotDropBuiltinFuncError(identifier.funcName) + // Extract the function name, handling qualified names like "system.session.func" + val funcName = if (identifier.database.isDefined) { + // Qualified name - validate it's a valid temporary function namespace (case-insensitive) + val db = identifier.database.get + if (!db.equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE)) { + throw QueryExecutionErrors.invalidNamespaceNameError( + Array(CatalogManager.SYSTEM_CATALOG_NAME, db)) + } + identifier.funcName + } else { + identifier.funcName + } + + // Check if temp function exists first - if it does, allow dropping it even if a builtin + // with the same name exists (shadowing case) + val unqualifiedIdent = FunctionIdentifier(funcName) + if (!catalog.isTemporaryFunction(unqualifiedIdent) && + catalog.isBuiltinFunction(unqualifiedIdent)) { + throw QueryCompilationErrors.cannotDropBuiltinFuncError(funcName) } - catalog.dropTempFunction(identifier.funcName, ifExists) + catalog.dropTempFunction(funcName, ifExists) } else { // We are dropping a permanent function. catalog.dropFunction(identifier, ignoreIfNotExists = ifExists) diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/count.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/count.sql.out index 732b714615792..d142fef725901 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/count.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/count.sql.out @@ -223,8 +223,9 @@ SELECT count(testData.*) FROM testData -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1021", + "errorClass" : "INVALID_USAGE_OF_STAR_WITH_TABLE_IDENTIFIER_IN_COUNT", + "sqlState" : "42000", "messageParameters" : { - "targetString" : "testData" + "tableName" : "testData" } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out index f0a7722886ed8..d062b95c8f2e8 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out @@ -991,41 +991,30 @@ org.apache.spark.sql.AnalysisException -- !query CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.apache.spark.sql.MyDoubleAvg' -- !query analysis -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_DATABASE", - "sqlState" : "42000", + "errorClass" : "INVALID_TEMP_OBJ_QUALIFIER", + "sqlState" : "42602", "messageParameters" : { - "database" : "`default`" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 1, - "stopIndex" : 108, - "fragment" : "CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.apache.spark.sql.MyDoubleAvg'" - } ] + "objectName" : "`myDoubleAvg`", + "objectType" : "FUNCTION", + "qualifier" : "`default`" + } } -- !query DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') -- !query analysis -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_SQL_SYNTAX.MULTI_PART_NAME", - "sqlState" : "42000", + "errorClass" : "INVALID_TEMP_OBJ_QUALIFIER", + "sqlState" : "42602", "messageParameters" : { - "name" : "`default`.`myDoubleAvg`", - "statement" : "DROP TEMPORARY FUNCTION" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 1, - "stopIndex" : 63, - "fragment" : "DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg')" - } ] + "objectName" : "`myDoubleAvg`", + "objectType" : "FUNCTION", + "qualifier" : "`default`" + } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out index 00740529b8a87..7e403dbf31b6c 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out @@ -991,41 +991,30 @@ org.apache.spark.sql.AnalysisException -- !query CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.apache.spark.sql.MyDoubleAvg' -- !query analysis -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_DATABASE", - "sqlState" : "42000", + "errorClass" : "INVALID_TEMP_OBJ_QUALIFIER", + "sqlState" : "42602", "messageParameters" : { - "database" : "`default`" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 1, - "stopIndex" : 108, - "fragment" : "CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.apache.spark.sql.MyDoubleAvg'" - } ] + "objectName" : "`myDoubleAvg`", + "objectType" : "FUNCTION", + "qualifier" : "`default`" + } } -- !query DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') -- !query analysis -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_SQL_SYNTAX.MULTI_PART_NAME", - "sqlState" : "42000", + "errorClass" : "INVALID_TEMP_OBJ_QUALIFIER", + "sqlState" : "42602", "messageParameters" : { - "name" : "`default`.`myDoubleAvg`", - "statement" : "DROP TEMPORARY FUNCTION" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 1, - "stopIndex" : 63, - "fragment" : "DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg')" - } ] + "objectName" : "`myDoubleAvg`", + "objectType" : "FUNCTION", + "qualifier" : "`default`" + } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part3.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part3.sql.out index bc6dc828ad857..629058686a33d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part3.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part3.sql.out @@ -410,11 +410,10 @@ SELECT range(1, 100) OVER () FROM empsalary -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "UNRESOLVED_ROUTINE", - "sqlState" : "42883", + "errorClass" : "NOT_A_SCALAR_FUNCTION", + "sqlState" : "42887", "messageParameters" : { - "routineName" : "`range`", - "searchPath" : "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]" + "functionName" : "`range`" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/count.sql.out b/sql/core/src/test/resources/sql-tests/results/count.sql.out index 0420922799299..f2ad839cd51fd 100644 --- a/sql/core/src/test/resources/sql-tests/results/count.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/count.sql.out @@ -191,8 +191,9 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1021", + "errorClass" : "INVALID_USAGE_OF_STAR_WITH_TABLE_IDENTIFIER_IN_COUNT", + "sqlState" : "42000", "messageParameters" : { - "targetString" : "testData" + "tableName" : "testData" } } diff --git a/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out b/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out index 13a4b43fd0589..7119e0916a3d4 100644 --- a/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out @@ -1133,20 +1133,15 @@ CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.a -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_DATABASE", - "sqlState" : "42000", + "errorClass" : "INVALID_TEMP_OBJ_QUALIFIER", + "sqlState" : "42602", "messageParameters" : { - "database" : "`default`" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 1, - "stopIndex" : 108, - "fragment" : "CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.apache.spark.sql.MyDoubleAvg'" - } ] + "objectName" : "`myDoubleAvg`", + "objectType" : "FUNCTION", + "qualifier" : "`default`" + } } @@ -1155,21 +1150,15 @@ DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_SQL_SYNTAX.MULTI_PART_NAME", - "sqlState" : "42000", + "errorClass" : "INVALID_TEMP_OBJ_QUALIFIER", + "sqlState" : "42602", "messageParameters" : { - "name" : "`default`.`myDoubleAvg`", - "statement" : "DROP TEMPORARY FUNCTION" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 1, - "stopIndex" : 63, - "fragment" : "DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg')" - } ] + "objectName" : "`myDoubleAvg`", + "objectType" : "FUNCTION", + "qualifier" : "`default`" + } } diff --git a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out index beeb3b13fe1ee..da7b6e3a31abb 100644 --- a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out @@ -1133,20 +1133,15 @@ CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.a -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_DATABASE", - "sqlState" : "42000", + "errorClass" : "INVALID_TEMP_OBJ_QUALIFIER", + "sqlState" : "42602", "messageParameters" : { - "database" : "`default`" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 1, - "stopIndex" : 108, - "fragment" : "CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.apache.spark.sql.MyDoubleAvg'" - } ] + "objectName" : "`myDoubleAvg`", + "objectType" : "FUNCTION", + "qualifier" : "`default`" + } } @@ -1155,21 +1150,15 @@ DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.parser.ParseException +org.apache.spark.sql.AnalysisException { - "errorClass" : "INVALID_SQL_SYNTAX.MULTI_PART_NAME", - "sqlState" : "42000", + "errorClass" : "INVALID_TEMP_OBJ_QUALIFIER", + "sqlState" : "42602", "messageParameters" : { - "name" : "`default`.`myDoubleAvg`", - "statement" : "DROP TEMPORARY FUNCTION" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 1, - "stopIndex" : 63, - "fragment" : "DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg')" - } ] + "objectName" : "`myDoubleAvg`", + "objectType" : "FUNCTION", + "qualifier" : "`default`" + } } diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out index 6cfb2cb4b451d..bab66aa48be25 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part3.sql.out @@ -437,11 +437,10 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "UNRESOLVED_ROUTINE", - "sqlState" : "42883", + "errorClass" : "NOT_A_SCALAR_FUNCTION", + "sqlState" : "42887", "messageParameters" : { - "routineName" : "`range`", - "searchPath" : "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]" + "functionName" : "`range`" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FunctionQualificationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FunctionQualificationSuite.scala new file mode 100644 index 0000000000000..5f299a5729b93 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/FunctionQualificationSuite.scala @@ -0,0 +1,623 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Literal} +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.IntegerType + +/** + * Comprehensive test suite for function qualification and resolution. + * Tests builtin/temp/persistent function name disambiguation with qualified names. + * This tests the system.builtin, system.session, and system.extension namespaces. + * + * Includes tests from: + * 1. function-qualification.sql golden file (SQL-accessible functions) + * 2. Extension function tests (programmatic registration via SparkSessionExtensions) + */ +class FunctionQualificationSuite extends QueryTest with SharedSparkSession { + + override protected def sparkConf = { + super.sparkConf.set("spark.sql.extensions", classOf[TestExtensions].getName) + } + + test("SECTION 1: Basic Qualification Tests") { + // Test builtin function with explicit qualification + checkAnswer(sql("SELECT system.builtin.abs(-5)"), Row(5)) + checkAnswer(sql("SELECT builtin.abs(-5)"), Row(5)) + + // Test builtin with case-insensitive qualification + checkAnswer(sql("SELECT BUILTIN.abs(-5)"), Row(5)) + checkAnswer(sql("SELECT System.Builtin.ABS(-5)"), Row(5)) + } + + test("SECTION 2: Temporary Function Creation and Qualification") { + // Create temporary function with unqualified name + sql("CREATE TEMPORARY FUNCTION my_func() RETURNS INT RETURN 42") + checkAnswer(sql("SELECT my_func()"), Row(42)) + + // Create with session qualification + sql("CREATE TEMPORARY FUNCTION session.my_func2() RETURNS STRING RETURN 'temp'") + checkAnswer(sql("SELECT my_func2()"), Row("temp")) + checkAnswer(sql("SELECT session.my_func2()"), Row("temp")) + + // Create with system.session qualification + sql("CREATE TEMPORARY FUNCTION system.session.my_func3() RETURNS INT RETURN 100") + checkAnswer(sql("SELECT my_func3()"), Row(100)) + checkAnswer(sql("SELECT system.session.my_func3()"), Row(100)) + + // Test case insensitivity with temp functions + checkAnswer(sql("SELECT SESSION.my_func()"), Row(42)) + checkAnswer(sql("SELECT SYSTEM.SESSION.my_func2()"), Row("temp")) + + // Clean up + sql("DROP TEMPORARY FUNCTION my_func") + sql("DROP TEMPORARY FUNCTION session.my_func2") + sql("DROP TEMPORARY FUNCTION system.session.my_func3") + } + + test("SECTION 3: Shadowing Behavior (Post-Security Fix)") { + // Create temp function with same name as builtin + sql("CREATE TEMPORARY FUNCTION abs() RETURNS INT RETURN 999") + + // Unqualified abs now resolves to BUILTIN (not temp!), due to security-focused order + checkAnswer(sql("SELECT abs(-5)"), Row(5)) + + // Temp function only accessible with explicit qualification + checkAnswer(sql("SELECT session.abs()"), Row(999)) + + // Builtin still accessible with qualification + checkAnswer(sql("SELECT builtin.abs(-10)"), Row(10)) + checkAnswer(sql("SELECT system.builtin.abs(-10)"), Row(10)) + sql("DROP TEMPORARY FUNCTION abs") + + // After drop, builtin still works unqualified + checkAnswer(sql("SELECT abs(-5)"), Row(5)) + } + + test("SECTION 4: Cross-Type Shadowing - temp table + builtin scalar") { + // Temp table function + builtin scalar function (NO conflict - builtin wins!) + sql("CREATE TEMPORARY FUNCTION abs() RETURNS TABLE(val INT) RETURN SELECT 42") + + // Builtin scalar abs still works unqualified (builtin resolves before temp table) + checkAnswer(sql("SELECT abs(-5)"), Row(5)) + + // Temp table function works in table context + checkAnswer(sql("SELECT * FROM abs()"), Row(42)) + + // Both accessible with explicit qualification + checkAnswer(sql("SELECT builtin.abs(-5)"), Row(5)) + checkAnswer(sql("SELECT * FROM session.abs()"), Row(42)) + sql("DROP TEMPORARY FUNCTION abs") + } + + test("SECTION 4: Cross-Type Shadowing - temp scalar + builtin table") { + // Temp scalar function + builtin table function (NO conflict - builtin wins!) + sql("CREATE TEMPORARY FUNCTION range() RETURNS INT RETURN 999") + + // Builtin table range still works unqualified in table context + checkAnswer( + sql("SELECT * FROM range(5)"), + Seq(Row(0), Row(1), Row(2), Row(3), Row(4))) + + // Temp scalar function works in scalar context + checkAnswer(sql("SELECT range()"), Row(999)) + + // Both accessible with explicit qualification + checkAnswer( + sql("SELECT * FROM builtin.range(5)"), + Seq(Row(0), Row(1), Row(2), Row(3), Row(4))) + checkAnswer(sql("SELECT session.range()"), Row(999)) + sql("DROP TEMPORARY FUNCTION range") + } + + test("SECTION 5: Cross-Type Error Detection - scalar in table context") { + // Scalar function cannot be used in table context + sql("CREATE TEMPORARY FUNCTION scalar_only() RETURNS INT RETURN 42") + checkAnswer(sql("SELECT scalar_only()"), Row(42)) + + checkError( + exception = intercept[AnalysisException] { + sql("SELECT * FROM scalar_only()") + }, + condition = "NOT_A_TABLE_FUNCTION", + parameters = Map("functionName" -> "`scalar_only`"), + context = ExpectedContext( + fragment = "scalar_only()", + start = 14, + stop = 26 + ) + ) + + sql("DROP TEMPORARY FUNCTION scalar_only") + } + + test("SECTION 5: Cross-Type Error Detection - table in scalar context") { + // Table function cannot be used in scalar context + sql("CREATE TEMPORARY FUNCTION table_only() RETURNS TABLE(val INT) RETURN SELECT 42") + checkAnswer(sql("SELECT * FROM table_only()"), Row(42)) + + checkError( + exception = intercept[AnalysisException] { + sql("SELECT table_only()") + }, + condition = "NOT_A_SCALAR_FUNCTION", + parameters = Map("functionName" -> "`table_only`"), + context = ExpectedContext( + fragment = "table_only()", + start = 7, + stop = 18 + ) + ) + + sql("DROP TEMPORARY FUNCTION table_only") + } + + test("SECTION 5: Cross-Type Error Detection - generator functions") { + // Generator functions work in both contexts + checkAnswer(sql("SELECT explode(array(1, 2, 3))"), Seq(Row(1), Row(2), Row(3))) + checkAnswer(sql("SELECT * FROM explode(array(1, 2, 3))"), Seq(Row(1), Row(2), Row(3))) + } + + test("SECTION 6: DDL Operations - DESCRIBE") { + // DESCRIBE builtin functions with qualification + val desc1 = sql("DESCRIBE FUNCTION builtin.abs") + assert(desc1.count() > 0) + + val desc2 = sql("DESCRIBE FUNCTION system.builtin.abs") + assert(desc2.count() > 0) + } + + test("SECTION 6: DDL Operations - DROP with qualified names") { + sql("CREATE TEMPORARY FUNCTION drop_test() RETURNS INT RETURN 1") + sql("DROP TEMPORARY FUNCTION session.drop_test") + } + + test("SECTION 6: DDL Operations - CREATE OR REPLACE") { + sql("CREATE TEMPORARY FUNCTION replace_test() RETURNS INT RETURN 1") + checkAnswer(sql("SELECT replace_test()"), Row(1)) + sql("CREATE OR REPLACE TEMPORARY FUNCTION session.replace_test() RETURNS INT RETURN 2") + checkAnswer(sql("SELECT replace_test()"), Row(2)) + sql("DROP TEMPORARY FUNCTION replace_test") + } + + test("SECTION 6: DDL Operations - CREATE OR REPLACE changes type") { + sql("CREATE TEMPORARY FUNCTION type_change() RETURNS INT RETURN 42") + checkAnswer(sql("SELECT type_change()"), Row(42)) + sql( + "CREATE OR REPLACE TEMPORARY FUNCTION type_change() " + + "RETURNS TABLE(val INT) RETURN SELECT 99") + checkAnswer(sql("SELECT * FROM type_change()"), Row(99)) + sql("DROP TEMPORARY FUNCTION type_change") + } + + test("SECTION 6: DDL Operations - IF NOT EXISTS not supported") { + // IF NOT EXISTS is not supported for temporary functions + checkError( + exception = intercept[ParseException] { + sql("CREATE TEMPORARY FUNCTION IF NOT EXISTS exists_test() RETURNS INT RETURN 1") + }, + condition = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_IF_NOT_EXISTS", + parameters = Map.empty[String, String], + context = ExpectedContext( + fragment = "CREATE TEMPORARY FUNCTION IF NOT EXISTS exists_test() RETURNS INT RETURN 1", + start = 0, + stop = 73 + ) + ) + + // SELECT on non-existent function should fail + checkError( + exception = intercept[AnalysisException] { + sql("SELECT exists_test()") + }, + condition = "UNRESOLVED_ROUTINE", + parameters = Map( + "routineName" -> "`exists_test`", + "searchPath" -> "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]" + ), + context = ExpectedContext( + fragment = "exists_test()", + start = 7, + stop = 19 + ) + ) + + checkError( + exception = intercept[ParseException] { + sql( + "CREATE TEMPORARY FUNCTION IF NOT EXISTS system.session.exists_test() " + + "RETURNS INT RETURN 2") + }, + condition = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_IF_NOT_EXISTS", + parameters = Map.empty[String, String], + context = ExpectedContext( + fragment = "CREATE TEMPORARY FUNCTION IF NOT EXISTS system.session.exists_test() " + + "RETURNS INT RETURN 2", + start = 0, + stop = 88 + ) + ) + + // SELECT on non-existent function should still fail + checkError( + exception = intercept[AnalysisException] { + sql("SELECT exists_test()") + }, + condition = "UNRESOLVED_ROUTINE", + parameters = Map( + "routineName" -> "`exists_test`", + "searchPath" -> "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]" + ), + context = ExpectedContext( + fragment = "exists_test()", + start = 7, + stop = 19 + ) + ) + } + + test("SECTION 6: DDL Operations - SHOW FUNCTIONS") { + sql("CREATE TEMPORARY FUNCTION show_test() RETURNS INT RETURN 1") + + val showTest = sql("SHOW FUNCTIONS LIKE 'show_test'").collect() + assert(showTest.length > 0) + + val showAbs = sql("SHOW FUNCTIONS LIKE 'abs'").collect() + assert(showAbs.length > 0) + + sql("DROP TEMPORARY FUNCTION show_test") + } + + test("SECTION 7: Error Cases - cannot create temp function with builtin namespace") { + checkError( + exception = intercept[ParseException] { + sql("CREATE TEMPORARY FUNCTION system.builtin.my_builtin() RETURNS INT RETURN 1") + }, + condition = "INVALID_TEMP_OBJ_QUALIFIER", + sqlState = "42602", + parameters = Map( + "objectName" -> "`my_builtin`", + "objectType" -> "FUNCTION", + "qualifier" -> "`system`.`builtin`" + ) + ) + } + + test("SECTION 7: Error Cases - cannot create temp function with invalid database") { + checkError( + exception = intercept[ParseException] { + sql("CREATE TEMPORARY FUNCTION mydb.my_func() RETURNS INT RETURN 1") + }, + condition = "INVALID_TEMP_OBJ_QUALIFIER", + sqlState = "42602", + parameters = Map( + "objectName" -> "`my_func`", + "objectType" -> "FUNCTION", + "qualifier" -> "`mydb`" + ) + ) + } + + test("SECTION 7: Error Cases - cannot drop builtin function") { + checkError( + exception = intercept[ParseException] { + sql("DROP TEMPORARY FUNCTION system.builtin.abs") + }, + condition = "INVALID_TEMP_OBJ_QUALIFIER", + sqlState = "42602", + parameters = Map( + "objectName" -> "`abs`", + "objectType" -> "FUNCTION", + "qualifier" -> "`system`.`builtin`" + ) + ) + } + + test("SECTION 7: Error Cases - cannot create duplicate functions") { + sql("CREATE TEMPORARY FUNCTION dup_test() RETURNS INT RETURN 42") + + checkError( + exception = intercept[AnalysisException] { + sql("CREATE TEMPORARY FUNCTION dup_test() RETURNS TABLE(val INT) RETURN SELECT 99") + }, + condition = "ROUTINE_ALREADY_EXISTS", + sqlState = "42723", + parameters = Map( + "existingRoutineType" -> "routine", + "newRoutineType" -> "routine", + "routineName" -> "`dup_test`" + ) + ) + + sql("DROP TEMPORARY FUNCTION dup_test") + } + + test("SECTION 7: Error Cases - non-existent function error") { + checkError( + exception = intercept[AnalysisException] { + sql("SELECT non_existent_func()") + }, + condition = "UNRESOLVED_ROUTINE", + parameters = Map( + "routineName" -> "`non_existent_func`", + "searchPath" -> "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]" + ), + context = ExpectedContext( + fragment = "non_existent_func()", + start = 7, + stop = 25 + ) + ) + } + + test("SECTION 8: Views - temp view can reference temp function") { + sql("CREATE TEMPORARY FUNCTION view_func() RETURNS STRING RETURN 'from_temp'") + sql("CREATE TEMPORARY VIEW temp_view AS SELECT view_func() as result") + checkAnswer(sql("SELECT * FROM temp_view"), Row("from_temp")) + sql("DROP VIEW temp_view") + sql("DROP TEMPORARY FUNCTION view_func") + } + + test("SECTION 8: Views - view with shadowing temp function") { + sql("CREATE TEMPORARY FUNCTION abs() RETURNS INT RETURN 777") + + // View must use qualified name to access temp function + sql("CREATE TEMPORARY VIEW shadow_view AS SELECT session.abs() as result") + checkAnswer(sql("SELECT * FROM shadow_view"), Row(777)) + + // Builtin accessible with qualification in view + sql("CREATE TEMPORARY VIEW builtin_view AS SELECT builtin.abs(-10) as result") + checkAnswer(sql("SELECT * FROM builtin_view"), Row(10)) + + sql("DROP VIEW shadow_view") + sql("DROP VIEW builtin_view") + sql("DROP TEMPORARY FUNCTION abs") + } + + test("SECTION 8: Views - multiple temp functions in same view") { + sql("CREATE TEMPORARY FUNCTION func1() RETURNS INT RETURN 1") + sql("CREATE TEMPORARY FUNCTION func2() RETURNS INT RETURN 2") + sql("CREATE TEMPORARY VIEW multi_func_view AS SELECT func1() + func2() as sum") + checkAnswer(sql("SELECT * FROM multi_func_view"), Row(3)) + sql("DROP VIEW multi_func_view") + sql("DROP TEMPORARY FUNCTION func1") + sql("DROP TEMPORARY FUNCTION func2") + } + + test("SECTION 8: Views - nested views with temp functions") { + sql("CREATE TEMPORARY FUNCTION nested_func() RETURNS INT RETURN 100") + sql("CREATE TEMPORARY VIEW inner_view AS SELECT nested_func() as val") + sql("CREATE TEMPORARY VIEW outer_view AS SELECT val * 2 FROM inner_view") + checkAnswer(sql("SELECT * FROM outer_view"), Row(200)) + sql("DROP VIEW outer_view") + sql("DROP VIEW inner_view") + sql("DROP TEMPORARY FUNCTION nested_func") + } + + test("SECTION 9: Multiple Functions - multiple qualified functions together") { + sql("CREATE TEMPORARY FUNCTION add10(x INT) RETURNS INT RETURN x + 10") + checkAnswer( + sql("SELECT builtin.abs(-5), session.add10(5), system.builtin.upper('hello')"), + Row(5, 15, "HELLO")) + sql("DROP TEMPORARY FUNCTION add10") + } + + test("SECTION 9: Multiple Functions - qualified aggregate function") { + // SQL functions cannot contain aggregate functions - this should error + checkError( + exception = intercept[AnalysisException] { + sql("CREATE TEMPORARY FUNCTION my_avg(x DOUBLE) RETURNS DOUBLE RETURN avg(x)") + }, + condition = "USER_DEFINED_FUNCTIONS.CANNOT_CONTAIN_COMPLEX_FUNCTIONS", + sqlState = "42601", + parameters = Map("queryText" -> "avg(x)") + ) + } + + test("SECTION 9: Multiple Functions - table function with qualified names") { + sql("CREATE TEMPORARY FUNCTION my_range() RETURNS TABLE(id INT) RETURN SELECT * FROM range(3)") + checkAnswer(sql("SELECT * FROM my_range()"), Seq(Row(0), Row(1), Row(2))) + checkAnswer(sql("SELECT * FROM session.my_range()"), Seq(Row(0), Row(1), Row(2))) + checkAnswer(sql("SELECT * FROM system.session.my_range()"), Seq(Row(0), Row(1), Row(2))) + sql("DROP TEMPORARY FUNCTION my_range") + } + + test("SECTION 10: COUNT(*) - unqualified and qualified") { + // Unqualified count(*) + checkAnswer(sql("SELECT count(*) FROM VALUES (1), (2), (3) AS t(a)"), Row(3)) + + // Qualified as builtin.count(*) + checkAnswer(sql("SELECT builtin.count(*) FROM VALUES (1), (2), (3) AS t(a)"), Row(3)) + + // Qualified as system.builtin.count(*) + checkAnswer(sql("SELECT system.builtin.count(*) FROM VALUES (1), (2), (3) AS t(a)"), Row(3)) + + // Case insensitive qualified count(*) + checkAnswer(sql("SELECT BUILTIN.COUNT(*) FROM VALUES (1), (2), (3) AS t(a)"), Row(3)) + checkAnswer(sql("SELECT System.Builtin.Count(*) FROM VALUES (1), (2), (3) AS t(a)"), Row(3)) + } + + test("SECTION 10: COUNT(*) - count(tbl.*) blocking") { + sql("CREATE TEMPORARY VIEW count_test_view AS SELECT 1 AS a, 2 AS b") + + // Unqualified count with table.* + checkError( + exception = intercept[AnalysisException] { + sql("SELECT count(count_test_view.*) FROM count_test_view") + }, + condition = "INVALID_USAGE_OF_STAR_WITH_TABLE_IDENTIFIER_IN_COUNT", + sqlState = "42000", + parameters = Map("tableName" -> "count_test_view") + ) + + // Qualified count with table.* + checkError( + exception = intercept[AnalysisException] { + sql("SELECT builtin.count(count_test_view.*) FROM count_test_view") + }, + condition = "INVALID_USAGE_OF_STAR_WITH_TABLE_IDENTIFIER_IN_COUNT", + sqlState = "42000", + parameters = Map("tableName" -> "count_test_view") + ) + + checkError( + exception = intercept[AnalysisException] { + sql("SELECT system.builtin.count(count_test_view.*) FROM count_test_view") + }, + condition = "INVALID_USAGE_OF_STAR_WITH_TABLE_IDENTIFIER_IN_COUNT", + sqlState = "42000", + parameters = Map("tableName" -> "count_test_view") + ) + + sql("DROP VIEW count_test_view") + } + + test("SECTION 11: Security - user cannot shadow current_user") { + // Baseline: current_user() works + val actualUser = sql("SELECT current_user()").collect().head.getString(0) + assert(actualUser != null) + checkAnswer(sql("SELECT current_user() IS NOT NULL"), Row(true)) + + // User creates temp function trying to shadow current_user + sql("CREATE TEMPORARY FUNCTION current_user() RETURNS STRING RETURN 'hacker'") + + // CRITICAL: Unqualified name still resolves to builtin (NOT the temp function) + val result = sql("SELECT current_user()").collect().head.getString(0) + assert(result == actualUser, s"Builtin was shadowed! Got $result, expected $actualUser") + checkAnswer(sql("SELECT current_user() IS NOT NULL"), Row(true)) + + // User's shadowed function only accessible via explicit qualification + checkAnswer(sql("SELECT session.current_user()"), Row("hacker")) + + sql("DROP TEMPORARY FUNCTION current_user") + } + + test("SECTION 11: Security - user cannot shadow abs") { + // Built-in abs works + checkAnswer(sql("SELECT builtin.abs(-5)"), Row(5)) + + // Create temp abs + sql("CREATE TEMPORARY FUNCTION abs() RETURNS INT RETURN 999") + + // Unqualified abs still resolves to builtin (security-focused order) + checkAnswer(sql("SELECT abs(-5)"), Row(5)) + + // Temp abs only accessible with qualification + checkAnswer(sql("SELECT session.abs()"), Row(999)) + + sql("DROP TEMPORARY FUNCTION abs") + } + + test("SECTION 11: Security - session_user and current_database") { + // Test session_user + sql("CREATE TEMPORARY FUNCTION session_user() RETURNS STRING RETURN 'fake_user'") + // Should be builtin, not temp + checkAnswer(sql("SELECT session_user() IS NOT NULL"), Row(true)) + sql("DROP TEMPORARY FUNCTION session_user") + + // Test current_database + sql("CREATE TEMPORARY FUNCTION current_database() RETURNS STRING RETURN 'fake_db'") + // Should be builtin, not temp + checkAnswer(sql("SELECT current_database() IS NOT NULL"), Row(true)) + sql("DROP TEMPORARY FUNCTION current_database") + } + + // ============================================================================ + // SECTION 12: Extension Function Tests + // ============================================================================ + // These tests verify the system.extension namespace and extension function behavior. + // Extension functions are registered programmatically via SparkSessionExtensions + // (e.g., Apache Sedona, Delta Lake) and cannot be tested via SQL golden files. + + test("SECTION 12: Extension - function can be called unqualified") { + // Extension function registered in TestExtensions + checkAnswer(sql("SELECT test_ext_func()"), Row(9999)) + } + + test("SECTION 12: Extension - resolution order: extension > builtin > session") { + // Test the critical security property: extension comes before builtin before session + sql("CREATE TEMPORARY FUNCTION session_func() RETURNS INT RETURN 1111") + + // Unqualified: resolves to session (no extension or builtin with this name) + checkAnswer(sql("SELECT session_func()"), Row(1111)) + + // Qualified + checkAnswer(sql("SELECT session.session_func()"), Row(1111)) + + // Cleanup + sql("DROP TEMPORARY FUNCTION session_func") + } + + test("SECTION 12: Extension - security property: temp cannot shadow current_user") { + // This test is already covered in SECTION 11, but we verify it works with extensions loaded + val actualUser = sql("SELECT current_user()").collect().head.getString(0) + + sql("CREATE TEMPORARY FUNCTION current_user() RETURNS STRING RETURN 'hacker'") + + // Unqualified call should still resolve to builtin (security!) + val unqualifiedResult = sql("SELECT current_user()").collect().head.getString(0) + assert(unqualifiedResult == actualUser, + s"Built-in current_user() was shadowed! Got '$unqualifiedResult', expected '$actualUser'") + + // But we can access the temp function via qualification + checkAnswer(sql("SELECT session.current_user()"), Row("hacker")) + + sql("DROP TEMPORARY FUNCTION current_user") + } + + test("SECTION 12: Extension - SHOW FUNCTIONS includes extension functions") { + val functions = sql("SHOW FUNCTIONS").collect().map(_.getString(0)) + assert(functions.contains("test_ext_func"), + s"Extension function test_ext_func not found in SHOW FUNCTIONS output") + } + + test("SECTION 12: Extension - DESCRIBE FUNCTION works") { + // Unqualified - extension functions should be describable + val desc = sql("DESCRIBE FUNCTION test_ext_func").collect() + assert(desc.nonEmpty, "DESCRIBE FUNCTION should return results for extension functions") + } +} + +/** + * Test extension that registers mock extension functions. + * This simulates what real extensions like Apache Sedona would do. + */ +class TestExtensions extends (SparkSessionExtensions => Unit) { + override def apply(extensions: SparkSessionExtensions): Unit = { + // Register a mock extension function + // Use the full 11-parameter ExpressionInfo constructor + extensions.injectFunction( + (org.apache.spark.sql.catalyst.FunctionIdentifier("test_ext_func"), + new ExpressionInfo( + "org.apache.spark.sql.FunctionQualificationSuite", // className + "", // db + "test_ext_func", // name + "Returns 9999 for testing", // usage + "", // arguments + "", // examples + "", // note + "", // group + "4.2.0", // since + "", // deprecated + ""), // source (empty is allowed) + (exprs: Seq[Expression]) => Literal(9999, IntegerType)) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 540ca2b1ec887..5e36fe348501a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -80,6 +80,9 @@ class SessionStateSuite extends SparkFunSuite { test("fork new session and inherit function registry and udf") { val testFuncName1 = FunctionIdentifier("strlenScala") val testFuncName2 = FunctionIdentifier("addone") + // Temporary functions registered via spark.udf.register are stored with database="session" + val testFuncName1Qualified = FunctionIdentifier(testFuncName1.funcName, Some("session")) + val testFuncName2Qualified = FunctionIdentifier(testFuncName2.funcName, Some("session")) try { activeSession.udf.register(testFuncName1.funcName, (_: String).length + (_: Int)) val forkedSession = activeSession.cloneSession() @@ -88,16 +91,19 @@ class SessionStateSuite extends SparkFunSuite { assert(forkedSession ne activeSession) assert(forkedSession.sessionState.functionRegistry ne activeSession.sessionState.functionRegistry) - assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty) + assert(forkedSession.sessionState.functionRegistry + .lookupFunction(testFuncName1Qualified).nonEmpty) // independence - forkedSession.sessionState.functionRegistry.dropFunction(testFuncName1) - assert(activeSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty) + forkedSession.sessionState.functionRegistry.dropFunction(testFuncName1Qualified) + assert(activeSession.sessionState.functionRegistry + .lookupFunction(testFuncName1Qualified).nonEmpty) activeSession.udf.register(testFuncName2.funcName, (_: Int) + 1) - assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName2).isEmpty) + assert(forkedSession.sessionState.functionRegistry + .lookupFunction(testFuncName2Qualified).isEmpty) } finally { - activeSession.sessionState.functionRegistry.dropFunction(testFuncName1) - activeSession.sessionState.functionRegistry.dropFunction(testFuncName2) + activeSession.sessionState.functionRegistry.dropFunction(testFuncName1Qualified) + activeSession.sessionState.functionRegistry.dropFunction(testFuncName2Qualified) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 66826a9ca762a..231b927d994b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -172,8 +172,12 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt extensions.injectFunction(MyExtensions.myFunction) } withSession(extensions) { session => - assert(session.sessionState.functionRegistry - .lookupFunction(MyExtensions.myFunction._1).isDefined) + // Extension functions are registered with database="extension" to enable security + // (extension functions resolve before built-ins, which resolve before session temp functions) + val qualifiedIdent = FunctionIdentifier( + MyExtensions.myFunction._1.funcName, + Some("extension")) + assert(session.sessionState.functionRegistry.lookupFunction(qualifiedIdent).isDefined) } } @@ -380,8 +384,10 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session))) assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) assert(session.sessionState.sqlParser.isInstanceOf[MyParser]) - assert(session.sessionState.functionRegistry - .lookupFunction(MyExtensions.myFunction._1).isDefined) + // Extension functions are registered with database="extension" + val qualifiedIdent = FunctionIdentifier( + MyExtensions.myFunction._1.funcName, Some("extension")) + assert(session.sessionState.functionRegistry.lookupFunction(qualifiedIdent).isDefined) assert(session.sessionState.columnarRules.contains( MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))) } finally { @@ -408,10 +414,13 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt assert(session.sessionState.optimizer.batches.flatMap(_.rules).filter(orderedRules.contains) .containsSlice(orderedRules ++ orderedRules)) // The optimizer rules are duplicated assert(session.sessionState.sqlParser === parser) - assert(session.sessionState.functionRegistry - .lookupFunction(MyExtensions.myFunction._1).isDefined) - assert(session.sessionState.functionRegistry - .lookupFunction(MyExtensions2.myFunction._1).isDefined) + // Extension functions are registered with database="extension" + val qualifiedIdent1 = FunctionIdentifier( + MyExtensions.myFunction._1.funcName, Some("extension")) + val qualifiedIdent2 = FunctionIdentifier( + MyExtensions2.myFunction._1.funcName, Some("extension")) + assert(session.sessionState.functionRegistry.lookupFunction(qualifiedIdent1).isDefined) + assert(session.sessionState.functionRegistry.lookupFunction(qualifiedIdent2).isDefined) } finally { stop(session) } @@ -437,8 +446,10 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt val outerParser = session.sessionState.sqlParser assert(outerParser.isInstanceOf[MyParser]) assert(outerParser.asInstanceOf[MyParser].delegate.isInstanceOf[MyParser]) - assert(session.sessionState.functionRegistry - .lookupFunction(MyExtensions.myFunction._1).isDefined) + // Extension functions are registered with database="extension" + val qualifiedIdent = FunctionIdentifier( + MyExtensions.myFunction._1.funcName, Some("extension")) + assert(session.sessionState.functionRegistry.lookupFunction(qualifiedIdent).isDefined) } finally { stop(session) } @@ -452,8 +463,9 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt classOf[MyExtensions2Duplicate].getCanonicalName).mkString(",")) .getOrCreate() try { + // Extension functions are registered with database="extension" val lastRegistered = session.sessionState.functionRegistry - .lookupFunction(FunctionIdentifier("myFunction2")) + .lookupFunction(FunctionIdentifier("myFunction2", Some("extension"))) assert(lastRegistered.isDefined) assert(lastRegistered.get !== MyExtensions2.myFunction._2) assert(lastRegistered.get === MyExtensions2Duplicate.myFunction._2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala index a58e2d9a8a500..07ca9c6140d29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala @@ -322,7 +322,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL stop = 141)) } - test("INVALID_SQL_SYNTAX.MULTI_PART_NAME: Create temporary function with multi-part name") { + test("INVALID_TEMP_OBJ_QUALIFIER: Create temporary function with invalid multi-part name") { val sqlText = """CREATE TEMPORARY FUNCTION ns.db.func as |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', @@ -330,19 +330,16 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL checkError( exception = parseException(sqlText), - condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", - sqlState = "42000", + condition = "INVALID_TEMP_OBJ_QUALIFIER", + sqlState = "42602", parameters = Map( - "statement" -> "CREATE TEMPORARY FUNCTION", - "name" -> "`ns`.`db`.`func`"), - context = ExpectedContext( - fragment = sqlText, - start = 0, - stop = 132)) + "objectType" -> "FUNCTION", + "objectName" -> "`func`", + "qualifier" -> "`ns`.`db`")) } - test("INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_DATABASE: " + - "Specifying database while creating temporary function") { + test("INVALID_TEMP_OBJ_QUALIFIER: " + + "Specifying invalid database while creating temporary function") { val sqlText = """CREATE TEMPORARY FUNCTION db.func as |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', @@ -350,28 +347,24 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL checkError( exception = parseException(sqlText), - condition = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_DATABASE", - sqlState = "42000", - parameters = Map("database" -> "`db`"), - context = ExpectedContext( - fragment = sqlText, - start = 0, - stop = 129)) + condition = "INVALID_TEMP_OBJ_QUALIFIER", + sqlState = "42602", + parameters = Map( + "objectType" -> "FUNCTION", + "objectName" -> "`func`", + "qualifier" -> "`db`")) } - test("INVALID_SQL_SYNTAX.MULTI_PART_NAME: Drop temporary function requires a single part name") { + test("INVALID_TEMP_OBJ_QUALIFIER: Drop temporary function with invalid qualification") { val sqlText = "DROP TEMPORARY FUNCTION db.func" checkError( exception = parseException(sqlText), - condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", - sqlState = "42000", + condition = "INVALID_TEMP_OBJ_QUALIFIER", + sqlState = "42602", parameters = Map( - "statement" -> "DROP TEMPORARY FUNCTION", - "name" -> "`db`.`func`"), - context = ExpectedContext( - fragment = sqlText, - start = 0, - stop = 30)) + "objectType" -> "FUNCTION", + "objectName" -> "`func`", + "qualifier" -> "`db`")) } test("DUPLICATE_KEY: Found duplicate partition keys") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala index 75b42c6440719..575017d2321f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.command +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.catalog.LanguageSQL @@ -113,25 +114,28 @@ class CreateSQLFunctionParserSuite extends AnalysisTest { parser.parsePlan("CREATE OR REPLACE TEMPORARY FUNCTION a() RETURNS INT RETURN 1"), createSQLFunctionCommand("a", exprText = Some("1"), replace = true)) - checkParseError( - "CREATE TEMPORARY FUNCTION a.b() RETURNS INT RETURN 1", - errorClass = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_DATABASE", - parameters = Map("database" -> "`a`"), - queryContext = Array( - ExpectedContext("CREATE TEMPORARY FUNCTION a.b() RETURNS INT RETURN 1", 0, 51) - ) - ) - - checkParseError( - "CREATE TEMPORARY FUNCTION a.b.c() RETURNS INT RETURN 1", - errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", + // Now throws an AnalysisException (semantic error) instead of ParseException + val e1 = intercept[AnalysisException] { + parser.parsePlan("CREATE TEMPORARY FUNCTION a.b() RETURNS INT RETURN 1") + } + checkError( + exception = e1, + condition = "INVALID_TEMP_OBJ_QUALIFIER", parameters = Map( - "statement" -> "CREATE TEMPORARY FUNCTION", - "name" -> "`a`.`b`.`c`"), - queryContext = Array( - ExpectedContext("CREATE TEMPORARY FUNCTION a.b.c() RETURNS INT RETURN 1", 0, 53) - ) - ) + "objectType" -> "FUNCTION", + "objectName" -> "`b`", + "qualifier" -> "`a`")) + + val e2 = intercept[AnalysisException] { + parser.parsePlan("CREATE TEMPORARY FUNCTION a.b.c() RETURNS INT RETURN 1") + } + checkError( + exception = e2, + condition = "INVALID_TEMP_OBJ_QUALIFIER", + parameters = Map( + "objectType" -> "FUNCTION", + "objectName" -> "`c`", + "qualifier" -> "`a`.`b`")) checkParseError( "CREATE TEMPORARY FUNCTION IF NOT EXISTS a() RETURNS INT RETURN 1", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 1561336fdfa39..fe8e033f920b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.SparkThrowable +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, GlobalTempView, LocalTempView, SchemaCompensation, UnresolvedAttribute, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.catalog.{ArchiveResource, FileResource, FunctionResource, JarResource} import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -684,23 +685,23 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql1 = "DROP TEMPORARY FUNCTION a.b" checkError( - exception = parseException(sql1), - condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", - parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "name" -> "`a`.`b`"), - context = ExpectedContext( - fragment = sql1, - start = 0, - stop = 26)) + exception = intercept[AnalysisException](parser.parsePlan(sql1)), + condition = "INVALID_TEMP_OBJ_QUALIFIER", + parameters = Map( + "objectType" -> "FUNCTION", + "objectName" -> "`b`", + "qualifier" -> "`a`"), + queryContext = Array.empty) val sql2 = "DROP TEMPORARY FUNCTION IF EXISTS a.b" checkError( - exception = parseException(sql2), - condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", - parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "name" -> "`a`.`b`"), - context = ExpectedContext( - fragment = sql2, - start = 0, - stop = 36)) + exception = intercept[AnalysisException](parser.parsePlan(sql2)), + condition = "INVALID_TEMP_OBJ_QUALIFIER", + parameters = Map( + "objectType" -> "FUNCTION", + "objectName" -> "`b`", + "qualifier" -> "`a`"), + queryContext = Array.empty) } test("SPARK-32374: create temporary view with properties not allowed") {