Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2102,6 +2102,10 @@ object CleanupAliases extends Rule[LogicalPlan] {
Window(cleanedWindowExprs, partitionSpec.map(trimAliases),
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)

case InlineTable(rows) =>
val cleanedRows = rows.map(_.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]))
InlineTable(cleanedRows)

// Operators that operate on objects should only have expressions from encoders, which should
// never have extra aliases.
case o: ObjectConsumer => o
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,26 @@ trait CheckAnalysis extends PredicateHelper {
}
}

case InlineTable(rows) if rows.length > 1 =>
val expectedDataTypes = rows.head.map(_.dataType)
rows.zipWithIndex.tail.foreach { case (row, ri) =>
// Check the number of columns.
if (row.length != expectedDataTypes.length) {
failAnalysis(
s"An inline table must have the same number of columns on every row. " +
s"Row '${ri + 1}' has '${row.length}' columns while " +
s"'${expectedDataTypes.length}' columns were expected.")
}
// Check the data
row.map(_.dataType).zip(expectedDataTypes).zipWithIndex.collect {
case ((dt1, dt2), ci) if dt1 != dt2 =>
failAnalysis(
s"Data type '$dt1' of column '${rows.head(ci).name}' at row '${ri + 1}' " +
s"does not match the expected data type '$dt2' for that column. " +
s"Expressions for an inline table's column must have the same data type.")
}
}

case _ => // Fallbacks to the following checks
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ object TypeCoercion {
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
s.makeCopy(Array(newChildren))

case s @ InlineTable(rows) if !s.resolved && s.expressionsResolved && s.validDimensions =>
val targetTypes = getWidestTypes(rows.map(_.map(_.dataType)))
if (targetTypes.nonEmpty) {
s.copy(rows = rows.map(widenTypes(_, targetTypes)))
} else {
s
}
}

/** Build new children with the widest types for each attribute among all the children */
Expand All @@ -251,12 +259,11 @@ object TypeCoercion {

// Get a sequence of data types, each of which is the widest type of this specific attribute
// in all the children
val targetTypes: Seq[DataType] =
getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]())
val targetTypes: Seq[DataType] = getWidestTypes(children.map(_.output.map(_.dataType)))

