diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala index e4b49eb51..9206be146 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala @@ -399,6 +399,29 @@ object ColumnarExpressionConverter extends Logging { attributeSeq, convertBoundRefToAttrRef = convertBoundRefToAttrRef), expr) + // Scala UDF. + case expr: ScalaUDF if (expr.udfName match { + case Some(name) => + ColumnarUDF.isSupportedUDF(name) + case None => + false + }) => + val children = expr.children.map { expr => + replaceWithColumnarExpression( + expr, + attributeSeq, + convertBoundRefToAttrRef = convertBoundRefToAttrRef) + } + ColumnarUDF.create(children, expr) + // Hive UDF. + case expr if (ColumnarUDF.isSupportedUDF(expr.prettyName)) => + val children = expr.children.map { expr => + replaceWithColumnarExpression( + expr, + attributeSeq, + convertBoundRefToAttrRef = convertBoundRefToAttrRef) + } + ColumnarUDF.create(children, expr) case expr => throw new UnsupportedOperationException( s" --> ${expr.getClass} | ${expr} is not currently supported.") @@ -460,6 +483,15 @@ object ColumnarExpressionConverter extends Logging { containsSubquery(sr.srcExpr) || containsSubquery(sr.searchExpr) || containsSubquery(sr.replaceExpr) + case expr: ScalaUDF if (expr.udfName match { + case Some(name) => + ColumnarUDF.isSupportedUDF(name) + case None => + false + }) => + expr.children.map(containsSubquery).exists(_ == true) + case expr if (ColumnarUDF.isSupportedUDF(expr.prettyName)) => + expr.children.map(containsSubquery).exists(_ == true) case expr => throw new UnsupportedOperationException( s" --> ${expr.getClass} | ${expr} is not currently supported.") diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUDF.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUDF.scala new file mode 100644 index 000000000..880e425e8 --- /dev/null +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarUDF.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.oap.expression + +import com.google.common.collect.Lists +import org.apache.arrow.gandiva.expression.{TreeBuilder, TreeNode} +import org.apache.arrow.vector.types.pojo.ArrowType + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.StringType + + +case class ColumnarURLDecoder(input: Expression) extends Expression with ColumnarExpression { + def nullable: Boolean = { + true + } + + def children: Seq[Expression] = { + Seq(input) + } + + def dataType: DataType = { + StringType + } + + def eval(input: InternalRow): Any = { + throw new UnsupportedOperationException("Should not trigger eval!") + } + + def child: Expression = { + input + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + throw new UnsupportedOperationException("Should not trigger code gen!") + } + + protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ColumnarURLDecoder = { + copy(input = newChildren.head) + } + + buildCheck + + def buildCheck: Unit = { + val supportedTypes = List(StringType) + if (!supportedTypes.contains(input.dataType)) { + throw new UnsupportedOperationException("Only StringType input is supported!") + } + } + + override def supportColumnarCodegen(args: java.lang.Object): Boolean = { + false + } + + override def doColumnarCodeGen(args: Object): (TreeNode, ArrowType) = { + val (inputNode, _): (TreeNode, ArrowType) = + input.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + val resultType = new ArrowType.Utf8() + val funcNode = + TreeBuilder.makeFunction( + "url_decoder", + Lists.newArrayList(inputNode), + resultType) + (funcNode, resultType) + } +} + +object ColumnarUDF { + // Keep the supported UDF name. The name is specified in registering the + // row based function in spark, e.g., + // CREATE TEMPORARY FUNCTION UrlDecoder AS 'com.intel.test.URLDecoderNew'; + val supportList = List("urldecoder") + + def isSupportedUDF(name: String): Boolean = { + if (name == null) { + return false + } + return supportList.map(s => s.equalsIgnoreCase(name)).exists(_ == true) + } + + def create(children: Seq[Expression], original: Expression): Expression = { + original.prettyName.toLowerCase() match { + // Hive UDF. + case "urldecoder" => + ColumnarURLDecoder(children.head) + // Scala UDF. + case "scalaudf" => + original.asInstanceOf[ScalaUDF].udfName.get.toLowerCase() match { + case "urldecoder" => + ColumnarURLDecoder(children.head) + case other => + throw new UnsupportedOperationException(s"not currently supported: $other.") + } + case other => + throw new UnsupportedOperationException(s"not currently supported: $other.") + } + } +} diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 356b68798..6ef4943f5 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -603,6 +603,16 @@ class DataFrameSuite extends QueryTest ) } + test("Columnar UDF") { + // Register a scala UDF. The scala UDF code will not be acutally used. It + // will be replaced by columnar UDF at runtime. + spark.udf.register("UrlDecoder", (s : String) => s) + checkAnswer( + sql("select UrlDecoder('AaBb%23'), UrlDecoder(null)"), + Seq(Row("AaBb#", null)) + ) + } + test("callUDF without Hive Support") { val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") df.sparkSession.udf.register("simpleUDF", (v: Int) => v * v)