diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 660f523698e7..11a7dd2d4ca3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2102,6 +2102,10 @@ object CleanupAliases extends Rule[LogicalPlan] { Window(cleanedWindowExprs, partitionSpec.map(trimAliases), orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) + case InlineTable(rows) => + val cleanedRows = rows.map(_.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])) + InlineTable(cleanedRows) + // Operators that operate on objects should only have expressions from encoders, which should // never have extra aliases. case o: ObjectConsumer => o diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 41b7e62d8cce..cd390eebe7a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -298,6 +298,26 @@ trait CheckAnalysis extends PredicateHelper { } } + case InlineTable(rows) if rows.length > 1 => + val expectedDataTypes = rows.head.map(_.dataType) + rows.zipWithIndex.tail.foreach { case (row, ri) => + // Check the number of columns. + if (row.length != expectedDataTypes.length) { + failAnalysis( + s"An inline table must have the same number of columns on every row. " + + s"Row '${ri + 1}' has '${row.length}' columns while " + + s"'${expectedDataTypes.length}' columns were expected.") + } + // Check the data + row.map(_.dataType).zip(expectedDataTypes).zipWithIndex.collect { + case ((dt1, dt2), ci) if dt1 != dt2 => + failAnalysis( + s"Data type '$dt1' of column '${rows.head(ci).name}' at row '${ri + 1}' " + + s"does not match the expected data type '$dt2' for that column. " + + s"Expressions for an inline table's column must have the same data type.") + } + } + case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 021952e7166f..8ce44dc0c39b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -243,6 +243,14 @@ object TypeCoercion { s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) s.makeCopy(Array(newChildren)) + + case s @ InlineTable(rows) if !s.resolved && s.expressionsResolved && s.validDimensions => + val targetTypes = getWidestTypes(rows.map(_.map(_.dataType))) + if (targetTypes.nonEmpty) { + s.copy(rows = rows.map(widenTypes(_, targetTypes))) + } else { + s + } } /** Build new children with the widest types for each attribute among all the children */ @@ -251,12 +259,11 @@ object TypeCoercion { // Get a sequence of data types, each of which is the widest type of this specific attribute // in all the children - val targetTypes: Seq[DataType] = - getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]()) + val targetTypes: Seq[DataType] = getWidestTypes(children.map(_.output.map(_.dataType))) if (targetTypes.nonEmpty) { // Add an extra Project if the targetTypes are different from the original types. - children.map(widenTypes(_, targetTypes)) + children.map(child => Project(widenTypes(child.output, targetTypes), child)) } else { // Unable to find a target type to widen, then just return the original set. children @@ -265,30 +272,31 @@ object TypeCoercion { /** Get the widest type for each attribute in all the children */ @tailrec private def getWidestTypes( - children: Seq[LogicalPlan], - attrIndex: Int, - castedTypes: mutable.Queue[DataType]): Seq[DataType] = { + dataTypes: Seq[Seq[DataType]], + attrIndex: Int = 0, + castedTypes: mutable.Queue[DataType] = mutable.Queue.empty): Seq[DataType] = { // Return the result after the widen data types have been found for all the children - if (attrIndex >= children.head.output.length) return castedTypes.toSeq + if (attrIndex >= dataTypes.head.length) return castedTypes // For the attrIndex-th attribute, find the widest type - findWiderCommonType(children.map(_.output(attrIndex).dataType)) match { + findWiderCommonType(dataTypes.map(_(attrIndex))) match { // If unable to find an appropriate widen type for this column, return an empty Seq case None => Seq.empty[DataType] // Otherwise, record the result in the queue and find the type for the next column case Some(widenType) => castedTypes.enqueue(widenType) - getWidestTypes(children, attrIndex + 1, castedTypes) + getWidestTypes(dataTypes, attrIndex + 1, castedTypes) } } - /** Given a plan, add an extra project on top to widen some columns' data types. */ - private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = { - val casted = plan.output.zip(targetTypes).map { - case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() + /** Cast the expressions to the given dataTypes (if we need to). */ + private def widenTypes( + expressions: Seq[NamedExpression], + targetTypes: Seq[DataType]): Seq[NamedExpression] = { + expressions.zip(targetTypes).map { + case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.toString)() case (e, _) => e } - Project(casted, plan) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 75130007b963..04f304376a76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -22,7 +22,7 @@ import scala.collection.immutable.HashSet import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.java.function.FilterFunction -import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -1563,17 +1563,21 @@ object DecimalAggregates extends Rule[LogicalPlan] { } /** - * Converts local operations (i.e. ones that don't require data exchange) on LocalRelation to - * another LocalRelation. - * - * This is relatively simple as it currently handles only a single case: Project. + * Converts local operations (i.e. ones that don't require data exchange) on LocalRelation or + * OneRowRelation to a new LocalRelation. */ object ConvertToLocalRelation extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case Project(projectList, LocalRelation(output, data)) if !projectList.exists(hasUnevaluableExpr) => val projection = new InterpretedProjection(projectList, output) LocalRelation(projectList.map(_.toAttribute), data.map(projection)) + case table: InlineTable => + val data = table.rows.map { row => + val projection = new InterpretedProjection(row) + projection(EmptyRow) + } + LocalRelation(table.output, data) } private def hasUnevaluableExpr(expr: Expression): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 679adf2717b5..a99f08b092bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -656,40 +656,36 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create an inline table (a virtual table in Hive parlance). */ override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { - // Get the backing expressions. - val expressions = ctx.expression.asScala.map { eCtx => - val e = expression(eCtx) - assert(e.foldable, "All expressions in an inline table must be constants.", eCtx) - e - } - - // Validate and evaluate the rows. - val (structType, structConstructor) = expressions.head.dataType match { - case st: StructType => - (st, (e: Expression) => e) - case dt => - val st = CreateStruct(Seq(expressions.head)).dataType - (st, (e: Expression) => CreateStruct(Seq(e))) - } - val rows = expressions.map { - case expression => - val safe = Cast(structConstructor(expression), structType) - safe.eval().asInstanceOf[InternalRow] + // Create expressions. + val rows = ctx.expression.asScala.map { e => + expression(e) match { + case CreateStruct(children) => children + case child => Seq(child) + } } - // Construct attributes. - val baseAttributes = structType.toAttributes.map(_.withNullability(true)) - val attributes = if (ctx.identifierList != null) { - val aliases = visitIdentifierList(ctx.identifierList) - assert(aliases.size == baseAttributes.size, - "Number of aliases must match the number of fields in an inline table.", ctx) - baseAttributes.zip(aliases).map(p => p._1.withName(p._2)) + // Resolve aliases. + val numExpectedColumns = rows.head.size + val aliases = if (ctx.identifierList != null) { + val names = visitIdentifierList(ctx.identifierList) + assert(names.size == numExpectedColumns, + s"Number of aliases '${names.size}' must match the number of fields " + + s"'$numExpectedColumns' in an inline table", ctx) + names } else { - baseAttributes + Seq.tabulate(numExpectedColumns)(i => s"col${i + 1}") } - // Create plan and add an alias if a name has been defined. - LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan) + // Create the inline table. + val table = InlineTable(rows.zipWithIndex.map { case (expressions, index) => + assert(expressions.size == numExpectedColumns, + s"Number of values '${expressions.size}' in row '${index + 1}' does not match the " + + s"expected number of values '$numExpectedColumns' in a row", ctx) + expressions.zip(aliases).map { + case (expression, name) => Alias(expression, name)() + } + }) + table.optionalMap(ctx.identifier)(aliasPlan) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index eb612c4c12c7..68443baf582d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -769,3 +769,41 @@ case object OneRowRelation extends LeafNode { */ override lazy val statistics: Statistics = Statistics(sizeInBytes = 1) } + +/** + * An inline table that holds a number of foldable expressions, which can be materialized into + * rows. This is semantically the same as a Union of one row relations. + */ +case class InlineTable(rows: Seq[Seq[NamedExpression]]) extends LeafNode { + lazy val expressionsResolved: Boolean = rows.forall(_.forall(_.resolved)) + + lazy val validDimensions: Boolean = { + val size = rows.headOption.map(_.size).getOrElse(0) + rows.tail.forall(_.size == size) + } + + override lazy val resolved: Boolean = { + def allRowsCompatible: Boolean = { + val expectedDataTypes = rows.headOption.toSeq.flatMap(_.map(_.dataType)) + rows.tail.forall { row => + row.map(_.dataType).zip(expectedDataTypes).forall { + case (dt1, dt2) => dt1 == dt2 + } + } + } + expressionsResolved && validDimensions && allRowsCompatible + } + + override def maxRows: Option[Long] = Some(rows.size) + + override def output: Seq[Attribute] = rows.transpose.map { + case column if column.forall(_.resolved) => + column.head.toAttribute.withNullability(column.exists(_.nullable)) + case column => + UnresolvedAttribute(column.head.name) + } + + override lazy val statistics: Statistics = { + Statistics(output.map(_.dataType.defaultSize).sum * rows.size) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala index 049a19b86f7c..118f5627ff37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -21,8 +21,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -52,4 +53,15 @@ class ConvertToLocalRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("InlineTable should be turned into a single LocalRelation") { + val testRelation = InlineTable( + Seq(Literal(1).as("a")) :: + Seq(Literal(2).as("a")) :: + Seq(Literal(3).as("a")) :: Nil) + val correctAnswer = LocalRelation( + LocalRelation('a.int.withNullability(false)).output, + InternalRow(1) :: InternalRow(2) :: InternalRow(3) :: Nil) + val optimized = Optimize.execute(testRelation.analyze) + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index fbe236e19626..c9ba4d3a8e9f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.parser -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator import org.apache.spark.sql.catalyst.expressions._ @@ -424,19 +423,25 @@ class PlanParserSuite extends PlanTest { } test("inline table") { - assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( - Seq('col1.int), - Seq(1, 2, 3, 4).map(x => Row(x)))) + def rows(names: String*)(values: Any*): Seq[Seq[NamedExpression]] = { + def row(values: Seq[Any]): Seq[NamedExpression] = values.zip(names).map { + case (value, name) => Alias(Literal(value), name)() + } + values.map { + case elements: Seq[Any] => row(elements) + case element => row(Seq(element)) + } + } + assertEqual( + "values 1, 2, 3, 4", + InlineTable(rows("col1")(1, 2, 3, 4))) assertEqual( "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", - LocalRelation.fromExternalRows( - Seq('a.int, 'b.string), - Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) - intercept("values (a, 'a'), (b, 'b')", - "All expressions in an inline table must be constants.") + InlineTable(rows("a", "b")(Seq(1, "a"), Seq(2, "b"), Seq(3, "c"))).as("tbl")) intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", - "Number of aliases must match the number of fields in an inline table.") - intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) + "Number of aliases", "must match the number of fields", "in an inline table") + intercept("values (1, 'a'), (2, 'b', 5Y)", + "Number of values", "in row", "does not match the expected number of values", "in a row") } test("simple select query with !> and !<") {