Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-955] implement concat_ws #963

Merged
merged 4 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,65 @@ import org.apache.spark.sql.types._

import scala.collection.mutable.ListBuffer

class ColumnarConcatWs(exps: Seq[Expression], original: Expression)
extends ConcatWs(exps: Seq[Expression])
with ColumnarExpression
with Logging {

buildCheck()

def buildCheck(): Unit = {
exps.foreach(expr =>
if (expr.dataType != StringType) {
throw new UnsupportedOperationException(
s"${expr.dataType} is not supported in ColumnarConcatWS")
})
}

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
false
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val iter: Iterator[Expression] = exps.iterator
val exp = iter.next() // spliter
val exp1 = iter.next()
val iterFaster: Iterator[Expression] = exps.iterator
iterFaster.next()
iterFaster.next()
iterFaster.next()

val (split_node, expType): (TreeNode, ArrowType) =
exp.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
val (exp1_node, exp1Type): (TreeNode, ArrowType) =
exp1.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = new ArrowType.Utf8()
val funcNode = TreeBuilder.makeFunction("concat",
Lists.newArrayList(exp1_node, split_node, rightNode(args, exps, split_node, iter, iterFaster)), resultType)
(funcNode, expType)
}

def rightNode(args: java.lang.Object, exps: Seq[Expression], split_node: TreeNode,
iter: Iterator[Expression], iterFaster: Iterator[Expression]): TreeNode = {
if (!iterFaster.hasNext) {
// When iter reaches the last but one expression
val (exp_node, expType): (TreeNode, ArrowType) =
exps.last.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
exp_node
} else {
val exp = iter.next()
iterFaster.next()
val (exp_node, expType): (TreeNode, ArrowType) =
exp.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
val resultType = new ArrowType.Utf8()
val funcNode = TreeBuilder.makeFunction("concat",
Lists.newArrayList(exp_node, split_node, rightNode(args, exps, split_node, iter, iterFaster)), resultType)
funcNode
}
}
}

class ColumnarConcat(exps: Seq[Expression], original: Expression)
extends Concat(exps: Seq[Expression])
with ColumnarExpression
Expand All @@ -44,6 +103,10 @@ class ColumnarConcat(exps: Seq[Expression], original: Expression)
})
}

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
false
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val iter: Iterator[Expression] = exps.iterator
val exp = iter.next()
Expand Down Expand Up @@ -85,6 +148,8 @@ object ColumnarConcatOperator {
def create(exps: Seq[Expression], original: Expression): Expression = original match {
case c: Concat =>
new ColumnarConcat(exps, original)
case cws: ConcatWs =>
new ColumnarConcatWs(exps, original)
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,16 @@ object ColumnarExpressionConverter extends Logging {
convertBoundRefToAttrRef = convertBoundRefToAttrRef)
}
ColumnarConcatOperator.create(exps, expr)
case cws: ConcatWs =>
check_if_no_calculation = false
logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.")
val exps = cws.children.map { expr =>
replaceWithColumnarExpression(
expr,
attributeSeq,
convertBoundRefToAttrRef = convertBoundRefToAttrRef)
}
ColumnarConcatOperator.create(exps, expr)
case r: Round =>
check_if_no_calculation = false
logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.")
Expand Down