Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ import org.scalatest.Matchers._
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction
import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, ExpressionInfo, Literal}
import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.HiveSessionCatalog
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
Expand All @@ -43,6 +42,14 @@ class ObjectHashAggregateSuite

import testImplicits._

protected override def beforeAll(): Unit = {
sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")
}

protected override def afterAll(): Unit = {
sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max")
}

test("typed_count without grouping keys") {
val df = Seq((1: Integer, 2), (null, 2), (3: Integer, 4)).toDF("a", "b")

Expand Down Expand Up @@ -199,10 +206,7 @@ class ObjectHashAggregateSuite
val typed = percentile_approx($"c0", 0.5)

// A Hive UDAF without partial aggregation support
val withoutPartial = {
registerHiveFunction("hive_max", classOf[GenericUDAFMax])
function("hive_max", $"c1")
}
val withoutPartial = function("hive_max", $"c1")

// A Spark SQL native aggregate function with partial aggregation support that can be executed
// by the Tungsten `HashAggregateExec`
Expand Down Expand Up @@ -420,13 +424,6 @@ class ObjectHashAggregateSuite
}
}

private def registerHiveFunction(functionName: String, clazz: Class[_]): Unit = {
val sessionCatalog = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog]
val builder = sessionCatalog.makeFunctionBuilder(functionName, clazz.getName)
val info = new ExpressionInfo(clazz.getName, functionName)
sessionCatalog.createTempFunction(functionName, info, builder, ignoreIfExists = false)
}

private def function(name: String, args: Column*): Column = {
Column(UnresolvedFunction(FunctionIdentifier(name), args.map(_.expr), isDistinct = false))
}
Expand Down