if (targetTypes.nonEmpty) {
// Add an extra Project if the targetTypes are different from the original types.
children.map(widenTypes(_, targetTypes))
children.map(child => Project(widenTypes(child.output, targetTypes), child))
} else {
// Unable to find a target type to widen, then just return the original set.
children
Expand All @@ -265,30 +272,31 @@ object TypeCoercion {

/** Get the widest type for each attribute in all the children */
@tailrec private def getWidestTypes(
children: Seq[LogicalPlan],
attrIndex: Int,
castedTypes: mutable.Queue[DataType]): Seq[DataType] = {
dataTypes: Seq[Seq[DataType]],
attrIndex: Int = 0,
castedTypes: mutable.Queue[DataType] = mutable.Queue.empty): Seq[DataType] = {
// Return the result after the widen data types have been found for all the children
if (attrIndex >= children.head.output.length) return castedTypes.toSeq
if (attrIndex >= dataTypes.head.length) return castedTypes

// For the attrIndex-th attribute, find the widest type
findWiderCommonType(children.map(_.output(attrIndex).dataType)) match {
findWiderCommonType(dataTypes.map(_(attrIndex))) match {
// If unable to find an appropriate widen type for this column, return an empty Seq
case None => Seq.empty[DataType]
// Otherwise, record the result in the queue and find the type for the next column
case Some(widenType) =>
castedTypes.enqueue(widenType)
getWidestTypes(children, attrIndex + 1, castedTypes)
getWidestTypes(dataTypes, attrIndex + 1, castedTypes)
}
}

/** Given a plan, add an extra project on top to widen some columns' data types. */
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = {
val casted = plan.output.zip(targetTypes).map {
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
/** Cast the expressions to the given dataTypes (if we need to). */
private def widenTypes(
expressions: Seq[NamedExpression],
targetTypes: Seq[DataType]): Seq[NamedExpression] = {
expressions.zip(targetTypes).map {
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.toString)()
case (e, _) => e
}
Project(casted, plan)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.immutable.HashSet
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -1563,17 +1563,21 @@ object DecimalAggregates extends Rule[LogicalPlan] {
}

/**
* Converts local operations (i.e. ones that don't require data exchange) on LocalRelation to
* another LocalRelation.
*
* This is relatively simple as it currently handles only a single case: Project.
* Converts local operations (i.e. ones that don't require data exchange) on LocalRelation or
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update comment.

* OneRowRelation to a new LocalRelation.
*/
object ConvertToLocalRelation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case Project(projectList, LocalRelation(output, data))
if !projectList.exists(hasUnevaluableExpr) =>
val projection = new InterpretedProjection(projectList, output)
LocalRelation(projectList.map(_.toAttribute), data.map(projection))
case table: InlineTable =>
val data = table.rows.map { row =>
val projection = new InterpretedProjection(row)
projection(EmptyRow)
}
LocalRelation(table.output, data)
}

private def hasUnevaluableExpr(expr: Expression): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -656,40 +656,36 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* Create an inline table (a virtual table in Hive parlance).
*/
override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) {
// Get the backing expressions.
val expressions = ctx.expression.asScala.map { eCtx =>
val e = expression(eCtx)
assert(e.foldable, "All expressions in an inline table must be constants.", eCtx)
e
}

// Validate and evaluate the rows.
val (structType, structConstructor) = expressions.head.dataType match {
case st: StructType =>
(st, (e: Expression) => e)
case dt =>
val st = CreateStruct(Seq(expressions.head)).dataType
(st, (e: Expression) => CreateStruct(Seq(e)))
}
val rows = expressions.map {
case expression =>
val safe = Cast(structConstructor(expression), structType)
safe.eval().asInstanceOf[InternalRow]
// Create expressions.
val rows = ctx.expression.asScala.map { e =>
expression(e) match {
case CreateStruct(children) => children
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hvanhovell what's this about? Why do we need to expand struct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think I understand what's happening here now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parser creates rows by issuing CreateStruct commands. Inline table takes a Seq[Expression] per row. So we need to extracts the children from the CreateStruct.

case child => Seq(child)
}
}

// Construct attributes.
val baseAttributes = structType.toAttributes.map(_.withNullability(true))
val attributes = if (ctx.identifierList != null) {
val aliases = visitIdentifierList(ctx.identifierList)
assert(aliases.size == baseAttributes.size,
"Number of aliases must match the number of fields in an inline table.", ctx)
baseAttributes.zip(aliases).map(p => p._1.withName(p._2))
// Resolve aliases.
val numExpectedColumns = rows.head.size
val aliases = if (ctx.identifierList != null) {
val names = visitIdentifierList(ctx.identifierList)
assert(names.size == numExpectedColumns,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an error case users can hit, should we throw ParserException instead of assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It uses a parser only version of assert that throws a ParseException: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala#L81

Come to think of it, we might need to rename it because people expect that assert calls can be elided. That is for a different PR though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

s"Number of aliases '${names.size}' must match the number of fields " +
s"'$numExpectedColumns' in an inline table", ctx)
names
} else {
baseAttributes
Seq.tabulate(numExpectedColumns)(i => s"col${i + 1}")
}

// Create plan and add an alias if a name has been defined.
LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan)
// Create the inline table.
val table = InlineTable(rows.zipWithIndex.map { case (expressions, index) =>
assert(expressions.size == numExpectedColumns,
s"Number of values '${expressions.size}' in row '${index + 1}' does not match the " +
s"expected number of values '$numExpectedColumns' in a row", ctx)
expressions.zip(aliases).map {
case (expression, name) => Alias(expression, name)()
}
})
table.optionalMap(ctx.identifier)(aliasPlan)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
Expand Down Expand Up @@ -769,3 +769,41 @@ case object OneRowRelation extends LeafNode {
*/
override lazy val statistics: Statistics = Statistics(sizeInBytes = 1)
}

/**
* An inline table that holds a number of foldable expressions, which can be materialized into
* rows. This is semantically the same as a Union of one row relations.
*/
case class InlineTable(rows: Seq[Seq[NamedExpression]]) extends LeafNode {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we assert rows.nonEmpty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I had these checks. The thing is that none of the LogicalPlans have such logic, it has all been centralized in CheckAnalysis. So I added it there.

It might not be a bad plan to move this functionality into the separate plans on the longer run.

lazy val expressionsResolved: Boolean = rows.forall(_.forall(_.resolved))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this? QueryPlan.expressions already handles seq of seq of expressions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed that piece of code in two places (resolve and type coercion), so I made it a val. But I can remove this.


lazy val validDimensions: Boolean = {
val size = rows.headOption.map(_.size).getOrElse(0)
rows.tail.forall(_.size == size)
}

override lazy val resolved: Boolean = {
def allRowsCompatible: Boolean = {
val expectedDataTypes = rows.headOption.toSeq.flatMap(_.map(_.dataType))
rows.tail.forall { row =>
row.map(_.dataType).zip(expectedDataTypes).forall {
case (dt1, dt2) => dt1 == dt2
}
}
}
expressionsResolved && validDimensions && allRowsCompatible
}

override def maxRows: Option[Long] = Some(rows.size)

override def output: Seq[Attribute] = rows.transpose.map {
case column if column.forall(_.resolved) =>
column.head.toAttribute.withNullability(column.exists(_.nullable))
case column =>
UnresolvedAttribute(column.head.name)
}

override lazy val statistics: Statistics = {
Statistics(output.map(_.dataType.defaultSize).sum * rows.size)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor


Expand Down Expand Up @@ -52,4 +53,15 @@ class ConvertToLocalRelationSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("InlineTable should be turned into a single LocalRelation") {
val testRelation = InlineTable(
Seq(Literal(1).as("a")) ::
Seq(Literal(2).as("a")) ::
Seq(Literal(3).as("a")) :: Nil)
val correctAnswer = LocalRelation(
LocalRelation('a.int.withNullability(false)).output,
InternalRow(1) :: InternalRow(2) :: InternalRow(3) :: Nil)
val optimized = Optimize.execute(testRelation.analyze)
comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.parser

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -424,19 +423,25 @@ class PlanParserSuite extends PlanTest {
}

test("inline table") {
assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows(
Seq('col1.int),
Seq(1, 2, 3, 4).map(x => Row(x))))
def rows(names: String*)(values: Any*): Seq[Seq[NamedExpression]] = {
def row(values: Seq[Any]): Seq[NamedExpression] = values.zip(names).map {
case (value, name) => Alias(Literal(value), name)()
}
values.map {
case elements: Seq[Any] => row(elements)
case element => row(Seq(element))
}
}
assertEqual(
"values 1, 2, 3, 4",
InlineTable(rows("col1")(1, 2, 3, 4)))
assertEqual(
"values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)",
LocalRelation.fromExternalRows(
Seq('a.int, 'b.string),
Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl"))
intercept("values (a, 'a'), (b, 'b')",
"All expressions in an inline table must be constants.")
InlineTable(rows("a", "b")(Seq(1, "a"), Seq(2, "b"), Seq(3, "c"))).as("tbl"))
intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)",
"Number of aliases must match the number of fields in an inline table.")
intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)"))
"Number of aliases", "must match the number of fields", "in an inline table")
intercept("values (1, 'a'), (2, 'b', 5Y)",
"Number of values", "in row", "does not match the expected number of values", "in a row")
}

test("simple select query with !> and !<") {
Expand Down