Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48782][SQL] Add support for executing procedures in catalogs #47943

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1444,6 +1444,12 @@
],
"sqlState" : "2203G"
},
"FAILED_TO_LOAD_ROUTINE" : {
"message" : [
"Failed to load routine <routineName>."
],
"sqlState" : "38000"
},
"FAILED_TO_PARSE_TOO_COMPLEX" : {
"message" : [
"The statement, including potential SQL functions and referenced views, was too complex to parse.",
Expand Down
1 change: 1 addition & 0 deletions docs/sql-ref-ansi-compliance.md
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ Below is a list of all the keywords in Spark SQL.
|BY|non-reserved|non-reserved|reserved|
|BYTE|non-reserved|non-reserved|non-reserved|
|CACHE|non-reserved|non-reserved|non-reserved|
|CALL|reserved|non-reserved|reserved|
|CALLED|non-reserved|non-reserved|non-reserved|
|CASCADE|non-reserved|non-reserved|non-reserved|
|CASE|reserved|non-reserved|reserved|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ BUCKETS: 'BUCKETS';
BY: 'BY';
BYTE: 'BYTE';
CACHE: 'CACHE';
CALL: 'CALL';
CALLED: 'CALLED';
CASCADE: 'CASCADE';
CASE: 'CASE';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ statement
LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN
(OPTIONS options=propertyList)? #createIndex
| DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex
| CALL identifierReference
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can split this into a separate PR, if needed.

LEFT_PAREN
(functionArgument (COMMA functionArgument)*)?
RIGHT_PAREN #call
| unsupportedHiveNativeCommands .*? #failNativeCommand
;

Expand Down Expand Up @@ -1851,6 +1855,7 @@ nonReserved
| BY
| BYTE
| CACHE
| CALL
| CALLED
| CASCADE
| CASE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
*/
@Evolving
public interface ProcedureParameter {
/**
* A field metadata key that indicates whether an argument is passed by name.
*/
String BY_NAME_METADATA_KEY = "BY_NAME";

/**
* Creates a builder for an IN procedure parameter.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ public interface UnboundProcedure extends Procedure {
* validate if the input types are compatible while binding or delegate that to Spark. Regardless,
* Spark will always perform the final validation of the arguments and rearrange them as needed
* based on {@link BoundProcedure#parameters() reported parameters}.
* <p>
* The provided {@code inputType} is based on the procedure arguments. If an argument is passed
* by name, its metadata will indicate this with {@link ProcedureParameter#BY_NAME_METADATA_KEY}
* set to {@code true}. In such cases, the field name will match the name of the target procedure
* parameter. If the argument is not named, {@link ProcedureParameter#BY_NAME_METADATA_KEY} will
* be set to {@code false} and the name will be assigned randomly.
*
* @param inputType the input types to bind to
* @return the bound procedure that is most suitable for the given input types
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Random, Success, Try}

import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
Expand All @@ -50,6 +50,7 @@ import org.apache.spark.sql.connector.catalog.{View => _, _}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
Expand Down Expand Up @@ -310,6 +311,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ExtractGenerator ::
ResolveGenerate ::
ResolveFunctions ::
ResolveProcedures ::
BindProcedures ::
ResolveTableSpec ::
ResolveAliases ::
ResolveSubquery ::
Expand Down Expand Up @@ -2611,6 +2614,71 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}

/**
* A rule that resolves procedures.
*/
object ResolveProcedures extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
_.containsPattern(UNRESOLVED_PROCEDURE), ruleId) {
case Call(UnresolvedProcedure(CatalogAndIdentifier(catalog, ident)), args, execute) =>
val procedureCatalog = catalog.asProcedureCatalog
val procedure = load(procedureCatalog, ident)
Call(ResolvedProcedure(procedureCatalog, ident, procedure), args, execute)
}

private def load(catalog: ProcedureCatalog, ident: Identifier): UnboundProcedure = {
try {
catalog.loadProcedure(ident)
} catch {
case e: Exception if !e.isInstanceOf[SparkThrowable] =>
val nameParts = catalog.name +: ident.asMultipartIdentifier
throw QueryCompilationErrors.failedToLoadRoutineError(nameParts, e)
}
}
}

/**
* A rule that binds procedures to the input types and rearranges arguments as needed.
*/
object BindProcedures extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case Call(ResolvedProcedure(catalog, ident, unbound: UnboundProcedure), args, execute)
if args.forall(_.resolved) =>
val inputType = extractInputType(args)
val bound = unbound.bind(inputType)
validateParameterModes(bound)
val rearrangedArgs = NamedParametersSupport.defaultRearrange(bound, args)
Call(ResolvedProcedure(catalog, ident, bound), rearrangedArgs, execute)
}

private def extractInputType(args: Seq[Expression]): StructType = {
val fields = args.zipWithIndex.map { case (arg, index) =>
arg match {
case NamedArgumentExpression(name, value) =>
val metadata = argMetadata(byName = true)
StructField(name, value.dataType, value.nullable, metadata)
case _ =>
val name = s"param$index"
val metadata = argMetadata(byName = false)
StructField(name, arg.dataType, arg.nullable, metadata)
}
}
StructType(fields)
}

private def argMetadata(byName: Boolean): Metadata = {
new MetadataBuilder()
.putBoolean(ProcedureParameter.BY_NAME_METADATA_KEY, byName)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we omit this metadata if the arg is not by name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, skipped.

.build()
}

