Skip to content

Commit

Permalink
[SPARK-48782][SQL] Add support for procedures in catalogs
Browse files Browse the repository at this point in the history
  • Loading branch information
aokolnychyi committed Aug 30, 2024
1 parent 1e67659 commit 8905e1e
Show file tree
Hide file tree
Showing 26 changed files with 972 additions and 17 deletions.
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1439,6 +1439,12 @@
],
"sqlState" : "54001"
},
"FAILED_TO_LOAD_ROUTINE" : {
"message" : [
"Failed to load routine <routineName>."
],
"sqlState" : "38000"
},
"FEATURE_NOT_ENABLED" : {
"message" : [
"The feature <featureName> is not enabled. Consider setting the config <configKey> to <configValue> to enable this capability."
Expand Down
1 change: 1 addition & 0 deletions docs/sql-ref-ansi-compliance.md
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ Below is a list of all the keywords in Spark SQL.
|BY|non-reserved|non-reserved|reserved|
|BYTE|non-reserved|non-reserved|non-reserved|
|CACHE|non-reserved|non-reserved|non-reserved|
|CALL|reserved|non-reserved|reserved|
|CALLED|non-reserved|non-reserved|non-reserved|
|CASCADE|non-reserved|non-reserved|non-reserved|
|CASE|reserved|non-reserved|reserved|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ BUCKETS: 'BUCKETS';
BY: 'BY';
BYTE: 'BYTE';
CACHE: 'CACHE';
CALL: 'CALL';
CALLED: 'CALLED';
CASCADE: 'CASCADE';
CASE: 'CASE';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ statement
LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN
(OPTIONS options=propertyList)? #createIndex
| DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex
| CALL identifierReference
LEFT_PAREN
(functionArgument (COMMA functionArgument)*)?
RIGHT_PAREN #call
| unsupportedHiveNativeCommands .*? #failNativeCommand
;

Expand Down Expand Up @@ -1802,6 +1806,7 @@ nonReserved
| BY
| BYTE
| CACHE
| CALL
| CALLED
| CASCADE
| CASE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
*/
@Evolving
public interface ProcedureParameter {
/**
* A field metadata key that indicates whether an argument is passed by name.
*/
String BY_NAME_METADATA_KEY = "BY_NAME";

/**
* Creates a builder for an IN procedure parameter.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ public interface UnboundProcedure extends Procedure {
* validate if the input types are compatible while binding or delegate that to Spark. Regardless,
* Spark will always perform the final validation of the arguments and rearrange them as needed
* based on {@link BoundProcedure#parameters() reported parameters}.
* <p>
* The provided {@code inputType} is based on the procedure arguments. If an argument is passed
* by name, its metadata will indicate this with {@link ProcedureParameter#BY_NAME_METADATA_KEY}
* set to {@code true}. In such cases, the field name will match the name of the target procedure
* parameter. If the argument is not named, {@link ProcedureParameter#BY_NAME_METADATA_KEY} will
* be set to {@code false} and the name will be assigned randomly.
*
* @param inputType the input types to bind to
* @return the bound procedure that is most suitable for the given input types
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Random, Success, Try}

import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
Expand All @@ -50,6 +50,7 @@ import org.apache.spark.sql.connector.catalog.{View => _, _}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
Expand Down Expand Up @@ -310,6 +311,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ExtractGenerator ::
ResolveGenerate ::
ResolveFunctions ::
ResolveProcedures ::
BindProcedures ::
ResolveTableSpec ::
ResolveAliases ::
ResolveSubquery ::
Expand Down Expand Up @@ -2611,6 +2614,73 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}

/**
* A rule that resolves procedures.
*/
object ResolveProcedures extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
_.containsPattern(UNRESOLVED_PROCEDURE), ruleId) {
case Call(UnresolvedProcedure(CatalogAndIdentifier(catalog, ident)), args) =>
val procedureCatalog = catalog.asProcedureCatalog
val procedure = load(procedureCatalog, ident)
Call(ResolvedProcedure(procedureCatalog, ident, procedure), args)
}

private def load(catalog: ProcedureCatalog, ident: Identifier): UnboundProcedure = {
try {
catalog.loadProcedure(ident)
} catch {
case e: SparkThrowable =>
throw e
case e: Exception =>
val nameParts = catalog.name +: ident.asMultipartIdentifier
throw QueryCompilationErrors.failedToLoadRoutineError(nameParts, e)
}
}
}

/**
* A rule that binds procedures to the input types and rearranges arguments as needed.
*/
object BindProcedures extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Call(ResolvedProcedure(catalog, ident, unboundProcedure: UnboundProcedure), args)
if args.forall(_.resolved) =>
val inputType = extractInputType(args)
val boundProcedure = unboundProcedure.bind(inputType)
validateParameterModes(boundProcedure)
val alignedArgs = NamedParametersSupport.defaultRearrange(boundProcedure, args)
Call(ResolvedProcedure(catalog, ident, boundProcedure), alignedArgs)
}

private def extractInputType(args: Seq[Expression]): StructType = {
val fields = args.zipWithIndex.map { case (arg, index) =>
arg match {
case NamedArgumentExpression(name, value) =>
val metadata = argMetadata(byName = true)
StructField(name, value.dataType, value.nullable, metadata)
case _ =>
val name = s"param$index"
val metadata = argMetadata(byName = false)
StructField(name, arg.dataType, arg.nullable, metadata)
}
}
StructType(fields)
}

