Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-955] implement concat_ws (#963)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouyuan authored Jun 14, 2022
1 parent 737bba4 commit e9dfc2d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,65 @@ import org.apache.spark.sql.types._

import scala.collection.mutable.ListBuffer

class ColumnarConcatWs(exps: Seq[Expression], original: Expression)
extends ConcatWs(exps: Seq[Expression])
with ColumnarExpression
with Logging {

buildCheck()

def buildCheck(): Unit = {
exps.foreach(expr =>
if (expr.dataType != StringType) {
throw new UnsupportedOperationException(
s"${expr.dataType} is not supported in ColumnarConcatWS")
})
}

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
false
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val iter: Iterator[Expression] = exps.iterator
val exp = iter.next() // spliter
val exp1 = iter.next()
val iterFaster: Iterator[Expression] = exps.iterator
iterFaster.next()
iterFaster.next()
iterFaster.next()

val (split_node, expType): (TreeNode, ArrowType) =
exp.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
val (exp1_node, exp1Type): (TreeNode, ArrowType) =
exp1.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = new ArrowType.Utf8()
val funcNode = TreeBuilder.makeFunction("concat",
Lists.newArrayList(exp1_node, split_node, rightNode(args, exps, split_node, iter, iterFaster)), resultType)
(funcNode, expType)
}

def rightNode(args: java.lang.Object, exps: Seq[Expression], split_node: TreeNode,
iter: Iterator[Expression], iterFaster: Iterator[Expression]): TreeNode = {
if (!iterFaster.hasNext) {
// When iter reaches the last but one expression
val (exp_node, expType): (TreeNode, ArrowType) =
exps.last.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
exp_node
} else {
val exp = iter.next()
iterFaster.next()
val (exp_node, expType): (TreeNode, ArrowType) =
exp.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
val resultType = new ArrowType.Utf8()
val funcNode = TreeBuilder.makeFunction("concat",
Lists.newArrayList(exp_node, split_node, rightNode(args, exps, split_node, iter, iterFaster)), resultType)
funcNode
}
}
}

class ColumnarConcat(exps: Seq[Expression], original: Expression)
extends Concat(exps: Seq[Expression])
with ColumnarExpression
Expand All @@ -44,6 +103,10 @@ class ColumnarConcat(exps: Seq[Expression], original: Expression)
})
}

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
false
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val iter: Iterator[Expression] = exps.iterator
val exp = iter.next()
Expand Down Expand Up @@ -85,6 +148,8 @@ object ColumnarConcatOperator {
def create(exps: Seq[Expression], original: Expression): Expression = original match {
case c: Concat =>
new ColumnarConcat(exps, original)
case cws: ConcatWs =>
new ColumnarConcatWs(exps, original)
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,16 @@ object ColumnarExpressionConverter extends Logging {
convertBoundRefToAttrRef = convertBoundRefToAttrRef)
}
ColumnarConcatOperator.create(exps, expr)
case cws: ConcatWs =>
check_if_no_calculation = false
logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.")
val exps = cws.children.map { expr =>
replaceWithColumnarExpression(
expr,
attributeSeq,
convertBoundRefToAttrRef = convertBoundRefToAttrRef)
}
ColumnarConcatOperator.create(exps, expr)
case r: Round =>
check_if_no_calculation = false
logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.")
Expand Down

0 comments on commit e9dfc2d

Please sign in to comment.