From 899fa4066a22dfc12cbbadfc165351a2573e2db1 Mon Sep 17 00:00:00 2001 From: jiaweihu Date: Wed, 2 Feb 2022 11:00:39 +0100 Subject: [PATCH] Replace MaxWeightExtimation with Spark SQL code, also removing related tests --- .../spark/index/MaxWeightEstimation.scala | 63 ------------------- .../spark/index/OTreeDataAnalyzer.scala | 6 +- .../spark/index/MaxWeightEstimationTest.scala | 47 -------------- 3 files changed, 2 insertions(+), 114 deletions(-) delete mode 100644 src/main/scala/io/qbeast/spark/index/MaxWeightEstimation.scala delete mode 100644 src/test/scala/io/qbeast/spark/index/MaxWeightEstimationTest.scala diff --git a/src/main/scala/io/qbeast/spark/index/MaxWeightEstimation.scala b/src/main/scala/io/qbeast/spark/index/MaxWeightEstimation.scala deleted file mode 100644 index af76409c4..000000000 --- a/src/main/scala/io/qbeast/spark/index/MaxWeightEstimation.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright 2021 Qbeast Analytics, S.L. - */ -package io.qbeast.spark.index - -import io.qbeast.core.model.Weight -import org.apache.spark.sql.{Encoder, Encoders} -import org.apache.spark.sql.expressions.Aggregator - -/** - * Aggregation object that estimates MaxWeight on a DataFrame - */ -object MaxWeightEstimation - extends Aggregator[NormalizedWeight, NormalizedWeight, NormalizedWeight] { - - import io.qbeast.core.model.NormalizedWeight - - /** - * Zero value for this aggregation - * @return Normalized value of minimum Weight - */ - override def zero: NormalizedWeight = Weight.MinValue.fraction - - /** - * Combine two values to produce a new value. - * For performance, the function may modify `buffer` - * and return it instead of constructing a new object - * @param buffer intermediate value for reduction - * @param weight input from aggregation - * @return merge of normalized weights buffer and weight - */ - override def reduce(buffer: NormalizedWeight, weight: NormalizedWeight): Double = { - NormalizedWeight.merge(buffer, weight) - } - - /** - * Merges two intermediate values - * @param w1 intermediate result - * @param w2 intermediate result - * @return merge of intermediate result w1 and w2 - */ - override def merge(w1: NormalizedWeight, w2: NormalizedWeight): Double = { - NormalizedWeight.merge(w1, w2) - } - - /** - * Transforms the output of the reduction - * @param reduction final buffer - * @return the final output result - */ - override def finish(reduction: NormalizedWeight): NormalizedWeight = reduction - - /** - * Specifies the Encoder for the intermediate value type - */ - override def bufferEncoder: Encoder[NormalizedWeight] = Encoders.scalaDouble - - /** - * Specifies the Encoder for the final output value type - * @return the Encoder for the output - */ - override def outputEncoder: Encoder[NormalizedWeight] = Encoders.scalaDouble -} diff --git a/src/main/scala/io/qbeast/spark/index/OTreeDataAnalyzer.scala b/src/main/scala/io/qbeast/spark/index/OTreeDataAnalyzer.scala index a220ad15d..6ccae6d8b 100644 --- a/src/main/scala/io/qbeast/spark/index/OTreeDataAnalyzer.scala +++ b/src/main/scala/io/qbeast/spark/index/OTreeDataAnalyzer.scala @@ -9,8 +9,7 @@ import io.qbeast.core.transform.{ColumnStats, Transformer} import io.qbeast.spark.index.QbeastColumns.{cubeToReplicateColumnName, weightColumnName} import io.qbeast.spark.internal.QbeastFunctions.qbeastHash import org.apache.spark.qbeast.config.CUBE_WEIGHTS_BUFFER_CAPACITY -import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.functions.{col, udaf} +import org.apache.spark.sql.functions.{col, lit, sum} import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} /** @@ -37,7 +36,6 @@ object DoublePassOTreeDataAnalyzer extends OTreeDataAnalyzer with Serializable { /** * Estimates MaxWeight on DataFrame */ - private val maxWeightEstimation: UserDefinedFunction = udaf(MaxWeightEstimation) private[index] def calculateRevisionChanges( columnStats: Seq[ColumnStats], @@ -89,7 +87,7 @@ object DoublePassOTreeDataAnalyzer extends OTreeDataAnalyzer with Serializable { // These column names are the ones specified in case class CubeNormalizedWeight partitionedEstimatedCubeWeights .groupBy("cubeBytes") - .agg(maxWeightEstimation(col("normalizedWeight"))) + .agg(lit(1) / sum(lit(1.0) / col("normalizedWeight"))) .map { row => val bytes = row.getAs[Array[Byte]](0) val estimatedWeight = row.getAs[Double](1) diff --git a/src/test/scala/io/qbeast/spark/index/MaxWeightEstimationTest.scala b/src/test/scala/io/qbeast/spark/index/MaxWeightEstimationTest.scala deleted file mode 100644 index b5b212b09..000000000 --- a/src/test/scala/io/qbeast/spark/index/MaxWeightEstimationTest.scala +++ /dev/null @@ -1,47 +0,0 @@ -package io.qbeast.spark.index - -import io.qbeast.core.model.Weight -import io.qbeast.spark.QbeastIntegrationTestSpec -import org.apache.spark.sql.functions.{col, udaf} -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.matchers.should.Matchers - -import scala.util.Random - -class MaxWeightEstimationTest extends AnyFlatSpec with Matchers with QbeastIntegrationTestSpec { - - import io.qbeast.core.model.NormalizedWeight - - "MaxWeight merge" should "compute fraction correctly" in { - val weightA = NormalizedWeight(Weight.MaxValue) - val weightB = NormalizedWeight(Weight.MaxValue) - MaxWeightEstimation.merge(weightA, weightB) shouldBe 0.5 - } - - "MaxWeight reduce" should "compute fraction correctly" in { - val weightA = NormalizedWeight(Weight.MaxValue) - val weightB = NormalizedWeight(Weight.MaxValue) - MaxWeightEstimation.reduce(weightA, weightB) shouldBe 0.5 - } - - "MaxWeight finish" should "return same fraction" in { - val finalWeight = NormalizedWeight(Weight(Random.nextInt())) - MaxWeightEstimation.finish(finalWeight) shouldBe finalWeight - } - - "MaxWeight zero" should "be minimum positive value" in { - MaxWeightEstimation.zero shouldBe 0.0 - } - - "MaxWeight" should "merge weights correctly on DataFrame" in withSpark { spark => - import spark.implicits._ - val weightA = 1.0 - val weightB = 1.0 - val weightC = 0.5 - - val df = Seq(weightA, weightB, weightC).toDF("weight") - val maxWeight = udaf(MaxWeightEstimation) - df.agg(maxWeight(col("weight"))).first().getDouble(0) shouldBe 0.25 - } - -}