From fa7e213122dd1ad9cfbf6f65c2fb0b609bc9feb9 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Thu, 12 Jan 2017 10:08:54 +0800 Subject: [PATCH 01/12] join estimation --- .../plans/logical/basicLogicalOperators.scala | 14 + .../logical/estimation/EstimationUtils.scala | 74 +++++ .../logical/estimation/JoinEstimation.scala | 311 ++++++++++++++++++ .../plans/logical/estimation/Range.scala | 122 +++++++ .../sql/estimation/JoinEstimationSuite.scala | 161 +++++++++ 5 files changed, 682 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/JoinEstimation.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/estimation/JoinEstimationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 8d7a6bc4b573..474efdf9d102 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -24,6 +24,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, ProjectEstimation} +import org.apache.spark.sql.catalyst.plans.logical.estimation.JoinEstimation +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, ProjectEstimation} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -349,6 +351,18 @@ case class Join( // they could explode the size. super.computeStats(conf).copy(isBroadcastable = false) } + + override lazy val statistics: Statistics = JoinEstimation.estimate(this).getOrElse( + joinType match { + case LeftAnti | LeftSemi => + // LeftSemi and LeftAnti won't ever be bigger than left + left.statistics + case _ => + // make sure we don't propagate isBroadcastable in other joins, because + // they could explode the size. + super.statistics.copy(isBroadcastable = false) + } + ) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala new file mode 100644 index 000000000000..f42dd29015bf --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.estimation + +import scala.math.BigDecimal.RoundingMode + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.types.StringType + + +object EstimationUtils extends Logging { + + /** Check if each plan has rowCount in its statistics. */ + def rowCountsExist(plans: LogicalPlan*): Boolean = + plans.forall(_.statistics.rowCount.isDefined) + + /** Check if each attribute has column stat in the corresponding statistics. */ + def columnStatsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = { + statsAndAttr.forall { case (stats, attr) => + stats.attributeStats.contains(attr) + } + } + + /** Get column stats for output attributes. */ + def getOutputMap(inputMap: AttributeMap[ColumnStat], output: Seq[Attribute]) + : AttributeMap[ColumnStat] = { + AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _))) + } + + def ceil(bigDecimal: BigDecimal): BigInt = bigDecimal.setScale(0, RoundingMode.CEILING).toBigInt() + + def getRowSize(attributes: Seq[Attribute], attrStats: AttributeMap[ColumnStat]): Long = { + // We assign a generic overhead for a Row object, the actual overhead is different for different + // Row format. + 8 + attributes.map { attr => + if (attrStats.contains(attr)) { + attr.dataType match { + case StringType => + // UTF8String: base + offset + numBytes + attrStats(attr).avgLen + 8 + 4 + case _ => + attrStats(attr).avgLen + } + } else { + attr.dataType.defaultSize + } + }.sum + } +} + +/** Attribute Reference extractor */ +object ExtractAttr { + def unapply(exp: Expression): Option[AttributeReference] = exp match { + case ar: AttributeReference => Some(ar) + case _ => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/JoinEstimation.scala new file mode 100644 index 000000000000..86494e21820b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/JoinEstimation.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.estimation + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.estimation.EstimationUtils._ +import org.apache.spark.sql.types.DataType + + +object JoinEstimation extends Logging { + /** + * Estimate statistics after join. Return `None` if the join type is not supported, or we don't + * have enough statistics for estimation. + */ + def estimate(join: Join): Option[Statistics] = { + join.joinType match { + case Inner | Cross | LeftOuter | RightOuter | FullOuter => + InnerOuterEstimation(join).doEstimate() + case LeftSemi | LeftAnti => + LeftSemiAntiEstimation(join).doEstimate() + case _ => + logDebug(s"Unsupported join type: ${join.joinType}") + None + } + } +} + +case class InnerOuterEstimation(join: Join) extends Logging { + + /** + * Estimate output size and number of rows after a join operator, and update output column stats. + */ + def doEstimate(): Option[Statistics] = join match { + case _ if !rowCountsExist(join.left, join.right) => + None + + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => + // 1. Compute join selectivity + val leftStats = left.statistics + val rightStats = right.statistics + val joinKeyPairs = extractJoinKeys(leftKeys, rightKeys) + val selectivity = joinSelectivity(joinKeyPairs, leftStats, rightStats) + + // 2. Estimate the number of output rows + val leftRows = leftStats.rowCount.get + val rightRows = rightStats.rowCount.get + val innerRows = ceil(BigDecimal(leftRows * rightRows) * selectivity) + + // Make sure outputRows won't be too small based on join type. + val outputRows = joinType match { + case LeftOuter => + // All rows from left side should be in the result. + leftRows.max(innerRows) + case RightOuter => + // All rows from right side should be in the result. + rightRows.max(innerRows) + case FullOuter => + // Simulate full outer join as obtaining the number of elements in the union of two + // finite sets: A \cup B = A + B - A \cap B => A FOJ B = A + B - A IJ B. + // But the "inner join" part can be much larger than A \cap B, making the simulated + // result much smaller. To prevent this, we choose the larger one between the simulated + // part and the inner part. + (leftRows + rightRows - innerRows).max(innerRows) + case _ => + // Don't change for inner or cross join + innerRows + } + + // 3. Update statistics based on the output of join + val intersectedStats = if (selectivity == 0) { + AttributeMap[ColumnStat](Nil) + } else { + updateIntersectedStats(joinKeyPairs, leftStats, rightStats) + } + val inputAttrStats = AttributeMap( + join.left.statistics.attributeStats.toSeq ++ join.right.statistics.attributeStats.toSeq) + val attributesWithStat = join.output.filter(a => inputAttrStats.contains(a)) + val (fromLeft, fromRight) = attributesWithStat.partition(join.left.outputSet.contains(_)) + val outputStats: Map[Attribute, ColumnStat] = join.joinType match { + case LeftOuter => + // Don't update column stats for attributes from left side. + fromLeft.map(a => (a, inputAttrStats(a))).toMap ++ + updateAttrStats(outputRows, fromRight, inputAttrStats, intersectedStats) + case RightOuter => + // Don't update column stats for attributes from right side. + updateAttrStats(outputRows, fromLeft, inputAttrStats, intersectedStats) ++ + fromRight.map(a => (a, inputAttrStats(a))).toMap + case FullOuter => + // Don't update column stats for attributes from both sides. + attributesWithStat.map(a => (a, inputAttrStats(a))).toMap + case _ => + // Update column stats from both sides for inner or cross join. + updateAttrStats(outputRows, attributesWithStat, inputAttrStats, intersectedStats) + } + val outputAttrStats = AttributeMap(outputStats.toSeq) + + Some(Statistics( + sizeInBytes = outputRows * getRowSize(join.output, outputAttrStats), + rowCount = Some(outputRows), + attributeStats = outputAttrStats, + isBroadcastable = false)) + + case _ => + // When there is no equi-join condition, we do estimation like cartesian product. + val inputAttrStats = AttributeMap( + join.left.statistics.attributeStats.toSeq ++ join.right.statistics.attributeStats.toSeq) + // Propagate the original column stats + val outputAttrStats = getOutputMap(inputAttrStats, join.output) + val outputRows = join.left.statistics.rowCount.get * join.right.statistics.rowCount.get + Some(Statistics( + sizeInBytes = outputRows * getRowSize(join.output, outputAttrStats), + rowCount = Some(outputRows), + attributeStats = outputAttrStats, + isBroadcastable = false)) + } + + // scalastyle:off + /** + * The number of rows of A inner join B on A.k1 = B.k1 is estimated by this basic formula: + * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), where V is the number of distinct values of + * that column. The underlying assumption for this formula is: each value of the smaller domain + * is included in the larger domain. + * Generally, inner join with multiple join keys can also be estimated based on the above + * formula: + * T(A IJ B) = T(A) * T(B) / (max(V(A.k1), V(B.k1)) * max(V(A.k2), V(B.k2)) * ... * max(V(A.kn), V(B.kn))) + * However, the denominator can become very large and excessively reduce the result, so we use a + * conservative strategy to take only the largest max(V(A.ki), V(B.ki)) as the denominator. + */ + // scalastyle:on + def joinSelectivity( + joinKeyPairs: Seq[(AttributeReference, AttributeReference)], + leftStats: Statistics, + rightStats: Statistics): BigDecimal = { + + var ndvDenom: BigInt = -1 + var i = 0 + while(i < joinKeyPairs.length && ndvDenom != 0) { + val (leftKey, rightKey) = joinKeyPairs(i) + // Do estimation if we have enough statistics + if (columnStatsExist((leftStats, leftKey), (rightStats, rightKey))) { + // Check if the two sides are disjoint + val leftKeyStats = leftStats.attributeStats(leftKey) + val rightKeyStats = rightStats.attributeStats(rightKey) + val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) + val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + if (Range.isIntersected(lRange, rRange)) { + // Get the largest ndv among pairs of join keys + val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount) + if (maxNdv > ndvDenom) ndvDenom = maxNdv + } else { + // Set ndvDenom to zero to indicate that this join should have no output + ndvDenom = 0 + } + } + i += 1 + } + + if (ndvDenom < 0) { + // There isn't join keys or column stats for any of the join key pairs, we do estimation like + // cartesian product. + 1 + } else if (ndvDenom == 0) { + // One of the join key pairs is disjoint, thus the two sides of join is disjoint. + 0 + } else { + 1 / BigDecimal(ndvDenom) + } + } + + /** Update column stats for output attributes. */ + private def updateAttrStats( + outputRows: BigInt, + attributes: Seq[Attribute], + oldAttrStats: AttributeMap[ColumnStat], + joinKeyStats: AttributeMap[ColumnStat]): AttributeMap[ColumnStat] = { + val outputAttrStats = new mutable.HashMap[Attribute, ColumnStat]() + val leftRows = join.left.statistics.rowCount.get + val rightRows = join.right.statistics.rowCount.get + if (outputRows == 0) { + // empty output + attributes.foreach(a => outputAttrStats.put(a, emptyColumnStat(a.dataType))) + } else if (outputRows == leftRows * rightRows) { + // We do estimation like cartesian product and propagate the original column stats + attributes.foreach(a => outputAttrStats.put(a, oldAttrStats(a))) + } else { + val leftRatio = + if (leftRows != 0) BigDecimal(outputRows) / BigDecimal(leftRows) else BigDecimal(0) + val rightRatio = + if (rightRows != 0) BigDecimal(outputRows) / BigDecimal(rightRows) else BigDecimal(0) + attributes.foreach { a => + // check if this attribute is a join key + if (joinKeyStats.contains(a)) { + outputAttrStats.put(a, joinKeyStats(a)) + } else { + val oldCS = oldAttrStats(a) + val oldNdv = oldCS.distinctCount + // We only change (scale down) the number of distinct values if the number of rows + // decreases after join, because join won't produce new values even if the number of + // rows increases. + val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) { + ceil(BigDecimal(oldNdv) * leftRatio) + } else if (join.right.outputSet.contains(a) && rightRatio < 1) { + ceil(BigDecimal(oldNdv) * rightRatio) + } else { + oldNdv + } + // TODO: support nullCount updates for specific outer joins + outputAttrStats.put(a, oldCS.copy(distinctCount = newNdv)) + } + } + } + AttributeMap(outputAttrStats.toSeq) + } + + /** Update intersected column stats for join keys. */ + private def updateIntersectedStats( + joinKeyPairs: Seq[(AttributeReference, AttributeReference)], + leftStats: Statistics, + rightStats: Statistics): AttributeMap[ColumnStat] = { + val intersectedStats = new mutable.HashMap[Attribute, ColumnStat]() + joinKeyPairs.foreach { case (leftKey, rightKey) => + // Do estimation if we have enough statistics + if (columnStatsExist((leftStats, leftKey), (rightStats, rightKey))) { + // Check if the two sides are disjoint + val leftKeyStats = leftStats.attributeStats(leftKey) + val rightKeyStats = rightStats.attributeStats(rightKey) + val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) + val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + if (Range.isIntersected(lRange, rRange)) { + // Update intersected column stats + val minNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) + val (newMin1, newMax1, newMin2, newMax2) = + Range.intersect(lRange, rRange, leftKey.dataType, rightKey.dataType) + intersectedStats.put(leftKey, intersectedColumnStat(leftKeyStats, minNdv, + newMin1, newMax1)) + intersectedStats.put(rightKey, intersectedColumnStat(rightKeyStats, minNdv, + newMin2, newMax2)) + } + } + } + AttributeMap(intersectedStats.toSeq) + } + + private def emptyColumnStat(dataType: DataType): ColumnStat = { + ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, + avgLen = dataType.defaultSize, maxLen = dataType.defaultSize) + } + + private def intersectedColumnStat( + origin: ColumnStat, + newDistinctCount: BigInt, + newMin: Option[Any], + newMax: Option[Any]): ColumnStat = { + origin.copy(distinctCount = newDistinctCount, min = newMin, max = newMax, nullCount = 0) + } + + private def extractJoinKeys( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression]): Seq[(AttributeReference, AttributeReference)] = { + leftKeys.zip(rightKeys).flatMap { + case (ExtractAttr(left), ExtractAttr(right)) => Some((left, right)) + case (left, right) => + // Currently we don't deal with equal joins like key1 = key2 + 5. + // Note: join keys from EqualNullSafe also fall into this case (Coalesce), consider to + // support it in the future by using `nullCount` in column stats. + logDebug(s"Unsupported equi-join expression: left key: $left, right key: $right") + None + } + } +} + +case class LeftSemiAntiEstimation(join: Join) { + def doEstimate(): Option[Statistics] = { + // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic + // column stats. Now we just propagate the statistics from left side. We should do more + // accurate estimation when advanced stats (e.g. histograms) are available. + if (rowCountsExist(join.left)) { + // Propagate the original column stats for cartesian product + val outputAttrStats = getOutputMap(join.left.statistics.attributeStats, join.output) + val outputRows = join.left.statistics.rowCount.get + Some(Statistics( + sizeInBytes = outputRows * getRowSize(join.output, outputAttrStats), + rowCount = Some(outputRows), + attributeStats = outputAttrStats, + isBroadcastable = false)) + } else { + None + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala new file mode 100644 index 000000000000..8fe84688f435 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.estimation + +import java.math.{BigDecimal => JDecimal} +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} + + +/** Value range of a column. */ +trait Range + +/** For simplicity we use decimal to unify operations of numeric ranges. */ +case class NumericRange(min: JDecimal, max: JDecimal) extends Range + +/** + * This version of Spark does not have min/max for binary/string types, we define their default + * behaviors by this class. + */ +class DefaultRange extends Range + +/** This is for columns with only null values. */ +class NullRange extends Range + +object Range { + def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { + case StringType | BinaryType => new DefaultRange() + case _ if min.isEmpty || max.isEmpty => new NullRange() + case _ => toNumericRange(min.get, max.get, dataType) + } + + def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match { + case (_, _: DefaultRange) | (_: DefaultRange, _) => + // Skip overlapping check for binary/string types + true + case (_, _: NullRange) | (_: NullRange, _) => + false + case (n1: NumericRange, n2: NumericRange) => + n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 + } + + /** This is only for two overlapped ranges. */ + def intersect( + r1: Range, + r2: Range, + dt1: DataType, + dt2: DataType): (Option[Any], Option[Any], Option[Any], Option[Any]) = { + (r1, r2) match { + case (_, _: DefaultRange) | (_: DefaultRange, _) => + // binary/string types don't support intersecting. + (None, None, None, None) + case (n1: NumericRange, n2: NumericRange) => + val newRange = NumericRange(n1.min.max(n2.min), n1.max.min(n2.max)) + val (newMin1, newMax1) = fromNumericRange(newRange, dt1) + val (newMin2, newMax2) = fromNumericRange(newRange, dt2) + (Some(newMin1), Some(newMax1), Some(newMin2), Some(newMax2)) + } + } + + /** + * For simplicity we use decimal to unify operations of numeric types, the two methods below + * are the contract of conversion. + */ + private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = { + dataType match { + case _: NumericType => + NumericRange(new JDecimal(min.toString), new JDecimal(max.toString)) + case BooleanType => + val min1 = if (min.asInstanceOf[Boolean]) 1 else 0 + val max1 = if (max.asInstanceOf[Boolean]) 1 else 0 + NumericRange(new JDecimal(min1), new JDecimal(max1)) + case DateType => + val min1 = DateTimeUtils.fromJavaDate(min.asInstanceOf[Date]) + val max1 = DateTimeUtils.fromJavaDate(max.asInstanceOf[Date]) + NumericRange(new JDecimal(min1), new JDecimal(max1)) + case TimestampType => + val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp]) + val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp]) + NumericRange(new JDecimal(min1), new JDecimal(max1)) + case _ => + throw new AnalysisException(s"Type $dataType is not castable to numeric in estimation.") + } + } + + private def fromNumericRange(n: NumericRange, dataType: DataType): (Any, Any) = { + dataType match { + case _: IntegralType => + (n.min.longValue(), n.max.longValue()) + case FloatType | DoubleType => + (n.min.doubleValue(), n.max.doubleValue()) + case _: DecimalType => + (n.min, n.max) + case BooleanType => + (n.min.longValue() == 1, n.max.longValue() == 1) + case DateType => + (DateTimeUtils.toJavaDate(n.min.intValue()), DateTimeUtils.toJavaDate(n.max.intValue())) + case TimestampType => + (DateTimeUtils.toJavaTimestamp(n.min.longValue()), + DateTimeUtils.toJavaTimestamp(n.max.longValue())) + case _ => + throw new AnalysisException(s"Type $dataType is not castable from numeric in estimation.") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/estimation/JoinEstimationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/estimation/JoinEstimationSuite.scala new file mode 100644 index 000000000000..9ba0d7f7642f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/estimation/JoinEstimationSuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.estimation + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.estimation.EstimationUtils._ +import org.apache.spark.sql.test.SharedSQLContext + + +class JoinEstimationSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + /** Data for one-column tables */ + private val joinEstimationTestData1 = Seq( + ("join_est_test1", "key1", Seq[Int](1, 1, 2, 2, 3, 3, 3)), + ("join_est_test2", "key2", Seq[Int](5, 6))) + + /** Data for two-column tables */ + private val joinEstimationTestData2 = Seq( + ("join_est_test3", Seq("key31", "key32"), + Seq[(Int, Int)]((1, 9), (2, 8), (3, 7), (4, 6), (5, 5))), + ("join_est_test4", Seq("key41", "key42"), + Seq[(Int, Int)]((1, 3), (2, 4)))) + + /** Original column stats */ + val colStatForKey1 = ColumnStat(3, Some(1), Some(3), 0, 4, 4) + val colStatForKey2 = ColumnStat(2, Some(5), Some(6), 0, 4, 4) + val colStatForKey31 = ColumnStat(5, Some(1), Some(5), 0, 4, 4) + val colStatForKey32 = ColumnStat(5, Some(5), Some(9), 0, 4, 4) + val colStatForKey41 = ColumnStat(2, Some(1), Some(2), 0, 4, 4) + val colStatForKey42 = ColumnStat(2, Some(3), Some(4), 0, 4, 4) + + override def beforeAll(): Unit = { + super.beforeAll() + // Create tables and collect statistics + joinEstimationTestData1.foreach { case (table, column, data) => + data.toDF(column).write.saveAsTable(table) + sql(s"analyze table $table compute STATISTICS FOR COLUMNS $column") + } + joinEstimationTestData2.foreach { case (table, columns, data) => + data.toDF(columns: _*).write.saveAsTable(table) + sql(s"analyze table $table compute STATISTICS FOR COLUMNS ${columns.mkString(", ")}") + } + } + + override def afterAll(): Unit = { + joinEstimationTestData1.foreach { case (table, _, _) => sql(s"DROP TABLE IF EXISTS $table") } + joinEstimationTestData2.foreach { case (table, _, _) => sql(s"DROP TABLE IF EXISTS $table") } + super.afterAll() + } + + + test("estimate inner join") { + val innerJoinSql = + "select count(1) from join_est_test1 join join_est_test3 on key1 = key31" + // Update column stats from both sides. + val joinedColStat = ColumnStat(3, Some(1), Some(3), 0, 4, 4) + val colStats = Seq("key1" -> joinedColStat, "key31" -> joinedColStat) + validateEstimatedStats(innerJoinSql, colStats) + } + + test("update column stats for join keys and non-join keys") { + val innerJoinSql = + "select count(1) from join_est_test3 join join_est_test4 on key31 = key41 and key32 > key42" + // Update column stats for both join keys. + // Update non-join column stat if #outputRow / #sideRow < 1, otherwise keep it unchanged. + val joinedColStat = ColumnStat(2, Some(1), Some(2), 0, 4, 4) + val colStats = Seq("key31" -> joinedColStat, "key41" -> joinedColStat, + "key32" -> colStatForKey32.copy(distinctCount = 2), "key42" -> colStatForKey42) + validateEstimatedStats(innerJoinSql, colStats) + } + + test("estimate disjoint inner join") { + val innerJoinSql = + "select count(1) from join_est_test1 join join_est_test2 on key1 = key2" + // Empty column stats for both sides. + val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) + val colStats = Seq("key1" -> emptyColStat, "key2" -> emptyColStat) + validateEstimatedStats(innerJoinSql, colStats) + } + + test("estimate cross join without equal conditions") { + val crossJoinSql = + "select count(1) from join_est_test1 cross join join_est_test2 on key1 < key2" + // Keep the column stat from both sides unchanged. + val colStats = Seq("key1" -> colStatForKey1, "key2" -> colStatForKey2) + validateEstimatedStats(crossJoinSql, colStats) + } + + test("estimate left outer join") { + val leftOuterJoinSql = + "select count(1) from join_est_test3 left join join_est_test4 on key31 = key41" + // Keep the column stat from left side unchanged. + val joinedColStat = ColumnStat(2, Some(1), Some(2), 0, 4, 4) + val colStats = Seq("key31" -> colStatForKey31, "key41" -> joinedColStat) + validateEstimatedStats(leftOuterJoinSql, colStats) + } + + test("estimate right outer join") { + val rightOuterJoinSql = + "select count(1) from join_est_test4 right join join_est_test3 on key41 = key31" + // Keep the column stat from right side unchanged. + val joinedColStat = ColumnStat(2, Some(1), Some(2), 0, 4, 4) + val colStats = Seq("key41" -> joinedColStat, "key31" -> colStatForKey31) + validateEstimatedStats(rightOuterJoinSql, colStats) + } + + test("estimate full outer join") { + val fullOuterJoinSql = + "select count(1) from join_est_test3 full join join_est_test4 on key31 = key41" + // Keep the column stat from both sides unchanged. + val colStats = Seq("key31" -> colStatForKey31, "key41" -> colStatForKey41) + validateEstimatedStats(fullOuterJoinSql, colStats) + } + + test("estimate left semi/anti join") { + val joinTypeStrings = Seq("left semi", "left anti") + joinTypeStrings.foreach { str => + val joinSql = + s"select count(1) from join_est_test3 $str join join_est_test4 on key31 = key41" + // For now we just propagate the statistics from left side for left semi/anti join. + val colStats = Seq("key31" -> colStatForKey31) + validateEstimatedStats(joinSql, colStats, Some(5)) + } + } + + private def validateEstimatedStats( + joinSql: String, + expectedColStats: Seq[(String, ColumnStat)], + rowCount: Option[Long] = None) : Unit = { + val logicalPlan = sql(joinSql).queryExecution.optimizedPlan + val joinNode = logicalPlan.collect { case join: Join => join }.head + val expectedRowCount = rowCount.getOrElse(sql(joinSql).collect().head.getLong(0)) + val nameToAttr = joinNode.output.map(a => (a.name, a)).toMap + val expectedAttrStats = + AttributeMap(expectedColStats.map(kv => nameToAttr(kv._1) -> kv._2)) + val expectedStats = Statistics( + sizeInBytes = expectedRowCount * getRowSize(joinNode.output, expectedAttrStats), + rowCount = Some(expectedRowCount), + attributeStats = expectedAttrStats, + isBroadcastable = false) + assert(joinNode.statistics == expectedStats) + } +} From a4d7d53b812ee0ea1583cad58e5bec28c5fabb4a Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Fri, 13 Jan 2017 10:10:47 +0800 Subject: [PATCH 02/12] rebase --- .../plans/logical/basicLogicalOperators.scala | 32 ++- .../logical/estimation/EstimationUtils.scala | 74 ------- .../statsEstimation/EstimationUtils.scala | 23 +- .../JoinEstimation.scala | 45 ++-- .../Range.scala | 4 +- .../statsEstimation/JoinEstimationSuite.scala | 202 ++++++++++++++++++ .../sql/estimation/JoinEstimationSuite.scala | 161 -------------- 7 files changed, 261 insertions(+), 280 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/{estimation => statsEstimation}/JoinEstimation.scala (89%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/{estimation => statsEstimation}/Range.scala (96%) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/estimation/JoinEstimationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 474efdf9d102..341b9e5f55c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,9 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, ProjectEstimation} -import org.apache.spark.sql.catalyst.plans.logical.estimation.JoinEstimation -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, ProjectEstimation} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, JoinEstimation, ProjectEstimation} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -342,27 +340,23 @@ case class Join( case _ => resolvedExceptNatural } - override def computeStats(conf: CatalystConf): Statistics = joinType match { - case LeftAnti | LeftSemi => - // LeftSemi and LeftAnti won't ever be bigger than left - left.stats(conf).copy() - case _ => - // make sure we don't propagate isBroadcastable in other joins, because - // they could explode the size. - super.computeStats(conf).copy(isBroadcastable = false) - } - - override lazy val statistics: Statistics = JoinEstimation.estimate(this).getOrElse( - joinType match { + override def computeStats(conf: CatalystConf): Statistics = { + def simpleEstimation: Statistics = joinType match { case LeftAnti | LeftSemi => // LeftSemi and LeftAnti won't ever be bigger than left - left.statistics + left.stats(conf) case _ => - // make sure we don't propagate isBroadcastable in other joins, because + // Make sure we don't propagate isBroadcastable in other joins, because // they could explode the size. - super.statistics.copy(isBroadcastable = false) + super.computeStats(conf).copy(isBroadcastable = false) + } + + if (conf.cboEnabled) { + JoinEstimation.estimate(conf, this).getOrElse(simpleEstimation) + } else { + simpleEstimation } - ) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala deleted file mode 100644 index f42dd29015bf..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans.logical.estimation - -import scala.math.BigDecimal.RoundingMode - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.types.StringType - - -object EstimationUtils extends Logging { - - /** Check if each plan has rowCount in its statistics. */ - def rowCountsExist(plans: LogicalPlan*): Boolean = - plans.forall(_.statistics.rowCount.isDefined) - - /** Check if each attribute has column stat in the corresponding statistics. */ - def columnStatsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = { - statsAndAttr.forall { case (stats, attr) => - stats.attributeStats.contains(attr) - } - } - - /** Get column stats for output attributes. */ - def getOutputMap(inputMap: AttributeMap[ColumnStat], output: Seq[Attribute]) - : AttributeMap[ColumnStat] = { - AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _))) - } - - def ceil(bigDecimal: BigDecimal): BigInt = bigDecimal.setScale(0, RoundingMode.CEILING).toBigInt() - - def getRowSize(attributes: Seq[Attribute], attrStats: AttributeMap[ColumnStat]): Long = { - // We assign a generic overhead for a Row object, the actual overhead is different for different - // Row format. - 8 + attributes.map { attr => - if (attrStats.contains(attr)) { - attr.dataType match { - case StringType => - // UTF8String: base + offset + numBytes - attrStats(attr).avgLen + 8 + 4 - case _ => - attrStats(attr).avgLen - } - } else { - attr.dataType.defaultSize - } - }.sum - } -} - -/** Attribute Reference extractor */ -object ExtractAttr { - def unapply(exp: Expression): Option[AttributeReference] = exp match { - case ar: AttributeReference => Some(ar) - case _ => None - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index e8b794212c10..3d85c0f6ecb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation +import scala.math.BigDecimal.RoundingMode + import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} import org.apache.spark.sql.types.StringType @@ -29,6 +31,15 @@ object EstimationUtils { def rowCountsExist(conf: CatalystConf, plans: LogicalPlan*): Boolean = plans.forall(_.stats(conf).rowCount.isDefined) + /** Check if each attribute has column stat in the corresponding statistics. */ + def columnStatsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = { + statsAndAttr.forall { case (stats, attr) => + stats.attributeStats.contains(attr) + } + } + + def ceil(bigDecimal: BigDecimal): BigInt = bigDecimal.setScale(0, RoundingMode.CEILING).toBigInt() + /** Get column stats for output attributes. */ def getOutputMap(inputMap: AttributeMap[ColumnStat], output: Seq[Attribute]) : AttributeMap[ColumnStat] = { @@ -60,3 +71,11 @@ object EstimationUtils { if (outputRowCount > 0) outputRowCount * sizePerRow else 1 } } + +/** Attribute Reference extractor */ +object ExtractAttr { + def unapply(exp: Expression): Option[AttributeReference] = exp match { + case ar: AttributeReference => Some(ar) + case _ => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala similarity index 89% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/JoinEstimation.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 86494e21820b..6f8587181549 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -15,16 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.plans.logical.estimation +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.collection.mutable import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} -import org.apache.spark.sql.catalyst.plans.logical.estimation.EstimationUtils._ +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.types.DataType @@ -33,12 +34,12 @@ object JoinEstimation extends Logging { * Estimate statistics after join. Return `None` if the join type is not supported, or we don't * have enough statistics for estimation. */ - def estimate(join: Join): Option[Statistics] = { + def estimate(conf: CatalystConf, join: Join): Option[Statistics] = { join.joinType match { case Inner | Cross | LeftOuter | RightOuter | FullOuter => - InnerOuterEstimation(join).doEstimate() + InnerOuterEstimation(conf, join).doEstimate() case LeftSemi | LeftAnti => - LeftSemiAntiEstimation(join).doEstimate() + LeftSemiAntiEstimation(conf, join).doEstimate() case _ => logDebug(s"Unsupported join type: ${join.joinType}") None @@ -46,19 +47,20 @@ object JoinEstimation extends Logging { } } -case class InnerOuterEstimation(join: Join) extends Logging { +case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging { + + private val leftStats = join.left.stats(conf) + private val rightStats = join.right.stats(conf) /** * Estimate output size and number of rows after a join operator, and update output column stats. */ def doEstimate(): Option[Statistics] = join match { - case _ if !rowCountsExist(join.left, join.right) => + case _ if !rowCountsExist(conf, join.left, join.right) => None case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => // 1. Compute join selectivity - val leftStats = left.statistics - val rightStats = right.statistics val joinKeyPairs = extractJoinKeys(leftKeys, rightKeys) val selectivity = joinSelectivity(joinKeyPairs, leftStats, rightStats) @@ -94,7 +96,7 @@ case class InnerOuterEstimation(join: Join) extends Logging { updateIntersectedStats(joinKeyPairs, leftStats, rightStats) } val inputAttrStats = AttributeMap( - join.left.statistics.attributeStats.toSeq ++ join.right.statistics.attributeStats.toSeq) + leftStats.attributeStats.toSeq ++ rightStats.attributeStats.toSeq) val attributesWithStat = join.output.filter(a => inputAttrStats.contains(a)) val (fromLeft, fromRight) = attributesWithStat.partition(join.left.outputSet.contains(_)) val outputStats: Map[Attribute, ColumnStat] = join.joinType match { @@ -116,7 +118,7 @@ case class InnerOuterEstimation(join: Join) extends Logging { val outputAttrStats = AttributeMap(outputStats.toSeq) Some(Statistics( - sizeInBytes = outputRows * getRowSize(join.output, outputAttrStats), + sizeInBytes = getOutputSize(join.output, outputAttrStats, outputRows), rowCount = Some(outputRows), attributeStats = outputAttrStats, isBroadcastable = false)) @@ -124,12 +126,12 @@ case class InnerOuterEstimation(join: Join) extends Logging { case _ => // When there is no equi-join condition, we do estimation like cartesian product. val inputAttrStats = AttributeMap( - join.left.statistics.attributeStats.toSeq ++ join.right.statistics.attributeStats.toSeq) + leftStats.attributeStats.toSeq ++ rightStats.attributeStats.toSeq) // Propagate the original column stats val outputAttrStats = getOutputMap(inputAttrStats, join.output) - val outputRows = join.left.statistics.rowCount.get * join.right.statistics.rowCount.get + val outputRows = leftStats.rowCount.get * rightStats.rowCount.get Some(Statistics( - sizeInBytes = outputRows * getRowSize(join.output, outputAttrStats), + sizeInBytes = getOutputSize(join.output, outputAttrStats, outputRows), rowCount = Some(outputRows), attributeStats = outputAttrStats, isBroadcastable = false)) @@ -195,8 +197,8 @@ case class InnerOuterEstimation(join: Join) extends Logging { oldAttrStats: AttributeMap[ColumnStat], joinKeyStats: AttributeMap[ColumnStat]): AttributeMap[ColumnStat] = { val outputAttrStats = new mutable.HashMap[Attribute, ColumnStat]() - val leftRows = join.left.statistics.rowCount.get - val rightRows = join.right.statistics.rowCount.get + val leftRows = leftStats.rowCount.get + val rightRows = rightStats.rowCount.get if (outputRows == 0) { // empty output attributes.foreach(a => outputAttrStats.put(a, emptyColumnStat(a.dataType))) @@ -290,17 +292,18 @@ case class InnerOuterEstimation(join: Join) extends Logging { } } -case class LeftSemiAntiEstimation(join: Join) { +case class LeftSemiAntiEstimation(conf: CatalystConf, join: Join) { def doEstimate(): Option[Statistics] = { // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic // column stats. Now we just propagate the statistics from left side. We should do more // accurate estimation when advanced stats (e.g. histograms) are available. - if (rowCountsExist(join.left)) { + if (rowCountsExist(conf, join.left)) { + val leftStats = join.left.stats(conf) // Propagate the original column stats for cartesian product - val outputAttrStats = getOutputMap(join.left.statistics.attributeStats, join.output) - val outputRows = join.left.statistics.rowCount.get + val outputAttrStats = getOutputMap(leftStats.attributeStats, join.output) + val outputRows = leftStats.rowCount.get Some(Statistics( - sizeInBytes = outputRows * getRowSize(join.output, outputAttrStats), + sizeInBytes = getOutputSize(join.output, outputAttrStats, outputRows), rowCount = Some(outputRows), attributeStats = outputAttrStats, isBroadcastable = false)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala similarity index 96% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 8fe84688f435..41b09fdbf277 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.plans.logical.estimation +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import java.math.{BigDecimal => JDecimal} import java.sql.{Date, Timestamp} @@ -115,8 +115,6 @@ object Range { case TimestampType => (DateTimeUtils.toJavaTimestamp(n.min.longValue()), DateTimeUtils.toJavaTimestamp(n.max.longValue())) - case _ => - throw new AnalysisException(s"Type $dataType is not castable from numeric in estimation.") } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala new file mode 100644 index 000000000000..4c97d4ec52a4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, EqualTo} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} + + +class JoinEstimationSuite extends StatsEstimationTestBase { + + /** Set up tables and its columns for testing */ + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + attr("key11") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key12") -> ColumnStat(distinctCount = 5, min = Some(5), max = Some(9), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key22") -> ColumnStat(distinctCount = 3, min = Some(2), max = Some(4), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key31") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key32") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, + avgLen = 4, maxLen = 4) + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + // Suppose table1 (key11 int, key12 int) has 5 records: (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + private val table1 = StatsTestPlan( + outputList = Seq("key11", "key12").map(nameToAttr), + rowCount = 5, + attributeStats = AttributeMap(Seq("key11", "key12").map(nameToColInfo))) + + // Suppose table2 (key21 int, key22 int) has 3 records: (1, 2), (2, 3), (2, 4) + private val table2 = StatsTestPlan( + outputList = Seq("key21", "key22").map(nameToAttr), + rowCount = 3, + attributeStats = AttributeMap(Seq("key21", "key22").map(nameToColInfo))) + + // Suppose table3 (key31 int, key32 int) has 2 records: (1, 2), (2, 3) + private val table3 = StatsTestPlan( + outputList = Seq("key31", "key32").map(nameToAttr), + rowCount = 2, + attributeStats = AttributeMap(Seq("key31", "key32").map(nameToColInfo))) + + test("cross join") { + // table1 (key11 int, key12 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) + val join = Join(table1, table2, Cross, None) + val expectedStats = Statistics( + sizeInBytes = 5 * 3 * (8 + 4 * 4), + rowCount = Some(5 * 3), + // Keep the column stat from both sides unchanged. + attributeStats = AttributeMap(Seq("key11", "key12", "key21", "key22").map(nameToColInfo))) + assert(join.stats(conf) == expectedStats) + } + + test("disjoint inner join") { + // table1 (key11 int, key12 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) + // key12 and key22 are disjoint + val join = Join(table1, table2, Inner, Some( + And(EqualTo(nameToAttr("key11"), nameToAttr("key21")), + EqualTo(nameToAttr("key12"), nameToAttr("key22"))))) + // Empty column stats for all output columns. + val emptyColStat = ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, + avgLen = 4, maxLen = 4) + + val expectedStats = Statistics( + sizeInBytes = 1, + rowCount = Some(0), + attributeStats = AttributeMap( + Seq("key11", "key12", "key21", "key22").map(c => (nameToAttr(c), emptyColStat)))) + assert(join.stats(conf) == expectedStats) + } + + test("inner join") { + // table1 (key11 int, key12 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) + val join = Join(table1, table2, Inner, Some(EqualTo(nameToAttr("key11"), nameToAttr("key21")))) + // Update column stats for equi-join keys (key11 and key21). + val joinedColStat = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4) + // Update column stat for other column if #outputRow / #sideRow < 1 (key12), or keep it + // unchanged (key22). + val colStatForKey12 = nameToColInfo("key12")._2.copy(distinctCount = 5 * 3 / 5) + + val expectedStats = Statistics( + sizeInBytes = 3 * (8 + 4 * 4), + rowCount = Some(3), + attributeStats = AttributeMap( + Seq(nameToAttr("key11") -> joinedColStat, nameToAttr("key21") -> joinedColStat, + nameToAttr("key12") -> colStatForKey12, nameToColInfo("key22")))) + assert(join.stats(conf) == expectedStats) + } + + test("inner join with multiple equi-join keys") { + // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) + // table3 (key31 int, key32 int): (1, 2), (2, 3) + val join = Join(table2, table3, Inner, Some( + And(EqualTo(nameToAttr("key21"), nameToAttr("key31")), + EqualTo(nameToAttr("key22"), nameToAttr("key32"))))) + + // Update column stats for join keys. + val joinedColStat1 = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4) + val joinedColStat2 = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, + avgLen = 4, maxLen = 4) + + val expectedStats = Statistics( + sizeInBytes = 2 * (8 + 4 * 4), + rowCount = Some(2), + attributeStats = AttributeMap( + Seq(nameToAttr("key21") -> joinedColStat1, nameToAttr("key31") -> joinedColStat1, + nameToAttr("key22") -> joinedColStat2, nameToAttr("key32") -> joinedColStat2))) + assert(join.stats(conf) == expectedStats) + } + + test("left outer join") { + // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) + // table3 (key31 int, key32 int): (1, 2), (2, 3) + val join = Join(table3, table2, LeftOuter, + Some(EqualTo(nameToAttr("key32"), nameToAttr("key22")))) + val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, + avgLen = 4, maxLen = 4) + + val expectedStats = Statistics( + sizeInBytes = 2 * (8 + 4 * 4), + rowCount = Some(2), + // Keep the column stat from left side unchanged. + attributeStats = AttributeMap( + Seq(nameToColInfo("key31"), nameToColInfo("key32"), + nameToColInfo("key21"), nameToAttr("key22") -> joinedColStat))) + assert(join.stats(conf) == expectedStats) + } + + test("right outer join") { + // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) + // table3 (key31 int, key32 int): (1, 2), (2, 3) + val join = Join(table2, table3, RightOuter, + Some(EqualTo(nameToAttr("key22"), nameToAttr("key32")))) + val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, + avgLen = 4, maxLen = 4) + + val expectedStats = Statistics( + sizeInBytes = 2 * (8 + 4 * 4), + rowCount = Some(2), + // Keep the column stat from right side unchanged. + attributeStats = AttributeMap( + Seq(nameToColInfo("key21"), nameToAttr("key22") -> joinedColStat, + nameToColInfo("key31"), nameToColInfo("key32")))) + assert(join.stats(conf) == expectedStats) + } + + test("full outer join") { + // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) + // table3 (key31 int, key32 int): (1, 2), (2, 3) + val join = Join(table2, table3, FullOuter, + Some(EqualTo(nameToAttr("key22"), nameToAttr("key32")))) + + val expectedStats = Statistics( + sizeInBytes = 3 * (8 + 4 * 4), + rowCount = Some(3), + // Keep the column stat from both sides unchanged. + attributeStats = AttributeMap(Seq(nameToColInfo("key21"), nameToColInfo("key22"), + nameToColInfo("key31"), nameToColInfo("key32")))) + assert(join.stats(conf) == expectedStats) + } + + test("left semi/anti join") { + // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) + // table3 (key31 int, key32 int): (1, 2), (2, 3) + Seq(LeftSemi, LeftAnti).foreach { jt => + val join = Join(table2, table3, jt, Some(EqualTo(nameToAttr("key22"), nameToAttr("key32")))) + // For now we just propagate the statistics from left side for left semi/anti join. + val expectedStats = Statistics( + sizeInBytes = 3 * (8 + 4 * 2), + rowCount = Some(3), + attributeStats = AttributeMap(Seq(nameToColInfo("key21"), nameToColInfo("key22")))) + assert(join.stats(conf) == expectedStats) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/estimation/JoinEstimationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/estimation/JoinEstimationSuite.scala deleted file mode 100644 index 9ba0d7f7642f..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/estimation/JoinEstimationSuite.scala +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.estimation - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions.AttributeMap -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} -import org.apache.spark.sql.catalyst.plans.logical.estimation.EstimationUtils._ -import org.apache.spark.sql.test.SharedSQLContext - - -class JoinEstimationSuite extends QueryTest with SharedSQLContext { - import testImplicits._ - - /** Data for one-column tables */ - private val joinEstimationTestData1 = Seq( - ("join_est_test1", "key1", Seq[Int](1, 1, 2, 2, 3, 3, 3)), - ("join_est_test2", "key2", Seq[Int](5, 6))) - - /** Data for two-column tables */ - private val joinEstimationTestData2 = Seq( - ("join_est_test3", Seq("key31", "key32"), - Seq[(Int, Int)]((1, 9), (2, 8), (3, 7), (4, 6), (5, 5))), - ("join_est_test4", Seq("key41", "key42"), - Seq[(Int, Int)]((1, 3), (2, 4)))) - - /** Original column stats */ - val colStatForKey1 = ColumnStat(3, Some(1), Some(3), 0, 4, 4) - val colStatForKey2 = ColumnStat(2, Some(5), Some(6), 0, 4, 4) - val colStatForKey31 = ColumnStat(5, Some(1), Some(5), 0, 4, 4) - val colStatForKey32 = ColumnStat(5, Some(5), Some(9), 0, 4, 4) - val colStatForKey41 = ColumnStat(2, Some(1), Some(2), 0, 4, 4) - val colStatForKey42 = ColumnStat(2, Some(3), Some(4), 0, 4, 4) - - override def beforeAll(): Unit = { - super.beforeAll() - // Create tables and collect statistics - joinEstimationTestData1.foreach { case (table, column, data) => - data.toDF(column).write.saveAsTable(table) - sql(s"analyze table $table compute STATISTICS FOR COLUMNS $column") - } - joinEstimationTestData2.foreach { case (table, columns, data) => - data.toDF(columns: _*).write.saveAsTable(table) - sql(s"analyze table $table compute STATISTICS FOR COLUMNS ${columns.mkString(", ")}") - } - } - - override def afterAll(): Unit = { - joinEstimationTestData1.foreach { case (table, _, _) => sql(s"DROP TABLE IF EXISTS $table") } - joinEstimationTestData2.foreach { case (table, _, _) => sql(s"DROP TABLE IF EXISTS $table") } - super.afterAll() - } - - - test("estimate inner join") { - val innerJoinSql = - "select count(1) from join_est_test1 join join_est_test3 on key1 = key31" - // Update column stats from both sides. - val joinedColStat = ColumnStat(3, Some(1), Some(3), 0, 4, 4) - val colStats = Seq("key1" -> joinedColStat, "key31" -> joinedColStat) - validateEstimatedStats(innerJoinSql, colStats) - } - - test("update column stats for join keys and non-join keys") { - val innerJoinSql = - "select count(1) from join_est_test3 join join_est_test4 on key31 = key41 and key32 > key42" - // Update column stats for both join keys. - // Update non-join column stat if #outputRow / #sideRow < 1, otherwise keep it unchanged. - val joinedColStat = ColumnStat(2, Some(1), Some(2), 0, 4, 4) - val colStats = Seq("key31" -> joinedColStat, "key41" -> joinedColStat, - "key32" -> colStatForKey32.copy(distinctCount = 2), "key42" -> colStatForKey42) - validateEstimatedStats(innerJoinSql, colStats) - } - - test("estimate disjoint inner join") { - val innerJoinSql = - "select count(1) from join_est_test1 join join_est_test2 on key1 = key2" - // Empty column stats for both sides. - val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) - val colStats = Seq("key1" -> emptyColStat, "key2" -> emptyColStat) - validateEstimatedStats(innerJoinSql, colStats) - } - - test("estimate cross join without equal conditions") { - val crossJoinSql = - "select count(1) from join_est_test1 cross join join_est_test2 on key1 < key2" - // Keep the column stat from both sides unchanged. - val colStats = Seq("key1" -> colStatForKey1, "key2" -> colStatForKey2) - validateEstimatedStats(crossJoinSql, colStats) - } - - test("estimate left outer join") { - val leftOuterJoinSql = - "select count(1) from join_est_test3 left join join_est_test4 on key31 = key41" - // Keep the column stat from left side unchanged. - val joinedColStat = ColumnStat(2, Some(1), Some(2), 0, 4, 4) - val colStats = Seq("key31" -> colStatForKey31, "key41" -> joinedColStat) - validateEstimatedStats(leftOuterJoinSql, colStats) - } - - test("estimate right outer join") { - val rightOuterJoinSql = - "select count(1) from join_est_test4 right join join_est_test3 on key41 = key31" - // Keep the column stat from right side unchanged. - val joinedColStat = ColumnStat(2, Some(1), Some(2), 0, 4, 4) - val colStats = Seq("key41" -> joinedColStat, "key31" -> colStatForKey31) - validateEstimatedStats(rightOuterJoinSql, colStats) - } - - test("estimate full outer join") { - val fullOuterJoinSql = - "select count(1) from join_est_test3 full join join_est_test4 on key31 = key41" - // Keep the column stat from both sides unchanged. - val colStats = Seq("key31" -> colStatForKey31, "key41" -> colStatForKey41) - validateEstimatedStats(fullOuterJoinSql, colStats) - } - - test("estimate left semi/anti join") { - val joinTypeStrings = Seq("left semi", "left anti") - joinTypeStrings.foreach { str => - val joinSql = - s"select count(1) from join_est_test3 $str join join_est_test4 on key31 = key41" - // For now we just propagate the statistics from left side for left semi/anti join. - val colStats = Seq("key31" -> colStatForKey31) - validateEstimatedStats(joinSql, colStats, Some(5)) - } - } - - private def validateEstimatedStats( - joinSql: String, - expectedColStats: Seq[(String, ColumnStat)], - rowCount: Option[Long] = None) : Unit = { - val logicalPlan = sql(joinSql).queryExecution.optimizedPlan - val joinNode = logicalPlan.collect { case join: Join => join }.head - val expectedRowCount = rowCount.getOrElse(sql(joinSql).collect().head.getLong(0)) - val nameToAttr = joinNode.output.map(a => (a.name, a)).toMap - val expectedAttrStats = - AttributeMap(expectedColStats.map(kv => nameToAttr(kv._1) -> kv._2)) - val expectedStats = Statistics( - sizeInBytes = expectedRowCount * getRowSize(joinNode.output, expectedAttrStats), - rowCount = Some(expectedRowCount), - attributeStats = expectedAttrStats, - isBroadcastable = false) - assert(joinNode.statistics == expectedStats) - } -} From e6d143349e939a5854fc655b358959ebcfa7042b Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Fri, 13 Jan 2017 11:48:55 +0800 Subject: [PATCH 03/12] fix comments --- .../plans/logical/statsEstimation/EstimationUtils.scala | 8 -------- .../plans/logical/statsEstimation/JoinEstimation.scala | 8 ++++---- .../catalyst/plans/logical/statsEstimation/Range.scala | 3 --- 3 files changed, 4 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 3d85c0f6ecb7..1b0b8797c760 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -71,11 +71,3 @@ object EstimationUtils { if (outputRowCount > 0) outputRowCount * sizePerRow else 1 } } - -/** Attribute Reference extractor */ -object ExtractAttr { - def unapply(exp: Expression): Option[AttributeReference] = exp match { - case ar: AttributeReference => Some(ar) - case _ => None - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 6f8587181549..1bbdbc3f4e90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -41,7 +41,7 @@ object JoinEstimation extends Logging { case LeftSemi | LeftAnti => LeftSemiAntiEstimation(conf, join).doEstimate() case _ => - logDebug(s"Unsupported join type: ${join.joinType}") + logDebug(s"[CBO] Unsupported join type: ${join.joinType}") None } } @@ -281,12 +281,12 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Seq[(AttributeReference, AttributeReference)] = { leftKeys.zip(rightKeys).flatMap { - case (ExtractAttr(left), ExtractAttr(right)) => Some((left, right)) - case (left, right) => + case (lk: AttributeReference, rk: AttributeReference) => Some((lk, rk)) + case (lk, rk) => // Currently we don't deal with equal joins like key1 = key2 + 5. // Note: join keys from EqualNullSafe also fall into this case (Coalesce), consider to // support it in the future by using `nullCount` in column stats. - logDebug(s"Unsupported equi-join expression: left key: $left, right key: $right") + logDebug(s"[CBO] Unsupported equi-join expression: left key: $lk, right key: $rk") None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 41b09fdbf277..9af55cd6ffb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import java.math.{BigDecimal => JDecimal} import java.sql.{Date, Timestamp} -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} @@ -95,8 +94,6 @@ object Range { val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp]) val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp]) NumericRange(new JDecimal(min1), new JDecimal(max1)) - case _ => - throw new AnalysisException(s"Type $dataType is not castable to numeric in estimation.") } } From 1f93a55cdcf8a1a9842eea5d1f1fc2c0ab678687 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Fri, 13 Jan 2017 17:23:57 +0800 Subject: [PATCH 04/12] add tests for all data types and null column --- .../statsEstimation/JoinEstimationSuite.scala | 116 +++++++++++++++++- .../ProjectEstimationSuite.scala | 21 ++-- .../StatsEstimationTestBase.scala | 7 +- 3 files changed, 129 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 4c97d4ec52a4..83b9c91137b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.catalyst.statsEstimation -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, EqualTo} +import java.sql.{Date, Timestamp} + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics} +import org.apache.spark.sql.types.{DateType, TimestampType, _} class JoinEstimationSuite extends StatsEstimationTestBase { @@ -199,4 +204,111 @@ class JoinEstimationSuite extends StatsEstimationTestBase { assert(join.stats(conf) == expectedStats) } } + + test("test join keys of different types") { + val dec1 = new java.math.BigDecimal("1.000000000000000000") + val dec2 = new java.math.BigDecimal("8.000000000000000000") + val d1 = Date.valueOf("2016-05-08") + val d2 = Date.valueOf("2016-05-09") + val t1 = Timestamp.valueOf("2016-05-08 00:00:01") + val t2 = Timestamp.valueOf("2016-05-09 00:00:02") + + /** Columns in a table with only one row */ + val columnInfo1 = mutable.LinkedHashMap[Attribute, ColumnStat]( + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, + min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, + min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 1, maxLen = 1), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, + min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 2, maxLen = 2), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, + min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 4, maxLen = 4), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, + min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, + min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, + min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 4, maxLen = 4), + AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, + min = Some(dec1), max = Some(dec1), nullCount = 0, avgLen = 16, maxLen = 16), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, + min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 1, + min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 1, + min = Some(d1), max = Some(d1), nullCount = 0, avgLen = 4, maxLen = 4), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 1, + min = Some(t1), max = Some(t1), nullCount = 0, avgLen = 8, maxLen = 8) + ) + + /** Columns in a table with two rows */ + val columnInfo2 = mutable.LinkedHashMap[Attribute, ColumnStat]( + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2, + min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2, + min = Some(1L), max = Some(2L), nullCount = 0, avgLen = 1, maxLen = 1), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2, + min = Some(1L), max = Some(3L), nullCount = 0, avgLen = 2, maxLen = 2), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2, + min = Some(1L), max = Some(4L), nullCount = 0, avgLen = 4, maxLen = 4), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2, + min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2, + min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2, + min = Some(1.0), max = Some(7.0), nullCount = 0, avgLen = 4, maxLen = 4), + AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2, + min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2, + min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 2, + min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 2, + min = Some(d1), max = Some(d2), nullCount = 0, avgLen = 4, maxLen = 4), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 2, + min = Some(t1), max = Some(t2), nullCount = 0, avgLen = 8, maxLen = 8) + ) + + val oneRowTable = StatsTestPlan( + outputList = columnInfo1.keys.toSeq, + rowCount = 1, + attributeStats = AttributeMap(columnInfo1.toSeq)) + val twoRowTable = StatsTestPlan( + outputList = columnInfo2.keys.toSeq, + rowCount = 2, + attributeStats = AttributeMap(columnInfo2.toSeq)) + val joinKeys = oneRowTable.output.zip(twoRowTable.output) + joinKeys.foreach { case (key1, key2) => + withClue(s"For data type ${key1.dataType}") { + // All values in oneRowTable is contained in twoRowTable, so column stats after join is + // equal to that of oneRowTable. + val join = Join(Project(Seq(key1), oneRowTable), Project(Seq(key2), twoRowTable), Inner, + Some(EqualTo(key1, key2))) + val expectedStats = Statistics( + sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))), + rowCount = Some(1), + attributeStats = AttributeMap(Seq(key1 -> columnInfo1(key1), key2 -> columnInfo1(key1)))) + assert(join.stats(conf) == expectedStats) + } + } + } + + test("join with null column") { + val (nullColumn, nullColStat) = (attr("cnull"), + ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 1, avgLen = 4, maxLen = 4)) + val nullTable = StatsTestPlan( + outputList = Seq(nullColumn), + rowCount = 1, + attributeStats = AttributeMap(Seq(nullColumn -> nullColStat))) + val join = Join(table1, nullTable, Inner, + Some(EqualTo(nameToAttr("key11"), nullColumn))) + val emptyColStat = ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, + avgLen = 4, maxLen = 4) + val expectedStats = Statistics( + sizeInBytes = 1, + rowCount = Some(0), + attributeStats = AttributeMap(Seq(nameToAttr("key11") -> emptyColStat, + nameToAttr("key12") -> emptyColStat, nullColumn -> emptyColStat))) + assert(join.stats(conf) == expectedStats) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index ae102a48451e..f408dc415358 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.statsEstimation import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -95,12 +95,7 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 2, min = Some(t1), max = Some(t2), nullCount = 0, avgLen = 8, maxLen = 8) )) - val columnSizes = columnInfo.map { case (attr, colStat) => - (attr, attr.dataType match { - case StringType => colStat.avgLen + 8 + 4 - case _ => colStat.avgLen - }) - } + val columnSizes: Map[Attribute, Long] = columnInfo.map(kv => (kv._1, getColSize(kv._1, kv._2))) val child = StatsTestPlan( outputList = columnInfo.keys.toSeq, rowCount = 2, @@ -108,11 +103,13 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { // Row with single column columnInfo.keys.foreach { attr => - checkProjectStats( - child = child, - projectAttrMap = AttributeMap(attr -> columnInfo(attr) :: Nil), - expectedSize = 2 * (8 + columnSizes(attr)), - expectedRowCount = 2) + withClue(s"For data type ${attr.dataType}") { + checkProjectStats( + child = child, + projectAttrMap = AttributeMap(attr -> columnInfo(attr) :: Nil), + expectedSize = 2 * (8 + columnSizes(attr)), + expectedRowCount = 2) + } } // Row with multiple columns diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index a5fac4ba6f03..8563e4da56ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{IntegerType, StringType} class StatsEstimationTestBase extends SparkFunSuite { @@ -29,6 +29,11 @@ class StatsEstimationTestBase extends SparkFunSuite { /** Enable stats estimation based on CBO. */ protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true, cboEnabled = true) + def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { + case StringType => colStat.avgLen + 8 + 4 + case _ => colStat.avgLen + } + def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)() /** Convert (column name, column stat) pairs to an AttributeMap based on plan output. */ From 4d0eba2cf087317c03b9a2ab18edf0a9c803336e Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Sun, 22 Jan 2017 12:59:52 +0800 Subject: [PATCH 05/12] rebase: getOutputSize --- .../plans/logical/statsEstimation/JoinEstimation.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 1bbdbc3f4e90..3f4a62ade56b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -118,7 +118,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val outputAttrStats = AttributeMap(outputStats.toSeq) Some(Statistics( - sizeInBytes = getOutputSize(join.output, outputAttrStats, outputRows), + sizeInBytes = getOutputSize(join.output, outputRows, outputAttrStats), rowCount = Some(outputRows), attributeStats = outputAttrStats, isBroadcastable = false)) @@ -131,7 +131,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val outputAttrStats = getOutputMap(inputAttrStats, join.output) val outputRows = leftStats.rowCount.get * rightStats.rowCount.get Some(Statistics( - sizeInBytes = getOutputSize(join.output, outputAttrStats, outputRows), + sizeInBytes = getOutputSize(join.output, outputRows, outputAttrStats), rowCount = Some(outputRows), attributeStats = outputAttrStats, isBroadcastable = false)) @@ -303,7 +303,7 @@ case class LeftSemiAntiEstimation(conf: CatalystConf, join: Join) { val outputAttrStats = getOutputMap(leftStats.attributeStats, join.output) val outputRows = leftStats.rowCount.get Some(Statistics( - sizeInBytes = getOutputSize(join.output, outputAttrStats, outputRows), + sizeInBytes = getOutputSize(join.output, outputRows, outputAttrStats), rowCount = Some(outputRows), attributeStats = outputAttrStats, isBroadcastable = false)) From 25b4367d136e56def8d8736c4d39f049341b8aae Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Sun, 12 Feb 2017 09:33:07 -0500 Subject: [PATCH 06/12] change full outer join computation --- .../plans/logical/statsEstimation/JoinEstimation.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 3f4a62ade56b..052362e48228 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -78,12 +78,8 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging // All rows from right side should be in the result. rightRows.max(innerRows) case FullOuter => - // Simulate full outer join as obtaining the number of elements in the union of two - // finite sets: A \cup B = A + B - A \cap B => A FOJ B = A + B - A IJ B. - // But the "inner join" part can be much larger than A \cap B, making the simulated - // result much smaller. To prevent this, we choose the larger one between the simulated - // part and the inner part. - (leftRows + rightRows - innerRows).max(innerRows) + // T(A FOJ B) = T(A LOJ B) + T(A ROJ B) - T(A IJ B) + leftRows.max(innerRows) + rightRows.max(innerRows) - innerRows case _ => // Don't change for inner or cross join innerRows From d040045a8823f3e84b916237ec8b157bcdb86a8d Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Sun, 12 Feb 2017 10:04:56 -0500 Subject: [PATCH 07/12] improve doc --- .../plans/logical/statsEstimation/JoinEstimation.scala | 10 ++++++++-- .../catalyst/plans/logical/statsEstimation/Range.scala | 5 ++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 052362e48228..7a8a37d7aa27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -186,7 +186,13 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging } } - /** Update column stats for output attributes. */ + /** + * Update column stats for output attributes. + * 1. For empty output, update all column stats to be empty. + * 2. For cartesian product, all values are preserved, so there's no need to change column stats. + * 3. For other cases, a) update max/min of join keys based on their intersected range. b) update + * distinct count of other attributes based on output rows after join. + */ private def updateAttrStats( outputRows: BigInt, attributes: Seq[Attribute], @@ -199,7 +205,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging // empty output attributes.foreach(a => outputAttrStats.put(a, emptyColumnStat(a.dataType))) } else if (outputRows == leftRows * rightRows) { - // We do estimation like cartesian product and propagate the original column stats + // Cartesian product, just propagate the original column stats attributes.foreach(a => outputAttrStats.put(a, oldAttrStats(a))) } else { val leftRatio = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 9af55cd6ffb1..dd08c11a0bda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -56,7 +56,10 @@ object Range { n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 } - /** This is only for two overlapped ranges. */ + /** + * Intersected results of two ranges. This is only for two overlapped ranges. + * The outputs are the intersected min/max values of the two columns based on their data types. + */ def intersect( r1: Range, r2: Range, From 05efa810598ffdef262272bb2f7ab906cf8bab3d Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 13 Feb 2017 21:58:55 -0800 Subject: [PATCH 08/12] 1. remove empty column stats, 2. deal with outer joins when inner part is empty. --- .../statsEstimation/JoinEstimation.scala | 157 +++++++++--------- .../statsEstimation/JoinEstimationSuite.scala | 15 +- 2 files changed, 85 insertions(+), 87 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 7a8a37d7aa27..dacd3a32b88d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -61,55 +61,78 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => // 1. Compute join selectivity - val joinKeyPairs = extractJoinKeys(leftKeys, rightKeys) - val selectivity = joinSelectivity(joinKeyPairs, leftStats, rightStats) + val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) + val selectivity = joinSelectivity(joinKeyPairs) // 2. Estimate the number of output rows val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get - val innerRows = ceil(BigDecimal(leftRows * rightRows) * selectivity) + val innerJoinedRows = ceil(BigDecimal(leftRows * rightRows) * selectivity) // Make sure outputRows won't be too small based on join type. val outputRows = joinType match { case LeftOuter => // All rows from left side should be in the result. - leftRows.max(innerRows) + leftRows.max(innerJoinedRows) case RightOuter => // All rows from right side should be in the result. - rightRows.max(innerRows) + rightRows.max(innerJoinedRows) case FullOuter => // T(A FOJ B) = T(A LOJ B) + T(A ROJ B) - T(A IJ B) - leftRows.max(innerRows) + rightRows.max(innerRows) - innerRows + leftRows.max(innerJoinedRows) + rightRows.max(innerJoinedRows) - innerJoinedRows case _ => // Don't change for inner or cross join - innerRows + innerJoinedRows } // 3. Update statistics based on the output of join - val intersectedStats = if (selectivity == 0) { - AttributeMap[ColumnStat](Nil) - } else { - updateIntersectedStats(joinKeyPairs, leftStats, rightStats) - } val inputAttrStats = AttributeMap( leftStats.attributeStats.toSeq ++ rightStats.attributeStats.toSeq) + val joinKeyStats = if (innerJoinedRows == 0) { + val leftKeys = joinKeyPairs.map(_._1) + val rightKeys = joinKeyPairs.map(_._2) + joinType match { + // For outer joins, if the inner join part is empty, the number of output rows is the + // same as that of the outer side. And column stats of join keys from the outer side + // keep unchanged, while column stats of join keys from the other side should be updated + // based on added null values. + case LeftOuter => + AttributeMap[ColumnStat](leftKeys.map(k => (k, inputAttrStats(k))) ++ + rightKeys.map(k => (k, nullColumnStat(k.dataType, leftRows)))) + case RightOuter => + AttributeMap[ColumnStat](rightKeys.map(k => (k, inputAttrStats(k))) ++ + leftKeys.map(k => (k, nullColumnStat(k.dataType, rightRows)))) + case FullOuter => + AttributeMap[ColumnStat](leftKeys.map { k => + val oriColStat = inputAttrStats(k) + (k, oriColStat.copy(distinctCount = oriColStat.distinctCount + rightRows)) + } ++ rightKeys.map { k => + val oriColStat = inputAttrStats(k) + (k, oriColStat.copy(distinctCount = oriColStat.distinctCount + leftRows)) + }) + case _ => + AttributeMap[ColumnStat](Nil) + } + } else { + getIntersectedStats(joinKeyPairs) + } val attributesWithStat = join.output.filter(a => inputAttrStats.contains(a)) val (fromLeft, fromRight) = attributesWithStat.partition(join.left.outputSet.contains(_)) val outputStats: Map[Attribute, ColumnStat] = join.joinType match { case LeftOuter => // Don't update column stats for attributes from left side. fromLeft.map(a => (a, inputAttrStats(a))).toMap ++ - updateAttrStats(outputRows, fromRight, inputAttrStats, intersectedStats) + updateAttrStats(outputRows, fromRight, inputAttrStats, joinKeyStats) case RightOuter => // Don't update column stats for attributes from right side. - updateAttrStats(outputRows, fromLeft, inputAttrStats, intersectedStats) ++ + updateAttrStats(outputRows, fromLeft, inputAttrStats, joinKeyStats) ++ fromRight.map(a => (a, inputAttrStats(a))).toMap case FullOuter => // Don't update column stats for attributes from both sides. attributesWithStat.map(a => (a, inputAttrStats(a))).toMap case _ => // Update column stats from both sides for inner or cross join. - updateAttrStats(outputRows, attributesWithStat, inputAttrStats, intersectedStats) + updateAttrStats(outputRows, attributesWithStat, inputAttrStats, joinKeyStats) } val outputAttrStats = AttributeMap(outputStats.toSeq) @@ -146,30 +169,23 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging * conservative strategy to take only the largest max(V(A.ki), V(B.ki)) as the denominator. */ // scalastyle:on - def joinSelectivity( - joinKeyPairs: Seq[(AttributeReference, AttributeReference)], - leftStats: Statistics, - rightStats: Statistics): BigDecimal = { - + def joinSelectivity(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]): BigDecimal = { var ndvDenom: BigInt = -1 var i = 0 while(i < joinKeyPairs.length && ndvDenom != 0) { val (leftKey, rightKey) = joinKeyPairs(i) - // Do estimation if we have enough statistics - if (columnStatsExist((leftStats, leftKey), (rightStats, rightKey))) { - // Check if the two sides are disjoint - val leftKeyStats = leftStats.attributeStats(leftKey) - val rightKeyStats = rightStats.attributeStats(rightKey) - val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) - if (Range.isIntersected(lRange, rRange)) { - // Get the largest ndv among pairs of join keys - val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount) - if (maxNdv > ndvDenom) ndvDenom = maxNdv - } else { - // Set ndvDenom to zero to indicate that this join should have no output - ndvDenom = 0 - } + // Check if the two sides are disjoint + val leftKeyStats = leftStats.attributeStats(leftKey) + val rightKeyStats = rightStats.attributeStats(rightKey) + val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) + val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + if (Range.isIntersected(lRange, rRange)) { + // Get the largest ndv among pairs of join keys + val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount) + if (maxNdv > ndvDenom) ndvDenom = maxNdv + } else { + // Set ndvDenom to zero to indicate that this join should have no output + ndvDenom = 0 } i += 1 } @@ -187,8 +203,8 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging } /** - * Update column stats for output attributes. - * 1. For empty output, update all column stats to be empty. + * Propagate or update column stats for output attributes. + * 1. For empty output, we don't need to keep any column stats. * 2. For cartesian product, all values are preserved, so there's no need to change column stats. * 3. For other cases, a) update max/min of join keys based on their intersected range. b) update * distinct count of other attributes based on output rows after join. @@ -201,13 +217,10 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val outputAttrStats = new mutable.HashMap[Attribute, ColumnStat]() val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get - if (outputRows == 0) { - // empty output - attributes.foreach(a => outputAttrStats.put(a, emptyColumnStat(a.dataType))) - } else if (outputRows == leftRows * rightRows) { + if (outputRows == leftRows * rightRows) { // Cartesian product, just propagate the original column stats attributes.foreach(a => outputAttrStats.put(a, oldAttrStats(a))) - } else { + } else if (outputRows != 0) { val leftRatio = if (leftRows != 0) BigDecimal(outputRows) / BigDecimal(leftRows) else BigDecimal(0) val rightRatio = @@ -237,37 +250,33 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging AttributeMap(outputAttrStats.toSeq) } - /** Update intersected column stats for join keys. */ - private def updateIntersectedStats( - joinKeyPairs: Seq[(AttributeReference, AttributeReference)], - leftStats: Statistics, - rightStats: Statistics): AttributeMap[ColumnStat] = { + /** Get intersected column stats for join keys. */ + private def getIntersectedStats(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]) + : AttributeMap[ColumnStat] = { + val intersectedStats = new mutable.HashMap[Attribute, ColumnStat]() joinKeyPairs.foreach { case (leftKey, rightKey) => - // Do estimation if we have enough statistics - if (columnStatsExist((leftStats, leftKey), (rightStats, rightKey))) { - // Check if the two sides are disjoint - val leftKeyStats = leftStats.attributeStats(leftKey) - val rightKeyStats = rightStats.attributeStats(rightKey) - val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) - if (Range.isIntersected(lRange, rRange)) { - // Update intersected column stats - val minNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) - val (newMin1, newMax1, newMin2, newMax2) = - Range.intersect(lRange, rRange, leftKey.dataType, rightKey.dataType) - intersectedStats.put(leftKey, intersectedColumnStat(leftKeyStats, minNdv, - newMin1, newMax1)) - intersectedStats.put(rightKey, intersectedColumnStat(rightKeyStats, minNdv, - newMin2, newMax2)) - } + // Check if the two sides are disjoint + val leftKeyStats = leftStats.attributeStats(leftKey) + val rightKeyStats = rightStats.attributeStats(rightKey) + val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) + val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + if (Range.isIntersected(lRange, rRange)) { + // Update intersected column stats + val minNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) + val (newMin1, newMax1, newMin2, newMax2) = + Range.intersect(lRange, rRange, leftKey.dataType, rightKey.dataType) + intersectedStats.put(leftKey, intersectedColumnStat(leftKeyStats, minNdv, + newMin1, newMax1)) + intersectedStats.put(rightKey, intersectedColumnStat(rightKeyStats, minNdv, + newMin2, newMax2)) } } AttributeMap(intersectedStats.toSeq) } - private def emptyColumnStat(dataType: DataType): ColumnStat = { - ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, + private def nullColumnStat(dataType: DataType, rowCount: BigInt): ColumnStat = { + ColumnStat(distinctCount = 0, min = None, max = None, nullCount = rowCount, avgLen = dataType.defaultSize, maxLen = dataType.defaultSize) } @@ -279,17 +288,15 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging origin.copy(distinctCount = newDistinctCount, min = newMin, max = newMax, nullCount = 0) } - private def extractJoinKeys( + private def extractJoinKeysWithColStats( leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Seq[(AttributeReference, AttributeReference)] = { - leftKeys.zip(rightKeys).flatMap { - case (lk: AttributeReference, rk: AttributeReference) => Some((lk, rk)) - case (lk, rk) => - // Currently we don't deal with equal joins like key1 = key2 + 5. - // Note: join keys from EqualNullSafe also fall into this case (Coalesce), consider to - // support it in the future by using `nullCount` in column stats. - logDebug(s"[CBO] Unsupported equi-join expression: left key: $lk, right key: $rk") - None + leftKeys.zip(rightKeys).collect { + // Currently we don't deal with equal joins like key1 = key2 + 5. + // Note: join keys from EqualNullSafe also fall into this case (Coalesce), consider to + // support it in the future by using `nullCount` in column stats. + case (lk: AttributeReference, rk: AttributeReference) + if columnStatsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 83b9c91137b4..f8b318a9823f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -86,15 +86,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val join = Join(table1, table2, Inner, Some( And(EqualTo(nameToAttr("key11"), nameToAttr("key21")), EqualTo(nameToAttr("key12"), nameToAttr("key22"))))) - // Empty column stats for all output columns. - val emptyColStat = ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, - avgLen = 4, maxLen = 4) - val expectedStats = Statistics( sizeInBytes = 1, rowCount = Some(0), - attributeStats = AttributeMap( - Seq("key11", "key12", "key21", "key22").map(c => (nameToAttr(c), emptyColStat)))) + attributeStats = AttributeMap(Nil)) assert(join.stats(conf) == expectedStats) } @@ -300,15 +295,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { outputList = Seq(nullColumn), rowCount = 1, attributeStats = AttributeMap(Seq(nullColumn -> nullColStat))) - val join = Join(table1, nullTable, Inner, - Some(EqualTo(nameToAttr("key11"), nullColumn))) - val emptyColStat = ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, - avgLen = 4, maxLen = 4) + val join = Join(table1, nullTable, Inner, Some(EqualTo(nameToAttr("key11"), nullColumn))) val expectedStats = Statistics( sizeInBytes = 1, rowCount = Some(0), - attributeStats = AttributeMap(Seq(nameToAttr("key11") -> emptyColStat, - nameToAttr("key12") -> emptyColStat, nullColumn -> emptyColStat))) + attributeStats = AttributeMap(Nil)) assert(join.stats(conf) == expectedStats) } } From 331e5d02a24535184eb6a37ab134287bb14b776f Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 14 Feb 2017 13:43:00 -0800 Subject: [PATCH 09/12] other comments --- .../plans/logical/statsEstimation/Range.scala | 3 +- .../statsEstimation/JoinEstimationSuite.scala | 120 +++++++++--------- .../StatsEstimationTestBase.scala | 1 + 3 files changed, 63 insertions(+), 61 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index dd08c11a0bda..bc39c05cd96e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -48,7 +48,8 @@ object Range { def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match { case (_, _: DefaultRange) | (_: DefaultRange, _) => - // Skip overlapping check for binary/string types + // The DefaultRange represents string/binary types which do not have max/min stats, + // we assume they are intersected to be conservative on estimation true case (_, _: NullRange) | (_: NullRange, _) => false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index f8b318a9823f..474595b9940c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -31,17 +31,15 @@ class JoinEstimationSuite extends StatsEstimationTestBase { /** Set up tables and its columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("key11") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), nullCount = 0, + attr("key-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), nullCount = 0, avgLen = 4, maxLen = 4), - attr("key12") -> ColumnStat(distinctCount = 5, min = Some(5), max = Some(9), nullCount = 0, + attr("key-5-9") -> ColumnStat(distinctCount = 5, min = Some(5), max = Some(9), nullCount = 0, avgLen = 4, maxLen = 4), - attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + attr("key-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4), - attr("key22") -> ColumnStat(distinctCount = 3, min = Some(2), max = Some(4), nullCount = 0, + attr("key-2-4") -> ColumnStat(distinctCount = 3, min = Some(2), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4), - attr("key31") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key32") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, + attr("key-2-3") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4) )) @@ -49,43 +47,43 @@ class JoinEstimationSuite extends StatsEstimationTestBase { private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = columnInfo.map(kv => kv._1.name -> kv) - // Suppose table1 (key11 int, key12 int) has 5 records: (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // Suppose table1 (key-1-5 int, key-5-9 int) has 5 records: (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) private val table1 = StatsTestPlan( - outputList = Seq("key11", "key12").map(nameToAttr), + outputList = Seq("key-1-5", "key-5-9").map(nameToAttr), rowCount = 5, - attributeStats = AttributeMap(Seq("key11", "key12").map(nameToColInfo))) + attributeStats = AttributeMap(Seq("key-1-5", "key-5-9").map(nameToColInfo))) - // Suppose table2 (key21 int, key22 int) has 3 records: (1, 2), (2, 3), (2, 4) + // Suppose table2 (key-1-2 int, key-2-4 int) has 3 records: (1, 2), (2, 3), (2, 4) private val table2 = StatsTestPlan( - outputList = Seq("key21", "key22").map(nameToAttr), + outputList = Seq("key-1-2", "key-2-4").map(nameToAttr), rowCount = 3, - attributeStats = AttributeMap(Seq("key21", "key22").map(nameToColInfo))) + attributeStats = AttributeMap(Seq("key-1-2", "key-2-4").map(nameToColInfo))) - // Suppose table3 (key31 int, key32 int) has 2 records: (1, 2), (2, 3) + // Suppose table3 (key-1-2 int, key-2-3 int) has 2 records: (1, 2), (2, 3) private val table3 = StatsTestPlan( - outputList = Seq("key31", "key32").map(nameToAttr), + outputList = Seq("key-1-2", "key-2-3").map(nameToAttr), rowCount = 2, - attributeStats = AttributeMap(Seq("key31", "key32").map(nameToColInfo))) + attributeStats = AttributeMap(Seq("key-1-2", "key-2-3").map(nameToColInfo))) test("cross join") { - // table1 (key11 int, key12 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) - // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) + // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) val join = Join(table1, table2, Cross, None) val expectedStats = Statistics( sizeInBytes = 5 * 3 * (8 + 4 * 4), rowCount = Some(5 * 3), // Keep the column stat from both sides unchanged. - attributeStats = AttributeMap(Seq("key11", "key12", "key21", "key22").map(nameToColInfo))) + attributeStats = AttributeMap( + Seq("key-1-5", "key-5-9", "key-1-2", "key-2-4").map(nameToColInfo))) assert(join.stats(conf) == expectedStats) } test("disjoint inner join") { - // table1 (key11 int, key12 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) - // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) - // key12 and key22 are disjoint - val join = Join(table1, table2, Inner, Some( - And(EqualTo(nameToAttr("key11"), nameToAttr("key21")), - EqualTo(nameToAttr("key12"), nameToAttr("key22"))))) + // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // key-5-9 and key-2-4 are disjoint + val join = Join(table1, table2, Inner, + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) val expectedStats = Statistics( sizeInBytes = 1, rowCount = Some(0), @@ -94,31 +92,32 @@ class JoinEstimationSuite extends StatsEstimationTestBase { } test("inner join") { - // table1 (key11 int, key12 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) - // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) - val join = Join(table1, table2, Inner, Some(EqualTo(nameToAttr("key11"), nameToAttr("key21")))) - // Update column stats for equi-join keys (key11 and key21). + // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + val join = Join(table1, table2, Inner, + Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2")))) + // Update column stats for equi-join keys (key-1-5 and key-1-2). val joinedColStat = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4) - // Update column stat for other column if #outputRow / #sideRow < 1 (key12), or keep it - // unchanged (key22). - val colStatForKey12 = nameToColInfo("key12")._2.copy(distinctCount = 5 * 3 / 5) + // Update column stat for other column if #outputRow / #sideRow < 1 (key-5-9), or keep it + // unchanged (key-2-4). + val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = 5 * 3 / 5) val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 4), rowCount = Some(3), attributeStats = AttributeMap( - Seq(nameToAttr("key11") -> joinedColStat, nameToAttr("key21") -> joinedColStat, - nameToAttr("key12") -> colStatForKey12, nameToColInfo("key22")))) + Seq(nameToAttr("key-1-5") -> joinedColStat, nameToAttr("key-1-2") -> joinedColStat, + nameToAttr("key-5-9") -> colStatForkey59, nameToColInfo("key-2-4")))) assert(join.stats(conf) == expectedStats) } test("inner join with multiple equi-join keys") { - // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) - // table3 (key31 int, key32 int): (1, 2), (2, 3) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, Inner, Some( - And(EqualTo(nameToAttr("key21"), nameToAttr("key31")), - EqualTo(nameToAttr("key22"), nameToAttr("key32"))))) + And(EqualTo(nameToAttr("key-1-2"), nameToAttr("key-1-2")), + EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))) // Update column stats for join keys. val joinedColStat1 = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, @@ -130,16 +129,16 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 2 * (8 + 4 * 4), rowCount = Some(2), attributeStats = AttributeMap( - Seq(nameToAttr("key21") -> joinedColStat1, nameToAttr("key31") -> joinedColStat1, - nameToAttr("key22") -> joinedColStat2, nameToAttr("key32") -> joinedColStat2))) + Seq(nameToAttr("key-1-2") -> joinedColStat1, nameToAttr("key-1-2") -> joinedColStat1, + nameToAttr("key-2-4") -> joinedColStat2, nameToAttr("key-2-3") -> joinedColStat2))) assert(join.stats(conf) == expectedStats) } test("left outer join") { - // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) - // table3 (key31 int, key32 int): (1, 2), (2, 3) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table3, table2, LeftOuter, - Some(EqualTo(nameToAttr("key32"), nameToAttr("key22")))) + Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4")))) val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4) @@ -148,16 +147,16 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = Some(2), // Keep the column stat from left side unchanged. attributeStats = AttributeMap( - Seq(nameToColInfo("key31"), nameToColInfo("key32"), - nameToColInfo("key21"), nameToAttr("key22") -> joinedColStat))) + Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-3"), + nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat))) assert(join.stats(conf) == expectedStats) } test("right outer join") { - // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) - // table3 (key31 int, key32 int): (1, 2), (2, 3) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, RightOuter, - Some(EqualTo(nameToAttr("key22"), nameToAttr("key32")))) + Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4) @@ -166,36 +165,37 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = Some(2), // Keep the column stat from right side unchanged. attributeStats = AttributeMap( - Seq(nameToColInfo("key21"), nameToAttr("key22") -> joinedColStat, - nameToColInfo("key31"), nameToColInfo("key32")))) + Seq(nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat, + nameToColInfo("key-1-2"), nameToColInfo("key-2-3")))) assert(join.stats(conf) == expectedStats) } test("full outer join") { - // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) - // table3 (key31 int, key32 int): (1, 2), (2, 3) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, FullOuter, - Some(EqualTo(nameToAttr("key22"), nameToAttr("key32")))) + Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 4), rowCount = Some(3), // Keep the column stat from both sides unchanged. - attributeStats = AttributeMap(Seq(nameToColInfo("key21"), nameToColInfo("key22"), - nameToColInfo("key31"), nameToColInfo("key32")))) + attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4"), + nameToColInfo("key-1-2"), nameToColInfo("key-2-3")))) assert(join.stats(conf) == expectedStats) } test("left semi/anti join") { - // table2 (key21 int, key22 int): (1, 2), (2, 3), (2, 4) - // table3 (key31 int, key32 int): (1, 2), (2, 3) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) Seq(LeftSemi, LeftAnti).foreach { jt => - val join = Join(table2, table3, jt, Some(EqualTo(nameToAttr("key22"), nameToAttr("key32")))) + val join = Join(table2, table3, jt, + Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) // For now we just propagate the statistics from left side for left semi/anti join. val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 2), rowCount = Some(3), - attributeStats = AttributeMap(Seq(nameToColInfo("key21"), nameToColInfo("key22")))) + attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4")))) assert(join.stats(conf) == expectedStats) } } @@ -295,7 +295,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { outputList = Seq(nullColumn), rowCount = 1, attributeStats = AttributeMap(Seq(nullColumn -> nullColStat))) - val join = Join(table1, nullTable, Inner, Some(EqualTo(nameToAttr("key11"), nullColumn))) + val join = Join(table1, nullTable, Inner, Some(EqualTo(nameToAttr("key-1-5"), nullColumn))) val expectedStats = Statistics( sizeInBytes = 1, rowCount = Some(0), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 8563e4da56ab..c56b41ce3763 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -30,6 +30,7 @@ class StatsEstimationTestBase extends SparkFunSuite { protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true, cboEnabled = true) def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { + // For UTF8String: base + offset + numBytes case StringType => colStat.avgLen + 8 + 4 case _ => colStat.avgLen } From e8930d2067d1bc81b07c52f2c951b254f633d8bf Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 14 Feb 2017 15:10:25 -0800 Subject: [PATCH 10/12] add test cases for disjoint outer joins --- .../statsEstimation/EstimationUtils.scala | 9 +- .../statsEstimation/JoinEstimation.scala | 87 +++++++++---------- .../statsEstimation/JoinEstimationSuite.scala | 51 +++++++++++ 3 files changed, 98 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 1b0b8797c760..4d18b28be866 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.math.BigDecimal.RoundingMode import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{DataType, StringType} object EstimationUtils { @@ -38,6 +38,11 @@ object EstimationUtils { } } + def nullColumnStat(dataType: DataType, rowCount: BigInt): ColumnStat = { + ColumnStat(distinctCount = 0, min = None, max = None, nullCount = rowCount, + avgLen = dataType.defaultSize, maxLen = dataType.defaultSize) + } + def ceil(bigDecimal: BigDecimal): BigInt = bigDecimal.setScale(0, RoundingMode.CEILING).toBigInt() /** Get column stats for output attributes. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index dacd3a32b88d..398a07207667 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.CatalystConf @@ -26,7 +27,6 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -import org.apache.spark.sql.types.DataType object JoinEstimation extends Logging { @@ -88,54 +88,52 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging // 3. Update statistics based on the output of join val inputAttrStats = AttributeMap( leftStats.attributeStats.toSeq ++ rightStats.attributeStats.toSeq) - val joinKeyStats = if (innerJoinedRows == 0) { - val leftKeys = joinKeyPairs.map(_._1) - val rightKeys = joinKeyPairs.map(_._2) + val attributesWithStat = join.output.filter(a => inputAttrStats.contains(a)) + val (fromLeft, fromRight) = attributesWithStat.partition(join.left.outputSet.contains(_)) + + val outputStats: Seq[(Attribute, ColumnStat)] = if (innerJoinedRows == 0) { joinType match { // For outer joins, if the inner join part is empty, the number of output rows is the // same as that of the outer side. And column stats of join keys from the outer side // keep unchanged, while column stats of join keys from the other side should be updated // based on added null values. case LeftOuter => - AttributeMap[ColumnStat](leftKeys.map(k => (k, inputAttrStats(k))) ++ - rightKeys.map(k => (k, nullColumnStat(k.dataType, leftRows)))) + fromLeft.map(a => (a, inputAttrStats(a))) ++ + fromRight.map(a => (a, nullColumnStat(a.dataType, leftRows))) case RightOuter => - AttributeMap[ColumnStat](rightKeys.map(k => (k, inputAttrStats(k))) ++ - leftKeys.map(k => (k, nullColumnStat(k.dataType, rightRows)))) + fromRight.map(a => (a, inputAttrStats(a))) ++ + fromLeft.map(a => (a, nullColumnStat(a.dataType, rightRows))) case FullOuter => - AttributeMap[ColumnStat](leftKeys.map { k => - val oriColStat = inputAttrStats(k) - (k, oriColStat.copy(distinctCount = oriColStat.distinctCount + rightRows)) - } ++ rightKeys.map { k => - val oriColStat = inputAttrStats(k) - (k, oriColStat.copy(distinctCount = oriColStat.distinctCount + leftRows)) - }) + fromLeft.map { a => + val oriColStat = inputAttrStats(a) + (a, oriColStat.copy(nullCount = oriColStat.nullCount + rightRows)) + } ++ fromRight.map { a => + val oriColStat = inputAttrStats(a) + (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows)) + } case _ => - AttributeMap[ColumnStat](Nil) + // For inner join, since the output is empty, we don't need to keep column stats. + Nil } } else { - getIntersectedStats(joinKeyPairs) - } - val attributesWithStat = join.output.filter(a => inputAttrStats.contains(a)) - val (fromLeft, fromRight) = attributesWithStat.partition(join.left.outputSet.contains(_)) - val outputStats: Map[Attribute, ColumnStat] = join.joinType match { - case LeftOuter => - // Don't update column stats for attributes from left side. - fromLeft.map(a => (a, inputAttrStats(a))).toMap ++ - updateAttrStats(outputRows, fromRight, inputAttrStats, joinKeyStats) - case RightOuter => - // Don't update column stats for attributes from right side. - updateAttrStats(outputRows, fromLeft, inputAttrStats, joinKeyStats) ++ - fromRight.map(a => (a, inputAttrStats(a))).toMap - case FullOuter => - // Don't update column stats for attributes from both sides. - attributesWithStat.map(a => (a, inputAttrStats(a))).toMap - case _ => - // Update column stats from both sides for inner or cross join. - updateAttrStats(outputRows, attributesWithStat, inputAttrStats, joinKeyStats) + val joinKeyStats = getIntersectedStats(joinKeyPairs) + join.joinType match { + // For outer joins, don't update column stats from the outer side. + case LeftOuter => + fromLeft.map(a => (a, inputAttrStats(a))) ++ + updateAttrStats(outputRows, fromRight, inputAttrStats, joinKeyStats) + case RightOuter => + updateAttrStats(outputRows, fromLeft, inputAttrStats, joinKeyStats) ++ + fromRight.map(a => (a, inputAttrStats(a))) + case FullOuter => + attributesWithStat.map(a => (a, inputAttrStats(a))) + case _ => + // Update column stats from both sides for inner or cross join. + updateAttrStats(outputRows, attributesWithStat, inputAttrStats, joinKeyStats) + } } - val outputAttrStats = AttributeMap(outputStats.toSeq) + val outputAttrStats = AttributeMap(outputStats) Some(Statistics( sizeInBytes = getOutputSize(join.output, outputRows, outputAttrStats), rowCount = Some(outputRows), @@ -213,13 +211,13 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging outputRows: BigInt, attributes: Seq[Attribute], oldAttrStats: AttributeMap[ColumnStat], - joinKeyStats: AttributeMap[ColumnStat]): AttributeMap[ColumnStat] = { - val outputAttrStats = new mutable.HashMap[Attribute, ColumnStat]() + joinKeyStats: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = { + val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get if (outputRows == leftRows * rightRows) { // Cartesian product, just propagate the original column stats - attributes.foreach(a => outputAttrStats.put(a, oldAttrStats(a))) + attributes.foreach(a => outputAttrStats += a -> oldAttrStats(a)) } else if (outputRows != 0) { val leftRatio = if (leftRows != 0) BigDecimal(outputRows) / BigDecimal(leftRows) else BigDecimal(0) @@ -228,7 +226,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging attributes.foreach { a => // check if this attribute is a join key if (joinKeyStats.contains(a)) { - outputAttrStats.put(a, joinKeyStats(a)) + outputAttrStats += a -> joinKeyStats(a) } else { val oldCS = oldAttrStats(a) val oldNdv = oldCS.distinctCount @@ -243,11 +241,11 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging oldNdv } // TODO: support nullCount updates for specific outer joins - outputAttrStats.put(a, oldCS.copy(distinctCount = newNdv)) + outputAttrStats += a -> oldCS.copy(distinctCount = newNdv) } } } - AttributeMap(outputAttrStats.toSeq) + outputAttrStats } /** Get intersected column stats for join keys. */ @@ -275,11 +273,6 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging AttributeMap(intersectedStats.toSeq) } - private def nullColumnStat(dataType: DataType, rowCount: BigInt): ColumnStat = { - ColumnStat(distinctCount = 0, min = None, max = None, nullCount = rowCount, - avgLen = dataType.defaultSize, maxLen = dataType.defaultSize) - } - private def intersectedColumnStat( origin: ColumnStat, newDistinctCount: BigInt, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 474595b9940c..a596c01a5f6c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.types.{DateType, TimestampType, _} @@ -91,6 +92,56 @@ class JoinEstimationSuite extends StatsEstimationTestBase { assert(join.stats(conf) == expectedStats) } + test("disjoint left outer join") { + // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // key-5-9 and key-2-4 are disjoint + val join = Join(table1, table2, LeftOuter, + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + val expectedStats = Statistics( + sizeInBytes = 5 * (8 + 4 * 4), + rowCount = Some(5), + attributeStats = AttributeMap(Seq("key-1-5", "key-5-9").map(nameToColInfo) ++ + // Null count for right side columns = left row count + Seq(nameToAttr("key-1-2") -> nullColumnStat(nameToAttr("key-1-2").dataType, 5), + nameToAttr("key-2-4") -> nullColumnStat(nameToAttr("key-2-4").dataType, 5)))) + assert(join.stats(conf) == expectedStats) + } + + test("disjoint right outer join") { + // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // key-5-9 and key-2-4 are disjoint + val join = Join(table1, table2, RightOuter, + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + val expectedStats = Statistics( + sizeInBytes = 3 * (8 + 4 * 4), + rowCount = Some(3), + attributeStats = AttributeMap(Seq("key-1-2", "key-2-4").map(nameToColInfo) ++ + // Null count for left side columns = right row count + Seq(nameToAttr("key-1-5") -> nullColumnStat(nameToAttr("key-1-5").dataType, 3), + nameToAttr("key-5-9") -> nullColumnStat(nameToAttr("key-5-9").dataType, 3)))) + assert(join.stats(conf) == expectedStats) + } + + test("disjoint full outer join") { + // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) + // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) + // key-5-9 and key-2-4 are disjoint + val join = Join(table1, table2, FullOuter, + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + val expectedStats = Statistics( + sizeInBytes = (5 + 3) * (8 + 4 * 4), + rowCount = Some(5 + 3), + attributeStats = AttributeMap( + // Update null count in column stats. + Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = 3), + nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3), + nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5), + nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5)))) + assert(join.stats(conf) == expectedStats) + } + test("inner join") { // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) From 981de6eb655ec7fae76d82312eb673496b8d7a83 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 14 Feb 2017 18:39:27 -0800 Subject: [PATCH 11/12] more comments --- .../statsEstimation/JoinEstimation.scala | 64 +++++------ .../plans/logical/statsEstimation/Range.scala | 15 +-- .../statsEstimation/JoinEstimationSuite.scala | 107 +++++++----------- 3 files changed, 71 insertions(+), 115 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 398a07207667..3e7377b5f242 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -91,7 +91,10 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val attributesWithStat = join.output.filter(a => inputAttrStats.contains(a)) val (fromLeft, fromRight) = attributesWithStat.partition(join.left.outputSet.contains(_)) - val outputStats: Seq[(Attribute, ColumnStat)] = if (innerJoinedRows == 0) { + val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { + // The output is empty, we don't need to keep column stats. + Nil + } else if (innerJoinedRows == 0) { joinType match { // For outer joins, if the inner join part is empty, the number of output rows is the // same as that of the outer side. And column stats of join keys from the outer side @@ -111,9 +114,6 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val oriColStat = inputAttrStats(a) (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows)) } - case _ => - // For inner join, since the output is empty, we don't need to keep column stats. - Nil } } else { val joinKeyStats = getIntersectedStats(joinKeyPairs) @@ -126,7 +126,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging updateAttrStats(outputRows, fromLeft, inputAttrStats, joinKeyStats) ++ fromRight.map(a => (a, inputAttrStats(a))) case FullOuter => - attributesWithStat.map(a => (a, inputAttrStats(a))) + inputAttrStats.toSeq case _ => // Update column stats from both sides for inner or cross join. updateAttrStats(outputRows, attributesWithStat, inputAttrStats, joinKeyStats) @@ -145,12 +145,11 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val inputAttrStats = AttributeMap( leftStats.attributeStats.toSeq ++ rightStats.attributeStats.toSeq) // Propagate the original column stats - val outputAttrStats = getOutputMap(inputAttrStats, join.output) val outputRows = leftStats.rowCount.get * rightStats.rowCount.get Some(Statistics( - sizeInBytes = getOutputSize(join.output, outputRows, outputAttrStats), + sizeInBytes = getOutputSize(join.output, outputRows, inputAttrStats), rowCount = Some(outputRows), - attributeStats = outputAttrStats, + attributeStats = inputAttrStats, isBroadcastable = false)) } @@ -202,9 +201,8 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging /** * Propagate or update column stats for output attributes. - * 1. For empty output, we don't need to keep any column stats. - * 2. For cartesian product, all values are preserved, so there's no need to change column stats. - * 3. For other cases, a) update max/min of join keys based on their intersected range. b) update + * 1. For cartesian product, all values are preserved, so there's no need to change column stats. + * 2. For other cases, a) update max/min of join keys based on their intersected range. b) update * distinct count of other attributes based on output rows after join. */ private def updateAttrStats( @@ -218,7 +216,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging if (outputRows == leftRows * rightRows) { // Cartesian product, just propagate the original column stats attributes.foreach(a => outputAttrStats += a -> oldAttrStats(a)) - } else if (outputRows != 0) { + } else { val leftRatio = if (leftRows != 0) BigDecimal(outputRows) / BigDecimal(leftRows) else BigDecimal(0) val rightRatio = @@ -228,8 +226,8 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging if (joinKeyStats.contains(a)) { outputAttrStats += a -> joinKeyStats(a) } else { - val oldCS = oldAttrStats(a) - val oldNdv = oldCS.distinctCount + val oldColStat = oldAttrStats(a) + val oldNdv = oldColStat.distinctCount // We only change (scale down) the number of distinct values if the number of rows // decreases after join, because join won't produce new values even if the number of // rows increases. @@ -241,7 +239,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging oldNdv } // TODO: support nullCount updates for specific outer joins - outputAttrStats += a -> oldCS.copy(distinctCount = newNdv) + outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv) } } } @@ -254,33 +252,26 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val intersectedStats = new mutable.HashMap[Attribute, ColumnStat]() joinKeyPairs.foreach { case (leftKey, rightKey) => - // Check if the two sides are disjoint val leftKeyStats = leftStats.attributeStats(leftKey) val rightKeyStats = rightStats.attributeStats(rightKey) val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) - if (Range.isIntersected(lRange, rRange)) { - // Update intersected column stats - val minNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) - val (newMin1, newMax1, newMin2, newMax2) = - Range.intersect(lRange, rRange, leftKey.dataType, rightKey.dataType) - intersectedStats.put(leftKey, intersectedColumnStat(leftKeyStats, minNdv, - newMin1, newMax1)) - intersectedStats.put(rightKey, intersectedColumnStat(rightKeyStats, minNdv, - newMin2, newMax2)) - } + // When we reach here, join selectivity is not zero, so each pair of join keys should be + // intersected. + assert(Range.isIntersected(lRange, rRange)) + + // Update intersected column stats + assert(leftKey.dataType.sameType(rightKey.dataType)) + val minNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) + val (newMin, newMax) = Range.intersect(lRange, rRange, leftKey.dataType) + intersectedStats.put(leftKey, + leftKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0)) + intersectedStats.put(rightKey, + rightKeyStats.copy(distinctCount = minNdv, min = newMin, max = newMax, nullCount = 0)) } AttributeMap(intersectedStats.toSeq) } - private def intersectedColumnStat( - origin: ColumnStat, - newDistinctCount: BigInt, - newMin: Option[Any], - newMax: Option[Any]): ColumnStat = { - origin.copy(distinctCount = newDistinctCount, min = newMin, max = newMax, nullCount = 0) - } - private def extractJoinKeysWithColStats( leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Seq[(AttributeReference, AttributeReference)] = { @@ -302,12 +293,11 @@ case class LeftSemiAntiEstimation(conf: CatalystConf, join: Join) { if (rowCountsExist(conf, join.left)) { val leftStats = join.left.stats(conf) // Propagate the original column stats for cartesian product - val outputAttrStats = getOutputMap(leftStats.attributeStats, join.output) val outputRows = leftStats.rowCount.get Some(Statistics( - sizeInBytes = getOutputSize(join.output, outputRows, outputAttrStats), + sizeInBytes = getOutputSize(join.output, outputRows, leftStats.attributeStats), rowCount = Some(outputRows), - attributeStats = outputAttrStats, + attributeStats = leftStats.attributeStats, isBroadcastable = false)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index bc39c05cd96e..5aa6b9353bc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -59,22 +59,17 @@ object Range { /** * Intersected results of two ranges. This is only for two overlapped ranges. - * The outputs are the intersected min/max values of the two columns based on their data types. + * The outputs are the intersected min/max values. */ - def intersect( - r1: Range, - r2: Range, - dt1: DataType, - dt2: DataType): (Option[Any], Option[Any], Option[Any], Option[Any]) = { + def intersect(r1: Range, r2: Range, dt: DataType): (Option[Any], Option[Any]) = { (r1, r2) match { case (_, _: DefaultRange) | (_: DefaultRange, _) => // binary/string types don't support intersecting. - (None, None, None, None) + (None, None) case (n1: NumericRange, n2: NumericRange) => val newRange = NumericRange(n1.min.max(n2.min), n1.max.min(n2.max)) - val (newMin1, newMax1) = fromNumericRange(newRange, dt1) - val (newMin2, newMax2) = fromNumericRange(newRange, dt2) - (Some(newMin1), Some(newMax1), Some(newMin2), Some(newMax2)) + val (newMin, newMax) = fromNumericRange(newRange, dt) + (Some(newMin), Some(newMax)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index a596c01a5f6c..f62df842fa50 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -252,83 +252,54 @@ class JoinEstimationSuite extends StatsEstimationTestBase { } test("test join keys of different types") { - val dec1 = new java.math.BigDecimal("1.000000000000000000") - val dec2 = new java.math.BigDecimal("8.000000000000000000") - val d1 = Date.valueOf("2016-05-08") - val d2 = Date.valueOf("2016-05-09") - val t1 = Timestamp.valueOf("2016-05-08 00:00:01") - val t2 = Timestamp.valueOf("2016-05-09 00:00:02") - /** Columns in a table with only one row */ - val columnInfo1 = mutable.LinkedHashMap[Attribute, ColumnStat]( - AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, - min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 2, maxLen = 2), - AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, - min = Some(dec1), max = Some(dec1), nullCount = 0, avgLen = 16, maxLen = 16), - AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 1, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 1, - min = Some(d1), max = Some(d1), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 1, - min = Some(t1), max = Some(t1), nullCount = 0, avgLen = 8, maxLen = 8) - ) - - /** Columns in a table with two rows */ - val columnInfo2 = mutable.LinkedHashMap[Attribute, ColumnStat]( - AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2, - min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(2L), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(3L), nullCount = 0, avgLen = 2, maxLen = 2), - AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(4L), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0), max = Some(7.0), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2, - min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16), - AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 2, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 2, - min = Some(d1), max = Some(d2), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 2, - min = Some(t1), max = Some(t2), nullCount = 0, avgLen = 8, maxLen = 8) - ) + def genColumnData: mutable.LinkedHashMap[Attribute, ColumnStat] = { + val dec = new java.math.BigDecimal("1.000000000000000000") + val date = Date.valueOf("2016-05-08") + val timestamp = Timestamp.valueOf("2016-05-08 00:00:01") + mutable.LinkedHashMap[Attribute, ColumnStat]( + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, + min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, + min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 1, maxLen = 1), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, + min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 2, maxLen = 2), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, + min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 4, maxLen = 4), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, + min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, + min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, + min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 4, maxLen = 4), + AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, + min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, + min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 1, + min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 1, + min = Some(date), max = Some(date), nullCount = 0, avgLen = 4, maxLen = 4), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 1, + min = Some(timestamp), max = Some(timestamp), nullCount = 0, avgLen = 8, maxLen = 8) + ) + } - val oneRowTable = StatsTestPlan( + val columnInfo1 = genColumnData + val columnInfo2 = genColumnData + val table1 = StatsTestPlan( outputList = columnInfo1.keys.toSeq, rowCount = 1, attributeStats = AttributeMap(columnInfo1.toSeq)) - val twoRowTable = StatsTestPlan( + val table2 = StatsTestPlan( outputList = columnInfo2.keys.toSeq, - rowCount = 2, + rowCount = 1, attributeStats = AttributeMap(columnInfo2.toSeq)) - val joinKeys = oneRowTable.output.zip(twoRowTable.output) + val joinKeys = table1.output.zip(table2.output) joinKeys.foreach { case (key1, key2) => withClue(s"For data type ${key1.dataType}") { - // All values in oneRowTable is contained in twoRowTable, so column stats after join is - // equal to that of oneRowTable. - val join = Join(Project(Seq(key1), oneRowTable), Project(Seq(key2), twoRowTable), Inner, + // All values in two tables are the same, so column stats after join are also the same. + val join = Join(Project(Seq(key1), table1), Project(Seq(key2), table2), Inner, Some(EqualTo(key1, key2))) val expectedStats = Statistics( sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))), From 8182123f09328ddffabcc1d180c0309f550489f8 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 14 Feb 2017 18:41:35 -0800 Subject: [PATCH 12/12] fix error --- .../catalyst/plans/logical/statsEstimation/JoinEstimation.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 3e7377b5f242..982a5a8bb89b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -114,6 +114,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging val oriColStat = inputAttrStats(a) (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows)) } + case _ => Nil } } else { val joinKeyStats = getIntersectedStats(joinKeyPairs)