private def argMetadata(byName: Boolean): Metadata = {
new MetadataBuilder()
.putBoolean(ProcedureParameter.BY_NAME_METADATA_KEY, byName)
.build()
}

private def validateParameterModes(procedure: BoundProcedure): Unit = {
procedure.parameters.find(_.mode != ProcedureParameter.Mode.IN).foreach { param =>
throw SparkException.internalError(s"Unsupported parameter mode: ${param.mode}")
}
}
}

/**
* This rule resolves and rewrites subqueries inside expressions.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
override def typeCoercionRules: List[Rule[LogicalPlan]] =
UnpivotCoercion ::
WidenSetOperationTypes ::
ProcedureArgumentCoercion ::
new AnsiCombinedTypeCoercionRule(
CollationTypeCasts ::
InConversion ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, PLAN_EXPRESSION, UNRESOLVED_WINDOW_EXPRESSION}
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils, TypeUtils}
import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement}
import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -673,6 +674,18 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
varName,
c.defaultExpr.originalSQL)

case call @ Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args)
if call.resolved =>
val inputTypes = procedure.parameters.map(_.dataType).toSeq
ExpectsInputTypes.checkInputDataTypes(args, inputTypes) match {
case TypeCheckResult.TypeCheckSuccess =>
// OK
case mismatch: TypeCheckResult.DataTypeMismatch =>
call.dataTypeMismatch("CALL", mismatch)
case _ =>
SparkException.internalError("Invalid input for procedure")
}

case _ => // Falls back to the following checks
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeAnyCollation}
Expand Down Expand Up @@ -202,6 +203,20 @@ abstract class TypeCoercionBase {
}
}

/**
* A type coercion rule that implicitly casts procedure arguments to expected types.
*/
object ProcedureArgumentCoercion extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case call @ Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args) if call.resolved =>
val expectedDataTypes = procedure.parameters.map(_.dataType)
val coercedArgs = args.zip(expectedDataTypes).map {
case (arg, expectedType) => implicitCast(arg, expectedType).getOrElse(arg)
}
call.copy(args = coercedArgs)
}
}

/**
* Widens the data types of the [[Unpivot]] values.
*/
Expand Down Expand Up @@ -838,6 +853,7 @@ object TypeCoercion extends TypeCoercionBase {
override def typeCoercionRules: List[Rule[LogicalPlan]] =
UnpivotCoercion ::
WidenSetOperationTypes ::
ProcedureArgumentCoercion ::
new CombinedTypeCoercionRule(
CollationTypeCasts ::
InConversion ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,13 @@ package object analysis {
}

def dataTypeMismatch(expr: Expression, mismatch: DataTypeMismatch): Nothing = {
dataTypeMismatch(toSQLExpr(expr), mismatch)
}

def dataTypeMismatch(sqlExpr: String, mismatch: DataTypeMismatch): Nothing = {
throw new AnalysisException(
errorClass = s"DATATYPE_MISMATCH.${mismatch.errorSubClass}",
messageParameters = mismatch.messageParameters + ("sqlExpr" -> toSQLExpr(expr)),
messageParameters = mismatch.messageParameters + ("sqlExpr" -> sqlExpr),
origin = t.origin)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, LeafExpression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC}
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC, UNRESOLVED_PROCEDURE}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, Table, TableCatalog}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, ProcedureCatalog, Table, TableCatalog}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.catalog.procedures.Procedure
import org.apache.spark.sql.types.{DataType, StructField}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -135,6 +136,12 @@ case class UnresolvedFunctionName(
case class UnresolvedIdentifier(nameParts: Seq[String], allowTemp: Boolean = false)
extends UnresolvedLeafNode

/**
* A procedure identifier that should be resolved into [[ResolvedProcedure]].
*/
case class UnresolvedProcedure(nameParts: Seq[String]) extends UnresolvedLeafNode {
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_PROCEDURE)
}

/**
* A resolved leaf node whose statistics has no meaning.
Expand Down Expand Up @@ -192,6 +199,12 @@ case class ResolvedFieldName(path: Seq[String], field: StructField) extends Fiel

case class ResolvedFieldPosition(position: ColumnPosition) extends FieldPosition

case class ResolvedProcedure(
catalog: ProcedureCatalog,
ident: Identifier,
procedure: Procedure) extends LeafNodeWithoutStats {
override def output: Seq[Attribute] = Nil
}

/**
* A plan containing resolved persistent views.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5491,6 +5491,28 @@ class AstBuilder extends DataTypeAstBuilder
ctx.EXISTS != null)
}

/**
* Creates a plan for invoking a procedure.
*
* For example:
* {{{
* CALL multi_part_name(v1, v2, ...);
* CALL multi_part_name(v1, param2 => v2, ...);
* CALL multi_part_name(param1 => v1, param2 => v2, ...);
* }}}
*/
override def visitCall(ctx: CallContext): LogicalPlan = withOrigin(ctx) {
val procedure = withIdentClause(ctx.identifierReference, UnresolvedProcedure)
val args = ctx.functionArgument.asScala.map {
case expr if expr.namedArgumentExpression != null =>
val namedExpr = expr.namedArgumentExpression
NamedArgumentExpression(namedExpr.key.getText, expression(namedExpr.value))
case expr =>
expression(expr)
}.toSeq
Call(procedure, args)
}

/**
* Create a TimestampAdd expression.
*/
Expand Down
Loading

0 comments on commit 8905e1e

Please sign in to comment.