Skip to content

Commit

Permalink
Add scala udf support and a unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE committed May 23, 2022
1 parent fa7a444 commit 72a5fe8
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -399,14 +399,24 @@ object ColumnarExpressionConverter extends Logging {
attributeSeq,
convertBoundRefToAttrRef = convertBoundRefToAttrRef),
expr)
case expr if (UDF.isSupportedUDF(expr.prettyName)) =>
// Scala UDF.
case expr: ScalaUDF if ColumnarUDF.isSupportedUDF(expr.udfName.get) =>
val children = expr.children.map { expr =>
replaceWithColumnarExpression(
expr,
attributeSeq,
convertBoundRefToAttrRef = convertBoundRefToAttrRef)
}
UDF.create(children, expr)
ColumnarUDF.create(children, expr)
// Hive UDF.
case expr if (ColumnarUDF.isSupportedUDF(expr.prettyName)) =>
val children = expr.children.map { expr =>
replaceWithColumnarExpression(
expr,
attributeSeq,
convertBoundRefToAttrRef = convertBoundRefToAttrRef)
}
ColumnarUDF.create(children, expr)
case expr =>
throw new UnsupportedOperationException(
s" --> ${expr.getClass} | ${expr} is not currently supported.")
Expand Down Expand Up @@ -468,7 +478,9 @@ object ColumnarExpressionConverter extends Logging {
containsSubquery(sr.srcExpr) ||
containsSubquery(sr.searchExpr) ||
containsSubquery(sr.replaceExpr)
case expr if (UDF.isSupportedUDF(expr.prettyName)) =>
case expr: ScalaUDF if ColumnarUDF.isSupportedUDF(expr.udfName.get) =>
expr.children.map(containsSubquery).exists(_ == true)
case expr if (ColumnarUDF.isSupportedUDF(expr.prettyName)) =>
expr.children.map(containsSubquery).exists(_ == true)
case expr =>
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,32 @@ case class ColumnarURLDecoder(input: Expression) extends Expression with Columna
}
}

object UDF {
object ColumnarUDF {
// Keep the supported UDF name. The name is specified in registering the
// row based function in spark, e.g.,
// CREATE TEMPORARY FUNCTION UrlDecoder AS 'com.intel.test.URLDecoderNew';
val supportList = {"UrlDecoder"}

def isSupportedUDF(name: String): Boolean = {
if (name == null) {
return false;
}
return supportList.contains(name)
}

def create(children: Seq[Expression], original: Expression): Expression = {
original.prettyName match {
// Hive UDF.
case "UrlDecoder" =>
ColumnarURLDecoder(children.head)
// Scala UDF.
case "scalaudf" =>
original.asInstanceOf[ScalaUDF].udfName.get match {
case "UrlDecoder" =>
ColumnarURLDecoder(children.head)
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3629,6 +3629,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
df.select(map(map_entries($"m"), lit(1))),
Row(Map(Seq(Row(1, "a")) -> 1)))
}

test("Columnar UDF") {
// Register a scala UDF. The scala UDF code will not be acutally used. It
// will be replaced by columnar UDF at runtime.
spark.udf.register("UrlDecoder", (s : String) => s)
checkAnswer(
sql("select UrlDecoder('AaBb%23'), UrlDecoder(null)"),
Seq(Row("AaBb#", null))
)
}
}

object DataFrameFunctionsSuite {
Expand Down

0 comments on commit 72a5fe8

Please sign in to comment.