diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 8d7a6bc4b573d..341b9e5f55c60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, ProjectEstimation} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, JoinEstimation, ProjectEstimation} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -340,14 +340,22 @@ case class Join( case _ => resolvedExceptNatural } - override def computeStats(conf: CatalystConf): Statistics = joinType match { - case LeftAnti | LeftSemi => - // LeftSemi and LeftAnti won't ever be bigger than left - left.stats(conf).copy() - case _ => - // make sure we don't propagate isBroadcastable in other joins, because - // they could explode the size. - super.computeStats(conf).copy(isBroadcastable = false) + override def computeStats(conf: CatalystConf): Statistics = { + def simpleEstimation: Statistics = joinType match { + case LeftAnti | LeftSemi => + // LeftSemi and LeftAnti won't ever be bigger than left + left.stats(conf) + case _ => + // Make sure we don't propagate isBroadcastable in other joins, because + // they could explode the size. + super.computeStats(conf).copy(isBroadcastable = false) + } + + if (conf.cboEnabled) { + JoinEstimation.estimate(conf, this).getOrElse(simpleEstimation) + } else { + simpleEstimation + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index e8b794212c10d..4d18b28be8663 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation +import scala.math.BigDecimal.RoundingMode + import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.types.{DataType, StringType} object EstimationUtils { @@ -29,6 +31,20 @@ object EstimationUtils { def rowCountsExist(conf: CatalystConf, plans: LogicalPlan*): Boolean = plans.forall(_.stats(conf).rowCount.isDefined) + /** Check if each attribute has column stat in the corresponding statistics. */ + def columnStatsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = { + statsAndAttr.forall { case (stats, attr) => + stats.attributeStats.contains(attr) + } + } + + def nullColumnStat(dataType: DataType, rowCount: BigInt): ColumnStat = { + ColumnStat(distinctCount = 0, min = None, max = None, nullCount = rowCount, + avgLen = dataType.defaultSize, maxLen = dataType.defaultSize) + } + + def ceil(bigDecimal: BigDecimal): BigInt = bigDecimal.setScale(0, RoundingMode.CEILING).toBigInt() + /** Get column stats for output attributes. */ def getOutputMap(inputMap: AttributeMap[ColumnStat], output: Seq[Attribute]) : AttributeMap[ColumnStat] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala new file mode 100644 index 0000000000000..982a5a8bb89be --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ + + +object JoinEstimation extends Logging { + /** + * Estimate statistics after join. Return `None` if the join type is not supported, or we don't + * have enough statistics for estimation. + */ + def estimate(conf: CatalystConf, join: Join): Option[Statistics] = { + join.joinType match { + case Inner | Cross | LeftOuter | RightOuter | FullOuter => + InnerOuterEstimation(conf, join).doEstimate() + case LeftSemi | LeftAnti => + LeftSemiAntiEstimation(conf, join).doEstimate() + case _ => + logDebug(s"[CBO] Unsupported join type: ${join.joinType}") + None + } + } +} + +case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging { + + private val leftStats = join.left.stats(conf) + private val rightStats = join.right.stats(conf) + + /** + * Estimate output size and number of rows after a join operator, and update output column stats. + */ + def doEstimate(): Option[Statistics] = join match { + case _ if !rowCountsExist(conf, join.left, join.right) => + None + + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => + // 1. Compute join selectivity + val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) + val selectivity = joinSelectivity(joinKeyPairs) + + // 2. Estimate the number of output rows + val leftRows = leftStats.rowCount.get + val rightRows = rightStats.rowCount.get + val innerJoinedRows = ceil(BigDecimal(leftRows * rightRows) * selectivity) + + // Make sure outputRows won't be too small based on join type. + val outputRows = joinType match { + case LeftOuter => + // All rows from left side should be in the result. + leftRows.max(innerJoinedRows) + case RightOuter => + // All rows from right side should be in the result. + rightRows.max(innerJoinedRows) + case FullOuter => + // T(A FOJ B) = T(A LOJ B) + T(A ROJ B) - T(A IJ B) + leftRows.max(innerJoinedRows) + rightRows.max(innerJoinedRows) - innerJoinedRows + case _ => + // Don't change for inner or cross join + innerJoinedRows + } + + // 3. Update statistics based on the output of join + val inputAttrStats = AttributeMap( + leftStats.attributeStats.toSeq ++ rightStats.attributeStats.toSeq) + val attributesWithStat = join.output.filter(a => inputAttrStats.contains(a)) + val (fromLeft, fromRight) = attributesWithStat.partition(join.left.outputSet.contains(_)) + + val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { + // The output is empty, we don't need to keep column stats. + Nil + } else if (innerJoinedRows == 0) { + joinType match { + // For outer joins, if the inner join part is empty, the number of output rows is the + // same as that of the outer side. And column stats of join keys from the outer side + // keep unchanged, while column stats of join keys from the other side should be updated + // based on added null values. + case LeftOuter => + fromLeft.map(a => (a, inputAttrStats(a))) ++ + fromRight.map(a => (a, nullColumnStat(a.dataType, leftRows))) + case RightOuter => + fromRight.map(a => (a, inputAttrStats(a))) ++ + fromLeft.map(a => (a, nullColumnStat(a.dataType, rightRows))) + case FullOuter => + fromLeft.map { a => + val oriColStat = inputAttrStats(a) + (a, oriColStat.copy(nullCount = oriColStat.nullCount + rightRows)) + } ++ fromRight.map { a => + val oriColStat = inputAttrStats(a) + (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows)) + } + case _ => Nil + } + } else { + val joinKeyStats = getIntersectedStats(joinKeyPairs) + join.joinType match { + // For outer joins, don't update column stats from the outer side. + case LeftOuter => + fromLeft.map(a => (a, inputAttrStats(a))) ++ + updateAttrStats(outputRows, fromRight, inputAttrStats, joinKeyStats) + case RightOuter => + updateAttrStats(outputRows, fromLeft, inputAttrStats, joinKeyStats) ++ + fromRight.map(a => (a, inputAttrStats(a))) + case FullOuter => + inputAttrStats.toSeq + case _ => + // Update column stats from both sides for inner or cross join. + updateAttrStats(outputRows, attributesWithStat, inputAttrStats, joinKeyStats) + } + } + + val outputAttrStats = AttributeMap(outputStats) + Some(Statistics( + sizeInBytes = getOutputSize(join.output, outputRows, outputAttrStats), + rowCount = Some(outputRows), + attributeStats = outputAttrStats, + isBroadcastable = false)) + + case _ => + // When there is no equi-join condition, we do estimation like cartesian product. + val inputAttrStats = AttributeMap( + leftStats.attributeStats.toSeq ++ rightStats.attributeStats.toSeq) + // Propagate the original column stats + val outputRows = leftStats.rowCount.get * rightStats.rowCount.get + Some(Statistics( + sizeInBytes = getOutputSize(join.output, outputRows, inputAttrStats), + rowCount = Some(outputRows), + attributeStats = inputAttrStats, + isBroadcastable = false)) + } + + // scalastyle:off + /** + * The number of rows of A inner join B on A.k1 = B.k1 is estimated by this basic formula: + * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), where V is the number of distinct values of + * that column. The underlying assumption for this formula is: each value of the smaller domain + * is included in the larger domain. + * Generally, inner join with multiple join keys can also be estimated based on the above + * formula: + * T(A IJ B) = T(A) * T(B) / (max(V(A.k1), V(B.k1)) * max(V(A.k2), V(B.k2)) * ... * max(V(A.kn), V(B.kn))) + * However, the denominator can become very large and excessively reduce the result, so we use a + * conservative strategy to take only the largest max(V(A.ki), V(B.ki)) as the denominator. + */ + // scalastyle:on + def joinSelectivity(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]): BigDecimal = { + var ndvDenom: BigInt = -1 + var i = 0 + while(i < joinKeyPairs.length && ndvDenom != 0) { + val (leftKey, rightKey) = joinKeyPairs(i) + // Check if the two sides are disjoint + val leftKeyStats = leftStats.attributeStats(leftKey) + val rightKeyStats = rightStats.attributeStats(rightKey) + val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) + val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + if (Range.isIntersected(lRange, rRange)) { + // Get the largest ndv among pairs of join keys + val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount) + if (maxNdv > ndvDenom) ndvDenom = maxNdv + } else { + // Set ndvDenom to zero to indicate that this join should have no output + ndvDenom = 0 + } + i += 1 + } + + if (ndvDenom < 0) { + // There isn't join keys or column stats for any of the join key pairs, we do estimation like + // cartesian product. + 1 + } else if (ndvDenom == 0) { + // One of the join key pairs is disjoint, thus the two sides of join is disjoint. + 0 + } else { + 1 / BigDecimal(ndvDenom) + } + } + + /** + * Propagate or update column stats for output attributes. + * 1. For cartesian product, all values are preserved, so there's no need to change column stats. + * 2. For other cases, a) update max/min of join keys based on their intersected range. b) update + * distinct count of other attributes based on output rows after join. + */ + private def updateAttrStats( + outputRows: BigInt, + attributes: Seq[Attribute], + oldAttrStats: AttributeMap[ColumnStat], + joinKeyStats: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = { + val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() + val leftRows = leftStats.rowCount.get + val rightRows = rightStats.rowCount.get + if (outputRows == leftRows * rightRows) { + // Cartesian product, just propagate the original column stats + attributes.foreach(a => outputAttrStats += a -> oldAttrStats(a)) + } else { + val leftRatio = + if (leftRows != 0) BigDecimal(outputRows) / BigDecimal(leftRows) else BigDecimal(0) + val rightRatio = + if (rightRows != 0) BigDecimal(outputRows) / BigDecimal(rightRows) else BigDecimal(0) + attributes.foreach { a => + // check if this attribute is a join key + if (joinKeyStats.contains(a)) { + outputAttrStats += a -> joinKeyStats(a) + } else { + val oldColStat = oldAttrStats(a) + val oldNdv = oldColStat.distinctCount + // We only change (scale down) the number of distinct values if the number of rows + // decreases after join, because join won't produce new values even if the number of + // rows increases. + val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) { + ceil(BigDecimal(oldNdv) * leftRatio) + } else if (join.right.outputSet.contains(a) && rightRatio < 1) { + ceil(BigDecimal(oldNdv) * rightRatio) + } else { + oldNdv + } + // TODO: support nullCount updates for specific outer joins + outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv) + } + } + } + outputAttrStats + } + + /** Get intersected column stats for join keys. */ + private def getIntersectedStats(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]) + : AttributeMap[ColumnStat] = { + + val intersectedStats = new mutable.HashMap[Attribute, ColumnStat]() + joinKeyPairs.foreach { case (leftKey, rightKey) => + val leftKeyStats = leftStats.attributeStats(leftKey) + val rightKeyStats = rightStats.attributeStats(rightKey) + val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) + val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + // When we reach here, join selectivity is not zero, so each pair of join keys should be + // intersected. + assert(Range.isIntersected(lRange, rRange)) + + // Update intersected column stats + assert(leftKey.dataType.sameType(rightKey.dataType)) + val minNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) + val (newMin, newMax) = Range.intersect(lRange, rRange, leftKey.dataType) + intersectedStats.put(leftKey, + leftKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0)) + intersectedStats.put(rightKey, + rightKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0)) + } + AttributeMap(intersectedStats.toSeq) + } + + private def extractJoinKeysWithColStats( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression]): Seq[(AttributeReference, AttributeReference)] = { + leftKeys.zip(rightKeys).collect { + // Currently we don't deal with equal joins like key1 = key2 + 5. + // Note: join keys from EqualNullSafe also fall into this case (Coalesce), consider to + // support it in the future by using `nullCount` in column stats. + case (lk: AttributeReference, rk: AttributeReference) + if columnStatsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) + } + } +} + +case class LeftSemiAntiEstimation(conf: CatalystConf, join: Join) { + def doEstimate(): Option[Statistics] = { + // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic + // column stats. Now we just propagate the statistics from left side. We should do more + // accurate estimation when advanced stats (e.g. histograms) are available. + if (rowCountsExist(conf, join.left)) { + val leftStats = join.left.stats(conf) + // Propagate the original column stats for cartesian product + val outputRows = leftStats.rowCount.get + Some(Statistics( + sizeInBytes = getOutputSize(join.output, outputRows, leftStats.attributeStats), + rowCount = Some(outputRows), + attributeStats = leftStats.attributeStats, + isBroadcastable = false)) + } else { + None + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala new file mode 100644 index 0000000000000..5aa6b9353bc4c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import java.math.{BigDecimal => JDecimal} +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} + + +/** Value range of a column. */ +trait Range + +/** For simplicity we use decimal to unify operations of numeric ranges. */ +case class NumericRange(min: JDecimal, max: JDecimal) extends Range + +/** + * This version of Spark does not have min/max for binary/string types, we define their default + * behaviors by this class. + */ +class DefaultRange extends Range + +/** This is for columns with only null values. */ +class NullRange extends Range + +object Range { + def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { + case StringType | BinaryType => new DefaultRange() + case _ if min.isEmpty || max.isEmpty => new NullRange() + case _ => toNumericRange(min.get, max.get, dataType) + } + + def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match { + case (_, _: DefaultRange) | (_: DefaultRange, _) => + // The DefaultRange represents string/binary types which do not have max/min stats, + // we assume they are intersected to be conservative on estimation + true + case (_, _: NullRange) | (_: NullRange, _) => + false + case (n1: NumericRange, n2: NumericRange) => + n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 + } + + /** + * Intersected results of two ranges. This is only for two overlapped ranges. + * The outputs are the intersected min/max values. + */ + def intersect(r1: Range, r2: Range, dt: DataType): (Option[Any], Option[Any]) = { + (r1, r2) match { + case (_, _: DefaultRange) | (_: DefaultRange, _) => + // binary/string types don't support intersecting. + (None, None) + case (n1: NumericRange, n2: NumericRange) => + val newRange = NumericRange(n1.min.max(n2.min), n1.max.min(n2.max)) + val (newMin, newMax) = fromNumericRange(newRange, dt) + (Some(newMin), Some(newMax)) + } + } + + /** + * For simplicity we use decimal to unify operations of numeric types, the two methods below + * are the contract of conversion. + */ + private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = { + dataType match { + case _: NumericType => + NumericRange(new JDecimal(min.toString), new JDecimal(max.toString)) + case BooleanType => + val min1 = if (min.asInstanceOf[Boolean]) 1 else 0 + val max1 = if (max.asInstanceOf[Boolean]) 1 else 0 + NumericRange(new JDecimal(min1), new JDecimal(max1)) + case DateType => + val min1 = DateTimeUtils.fromJavaDate(min.asInstanceOf[Date]) + val max1 = DateTimeUtils.fromJavaDate(max.asInstanceOf[Date]) + NumericRange(new JDecimal(min1), new JDecimal(max1)) + case TimestampType => + val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp]) + val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp]) + NumericRange(new JDecimal(min1), new JDecimal(max1)) + } + } + + private def fromNumericRange(n: NumericRange, dataType: DataType): (Any, Any) = { + dataType match { + case _: IntegralType => + (n.min.longValue(), n.max.longValue()) + case FloatType | DoubleType => + (n.min.doubleValue(), n.max.doubleValue()) + case _: DecimalType => + (n.min, n.max) + case BooleanType => + (n.min.longValue() == 1, n.max.longValue() == 1) + case DateType => + (DateTimeUtils.toJavaDate(n.min.intValue()), DateTimeUtils.toJavaDate(n.max.intValue())) + case TimestampType => + (DateTimeUtils.toJavaTimestamp(n.min.longValue()), + DateTimeUtils.toJavaTimestamp(n.max.longValue())) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala new file mode 100644 index 0000000000000..f62df842fa50a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import java.sql.{Date, Timestamp} + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeReference, EqualTo} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.types.{DateType, TimestampType, _} + + +class JoinEstimationSuite extends StatsEstimationTestBase { + + /** Set up tables and its columns for testing */ + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + attr("key-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key-5-9") -> ColumnStat(distinctCount = 5, min = Some(5), max = Some(9), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key-2-4") -> ColumnStat(distinctCount = 3, min = Some(2), max = Some(4), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key-2-3") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, + avgLen = 4, maxLen = 4) + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + // Suppose table1 (key-1-5 int, key-5-9 int) has 5 records: (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + private val table1 = StatsTestPlan( + outputList = Seq("key-1-5", "key-5-9").map(nameToAttr), + rowCount = 5, + attributeStats = AttributeMap(Seq("key-1-5", "key-5-9").map(nameToColInfo))) + + // Suppose table2 (key-1-2 int, key-2-4 int) has 3 records: (1, 2), (2, 3), (2, 4) + private val table2 = StatsTestPlan( + outputList = Seq("key-1-2", "key-2-4").map(nameToAttr), + rowCount = 3, + attributeStats = AttributeMap(Seq("key-1-2", "key-2-4").map(nameToColInfo))) + + // Suppose table3 (key-1-2 int, key-2-3 int) has 2 records: (1, 2), (2, 3) + private val table3 = StatsTestPlan( + outputList = Seq("key-1-2", "key-2-3").map(nameToAttr), + rowCount = 2, + attributeStats = AttributeMap(Seq("key-1-2", "key-2-3").map(nameToColInfo))) + + test("cross join") { + // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + val join = Join(table1, table2, Cross, None) + val expectedStats = Statistics( + sizeInBytes = 5 * 3 * (8 + 4 * 4), + rowCount = Some(5 * 3), + // Keep the column stat from both sides unchanged. + attributeStats = AttributeMap( + Seq("key-1-5", "key-5-9", "key-1-2", "key-2-4").map(nameToColInfo))) + assert(join.stats(conf) == expectedStats) + } + + test("disjoint inner join") { + // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // key-5-9 and key-2-4 are disjoint + val join = Join(table1, table2, Inner, + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + val expectedStats = Statistics( + sizeInBytes = 1, + rowCount = Some(0), + attributeStats = AttributeMap(Nil)) + assert(join.stats(conf) == expectedStats) + } + + test("disjoint left outer join") { + // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // key-5-9 and key-2-4 are disjoint + val join = Join(table1, table2, LeftOuter, + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + val expectedStats = Statistics( + sizeInBytes = 5 * (8 + 4 * 4), + rowCount = Some(5), + attributeStats = AttributeMap(Seq("key-1-5", "key-5-9").map(nameToColInfo) ++ + // Null count for right side columns = left row count + Seq(nameToAttr("key-1-2") -> nullColumnStat(nameToAttr("key-1-2").dataType, 5), + nameToAttr("key-2-4") -> nullColumnStat(nameToAttr("key-2-4").dataType, 5)))) + assert(join.stats(conf) == expectedStats) + } + + test("disjoint right outer join") { + // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // key-5-9 and key-2-4 are disjoint + val join = Join(table1, table2, RightOuter, + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + val expectedStats = Statistics( + sizeInBytes = 3 * (8 + 4 * 4), + rowCount = Some(3), + attributeStats = AttributeMap(Seq("key-1-2", "key-2-4").map(nameToColInfo) ++ + // Null count for left side columns = right row count + Seq(nameToAttr("key-1-5") -> nullColumnStat(nameToAttr("key-1-5").dataType, 3), + nameToAttr("key-5-9") -> nullColumnStat(nameToAttr("key-5-9").dataType, 3)))) + assert(join.stats(conf) == expectedStats) + } + + test("disjoint full outer join") { + // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // key-5-9 and key-2-4 are disjoint + val join = Join(table1, table2, FullOuter, + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + val expectedStats = Statistics( + sizeInBytes = (5 + 3) * (8 + 4 * 4), + rowCount = Some(5 + 3), + attributeStats = AttributeMap( + // Update null count in column stats. + Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = 3), + nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3), + nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5), + nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5)))) + assert(join.stats(conf) == expectedStats) + } + + test("inner join") { + // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + val join = Join(table1, table2, Inner, + Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2")))) + // Update column stats for equi-join keys (key-1-5 and key-1-2). + val joinedColStat = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4) + // Update column stat for other column if #outputRow / #sideRow < 1 (key-5-9), or keep it + // unchanged (key-2-4). + val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = 5 * 3 / 5) + + val expectedStats = Statistics( + sizeInBytes = 3 * (8 + 4 * 4), + rowCount = Some(3), + attributeStats = AttributeMap( + Seq(nameToAttr("key-1-5") -> joinedColStat, nameToAttr("key-1-2") -> joinedColStat, + nameToAttr("key-5-9") -> colStatForkey59, nameToColInfo("key-2-4")))) + assert(join.stats(conf) == expectedStats) + } + + test("inner join with multiple equi-join keys") { + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) + val join = Join(table2, table3, Inner, Some( + And(EqualTo(nameToAttr("key-1-2"), nameToAttr("key-1-2")), + EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))) + + // Update column stats for join keys. + val joinedColStat1 = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4) + val joinedColStat2 = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, + avgLen = 4, maxLen = 4) + + val expectedStats = Statistics( + sizeInBytes = 2 * (8 + 4 * 4), + rowCount = Some(2), + attributeStats = AttributeMap( + Seq(nameToAttr("key-1-2") -> joinedColStat1, nameToAttr("key-1-2") -> joinedColStat1, + nameToAttr("key-2-4") -> joinedColStat2, nameToAttr("key-2-3") -> joinedColStat2))) + assert(join.stats(conf) == expectedStats) + } + + test("left outer join") { + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) + val join = Join(table3, table2, LeftOuter, + Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4")))) + val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, + avgLen = 4, maxLen = 4) + + val expectedStats = Statistics( + sizeInBytes = 2 * (8 + 4 * 4), + rowCount = Some(2), + // Keep the column stat from left side unchanged. + attributeStats = AttributeMap( + Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-3"), + nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat))) + assert(join.stats(conf) == expectedStats) + } + + test("right outer join") { + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) + val join = Join(table2, table3, RightOuter, + Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) + val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, + avgLen = 4, maxLen = 4) + + val expectedStats = Statistics( + sizeInBytes = 2 * (8 + 4 * 4), + rowCount = Some(2), + // Keep the column stat from right side unchanged. + attributeStats = AttributeMap( + Seq(nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat, + nameToColInfo("key-1-2"), nameToColInfo("key-2-3")))) + assert(join.stats(conf) == expectedStats) + } + + test("full outer join") { + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) + val join = Join(table2, table3, FullOuter, + Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) + + val expectedStats = Statistics( + sizeInBytes = 3 * (8 + 4 * 4), + rowCount = Some(3), + // Keep the column stat from both sides unchanged. + attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4"), + nameToColInfo("key-1-2"), nameToColInfo("key-2-3")))) + assert(join.stats(conf) == expectedStats) + } + + test("left semi/anti join") { + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) + Seq(LeftSemi, LeftAnti).foreach { jt => + val join = Join(table2, table3, jt, + Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) + // For now we just propagate the statistics from left side for left semi/anti join. + val expectedStats = Statistics( + sizeInBytes = 3 * (8 + 4 * 2), + rowCount = Some(3), + attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4")))) + assert(join.stats(conf) == expectedStats) + } + } + + test("test join keys of different types") { + /** Columns in a table with only one row */ + def genColumnData: mutable.LinkedHashMap[Attribute, ColumnStat] = { + val dec = new java.math.BigDecimal("1.000000000000000000") + val date = Date.valueOf("2016-05-08") + val timestamp = Timestamp.valueOf("2016-05-08 00:00:01") + mutable.LinkedHashMap[Attribute, ColumnStat]( + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, + min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, + min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 1, maxLen = 1), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, + min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 2, maxLen = 2), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, + min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 4, maxLen = 4), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, + min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, + min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, + min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 4, maxLen = 4), + AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, + min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, + min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 1, + min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 1, + min = Some(date), max = Some(date), nullCount = 0, avgLen = 4, maxLen = 4), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 1, + min = Some(timestamp), max = Some(timestamp), nullCount = 0, avgLen = 8, maxLen = 8) + ) + } + + val columnInfo1 = genColumnData + val columnInfo2 = genColumnData + val table1 = StatsTestPlan( + outputList = columnInfo1.keys.toSeq, + rowCount = 1, + attributeStats = AttributeMap(columnInfo1.toSeq)) + val table2 = StatsTestPlan( + outputList = columnInfo2.keys.toSeq, + rowCount = 1, + attributeStats = AttributeMap(columnInfo2.toSeq)) + val joinKeys = table1.output.zip(table2.output) + joinKeys.foreach { case (key1, key2) => + withClue(s"For data type ${key1.dataType}") { + // All values in two tables are the same, so column stats after join are also the same. + val join = Join(Project(Seq(key1), table1), Project(Seq(key2), table2), Inner, + Some(EqualTo(key1, key2))) + val expectedStats = Statistics( + sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))), + rowCount = Some(1), + attributeStats = AttributeMap(Seq(key1 -> columnInfo1(key1), key2 -> columnInfo1(key1)))) + assert(join.stats(conf) == expectedStats) + } + } + } + + test("join with null column") { + val (nullColumn, nullColStat) = (attr("cnull"), + ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 1, avgLen = 4, maxLen = 4)) + val nullTable = StatsTestPlan( + outputList = Seq(nullColumn), + rowCount = 1, + attributeStats = AttributeMap(Seq(nullColumn -> nullColStat))) + val join = Join(table1, nullTable, Inner, Some(EqualTo(nameToAttr("key-1-5"), nullColumn))) + val expectedStats = Statistics( + sizeInBytes = 1, + rowCount = Some(0), + attributeStats = AttributeMap(Nil)) + assert(join.stats(conf) == expectedStats) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index ae102a48451e8..f408dc4153586 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.statsEstimation import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -95,12 +95,7 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 2, min = Some(t1), max = Some(t2), nullCount = 0, avgLen = 8, maxLen = 8) )) - val columnSizes = columnInfo.map { case (attr, colStat) => - (attr, attr.dataType match { - case StringType => colStat.avgLen + 8 + 4 - case _ => colStat.avgLen - }) - } + val columnSizes: Map[Attribute, Long] = columnInfo.map(kv => (kv._1, getColSize(kv._1, kv._2))) val child = StatsTestPlan( outputList = columnInfo.keys.toSeq, rowCount = 2, @@ -108,11 +103,13 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { // Row with single column columnInfo.keys.foreach { attr => - checkProjectStats( - child = child, - projectAttrMap = AttributeMap(attr -> columnInfo(attr) :: Nil), - expectedSize = 2 * (8 + columnSizes(attr)), - expectedRowCount = 2) + withClue(s"For data type ${attr.dataType}") { + checkProjectStats( + child = child, + projectAttrMap = AttributeMap(attr -> columnInfo(attr) :: Nil), + expectedSize = 2 * (8 + columnSizes(attr)), + expectedRowCount = 2) + } } // Row with multiple columns diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index a5fac4ba6f03c..c56b41ce37636 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{IntegerType, StringType} class StatsEstimationTestBase extends SparkFunSuite { @@ -29,6 +29,12 @@ class StatsEstimationTestBase extends SparkFunSuite { /** Enable stats estimation based on CBO. */ protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true, cboEnabled = true) + def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { + // For UTF8String: base + offset + numBytes + case StringType => colStat.avgLen + 8 + 4 + case _ => colStat.avgLen + } + def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)() /** Convert (column name, column stat) pairs to an AttributeMap based on plan output. */