private def validateParameterModes(procedure: BoundProcedure): Unit = {
procedure.parameters.find(_.mode != ProcedureParameter.Mode.IN).foreach { param =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we planning to support more parameter modes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, yes. There is no active work at the moment, as far as I know.

throw SparkException.internalError(s"Unsupported parameter mode: ${param.mode}")
}
}
}

/**
* This rule resolves and rewrites subqueries inside expressions.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
override def typeCoercionRules: List[Rule[LogicalPlan]] =
UnpivotCoercion ::
WidenSetOperationTypes ::
ProcedureArgumentCoercion ::
new AnsiCombinedTypeCoercionRule(
CollationTypeCasts ::
InConversion ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
varName,
c.defaultExpr.originalSQL)

case c: Call if c.resolved && c.bound && c.checkArgTypes().isFailure =>
c.checkArgTypes() match {
case mismatch: TypeCheckResult.DataTypeMismatch =>
c.dataTypeMismatch("CALL", mismatch)
case _ =>
SparkException.internalError("Invalid input for procedure")
}

case _ => // Falls back to the following checks
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeAnyCollation}
Expand Down Expand Up @@ -202,6 +203,20 @@ abstract class TypeCoercionBase {
}
}

/**
* A type coercion rule that implicitly casts procedure arguments to expected types.
*/
object ProcedureArgumentCoercion extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case c @ Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args, _) if c.resolved =>
val expectedDataTypes = procedure.parameters.map(_.dataType)
val coercedArgs = args.zip(expectedDataTypes).map {
case (arg, expectedType) => implicitCast(arg, expectedType).getOrElse(arg)
}
c.copy(args = coercedArgs)
}
}

/**
* Widens the data types of the [[Unpivot]] values.
*/
Expand Down Expand Up @@ -838,6 +853,7 @@ object TypeCoercion extends TypeCoercionBase {
override def typeCoercionRules: List[Rule[LogicalPlan]] =
UnpivotCoercion ::
WidenSetOperationTypes ::
ProcedureArgumentCoercion ::
new CombinedTypeCoercionRule(
CollationTypeCasts ::
InConversion ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,13 @@ package object analysis {
}

def dataTypeMismatch(expr: Expression, mismatch: DataTypeMismatch): Nothing = {
dataTypeMismatch(toSQLExpr(expr), mismatch)
}

def dataTypeMismatch(sqlExpr: String, mismatch: DataTypeMismatch): Nothing = {
throw new AnalysisException(
errorClass = s"DATATYPE_MISMATCH.${mismatch.errorSubClass}",
messageParameters = mismatch.messageParameters + ("sqlExpr" -> toSQLExpr(expr)),
messageParameters = mismatch.messageParameters + ("sqlExpr" -> sqlExpr),
origin = t.origin)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, LeafExpression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC}
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC, UNRESOLVED_PROCEDURE}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, Table, TableCatalog}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, ProcedureCatalog, Table, TableCatalog}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.catalog.procedures.Procedure
import org.apache.spark.sql.types.{DataType, StructField}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -135,6 +136,12 @@ case class UnresolvedFunctionName(
case class UnresolvedIdentifier(nameParts: Seq[String], allowTemp: Boolean = false)
extends UnresolvedLeafNode

/**
* A procedure identifier that should be resolved into [[ResolvedProcedure]].
*/
case class UnresolvedProcedure(nameParts: Seq[String]) extends UnresolvedLeafNode {
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_PROCEDURE)
}

/**
* A resolved leaf node whose statistics has no meaning.
Expand Down Expand Up @@ -192,6 +199,12 @@ case class ResolvedFieldName(path: Seq[String], field: StructField) extends Fiel

case class ResolvedFieldPosition(position: ColumnPosition) extends FieldPosition

case class ResolvedProcedure(
catalog: ProcedureCatalog,
ident: Identifier,
procedure: Procedure) extends LeafNodeWithoutStats {
override def output: Seq[Attribute] = Nil
}

/**
* A plan containing resolved persistent views.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5697,6 +5697,28 @@ class AstBuilder extends DataTypeAstBuilder
ctx.EXISTS != null)
}

/**
* Creates a plan for invoking a procedure.
*
* For example:
* {{{
* CALL multi_part_name(v1, v2, ...);
* CALL multi_part_name(v1, param2 => v2, ...);
* CALL multi_part_name(param1 => v1, param2 => v2, ...);
* }}}
*/
override def visitCall(ctx: CallContext): LogicalPlan = withOrigin(ctx) {
val procedure = withIdentClause(ctx.identifierReference, UnresolvedProcedure)
val args = ctx.functionArgument.asScala.map {
case expr if expr.namedArgumentExpression != null =>
val namedExpr = expr.namedArgumentExpression
NamedArgumentExpression(namedExpr.key.getText, expression(namedExpr.value))
case expr =>
expression(expr)
}.toSeq
Call(procedure, args)
}

/**
* Create a TimestampAdd expression.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.plans.logical

/**
* A logical plan node that requires execution during analysis.
*/
trait ExecutableDuringAnalysis extends LogicalPlan {
/**
* Returns the logical plan node that should be used for EXPLAIN.
*/
def stageForExplain(): LogicalPlan
}
Loading