Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-38]ColumnarSMJ: support expression as join keys #39

Merged
merged 1 commit into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ case class ColumnarSortMergeJoinExec(
//do not call prebuild so we could skip the c++ codegen
//val triggerBuildSignature = getCodeGenSignature

try {
/*try {
ColumnarSortMergeJoin.precheck(
leftKeys,
rightKeys,
Expand All @@ -332,7 +332,7 @@ case class ColumnarSortMergeJoinExec(
} catch {
case e: Throwable =>
throw e
}
}*/

/***********************************************************/
def getCodeGenSignature: String =
Expand Down
136 changes: 120 additions & 16 deletions core/src/main/scala/com/intel/oap/expression/ColumnarSortMergeJoin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ class ColumnarSortMergeJoin(
prepareTime: SQLMetric,
totaltime_sortmergejoin: SQLMetric,
totalOutputNumRows: SQLMetric,
sparkConf: SparkConf)
sparkConf: SparkConf,
buildProjector: ColumnarProjection,
buildKeyProjectOrdinalList: List[Int],
streamProjector: ColumnarProjection,
streamKeyProjectOrdinalList: List[Int])
extends Logging {
ColumnarPluginConfig.getConf(sparkConf)
var probe_iterator: BatchIterator = _
Expand Down Expand Up @@ -105,13 +109,28 @@ class ColumnarSortMergeJoin(
}
build_cb = realbuildIter.next()
val beforeBuild = System.nanoTime()
val build_rb = ConverterUtils.createArrowRecordBatch(build_cb)
// handle projection
val projectedBuildKeyCols: List[ArrowWritableColumnVector] = if (buildProjector != null) {
val builderOrdinalList = buildProjector.getOrdinalList
val builderAttributes = buildProjector.output
val builderProjectCols = builderOrdinalList.map(i => {
build_cb.column(i).asInstanceOf[ArrowWritableColumnVector]
})
buildProjector.evaluate(build_cb.numRows, builderProjectCols.map(_.getValueVector()))
} else {
List[ArrowWritableColumnVector]()
}
val buildCols = (0 until build_cb.numCols).toList.map(i =>
build_cb.column(i).asInstanceOf[ArrowWritableColumnVector]) ::: projectedBuildKeyCols
val build_rb =
ConverterUtils.createArrowRecordBatch(build_cb.numRows, buildCols.map(_.getValueVector))
(0 until build_cb.numCols).toList.foreach(i =>
build_cb.column(i).asInstanceOf[ArrowWritableColumnVector].retain())
inputBatchHolder += build_cb
prober.evaluate(build_rb)
prepareTime += NANOSECONDS.toMillis(System.nanoTime() - beforeBuild)
ConverterUtils.releaseArrowRecordBatch(build_rb)
projectedBuildKeyCols.foreach(v => v.close)
}
if (build_cb != null) {
build_cb = null
Expand Down Expand Up @@ -149,9 +168,31 @@ class ColumnarSortMergeJoin(
last_cb = cb
val beforeJoin = System.nanoTime()
val stream_rb: ArrowRecordBatch = ConverterUtils.createArrowRecordBatch(cb)
val output_rb = probe_iterator.process(stream_input_arrow_schema, stream_rb)

ConverterUtils.releaseArrowRecordBatch(stream_rb)
val output_rb = if (cb.numRows > 0) {
val projectedStreamKeyCols: List[ArrowWritableColumnVector] =
if (streamProjector != null) {
val streamOrdinalList = streamProjector.getOrdinalList
val streamAttributes = streamProjector.output
val streamProjectCols = streamOrdinalList.map(i => {
cb.column(i).asInstanceOf[ArrowWritableColumnVector]
})
streamProjector.evaluate(cb.numRows, streamProjectCols.map(_.getValueVector()))
} else {
List[ArrowWritableColumnVector]()
}
val streamCols = (0 until cb.numCols).toList.map(i =>
cb.column(i).asInstanceOf[ArrowWritableColumnVector]) ::: projectedStreamKeyCols
val stream_rb: ArrowRecordBatch =
ConverterUtils.createArrowRecordBatch(cb.numRows, streamCols.map(_.getValueVector))

val res = probe_iterator.process(stream_input_arrow_schema, stream_rb)

ConverterUtils.releaseArrowRecordBatch(stream_rb)
projectedStreamKeyCols.foreach(v => v.close)
res
} else {
null
}
joinTime += NANOSECONDS.toMillis(System.nanoTime() - beforeJoin)
if (output_rb == null) {
val resultColumnVectors =
Expand Down Expand Up @@ -188,6 +229,8 @@ object ColumnarSortMergeJoin extends Logging {
var output_arrow_schema: Schema = _
var condition_probe_expr: ExpressionTree = _
var prober: ExpressionEvaluator = _
var buildKeyProjectOrdinalList: List[Int] = _
var streamKeyProjectOrdinalList: List[Int] = _

def init(
leftKeys: Seq[Expression],
Expand Down Expand Up @@ -234,16 +277,16 @@ object ColumnarSortMergeJoin extends Logging {
val lkeyFieldList: List[Field] = leftKeys.toList.zipWithIndex.map {
case (expr, i) => {
//TODO(): fix this workaround
if (expr.isInstanceOf[AttributeReference] && expr.asInstanceOf[AttributeReference].name == "none") {
return
}
//if (expr.isInstanceOf[AttributeReference] && expr.asInstanceOf[AttributeReference].name == "none") {
// return
//}
val (nativeNode, returnType) = ConverterUtils.getColumnarFuncNode(expr)
if (s"${nativeNode.toProtobuf}".contains("none#")) {
throw new UnsupportedOperationException(
s"Unsupport to generate native expression from replaceable expression.")
}
if (s"${nativeNode.toProtobuf}".contains("fnNode")) {
throw new UnsupportedOperationException(s"join key with expression is not supported.")
//throw new UnsupportedOperationException(s"join key with expression is not supported.")
lkeyProjectOrdinalList += i
Field.nullable(s"${expr}", returnType)
} else {
Expand All @@ -261,7 +304,7 @@ object ColumnarSortMergeJoin extends Logging {
case (expr, i) => {
val (nativeNode, returnType) = ConverterUtils.getColumnarFuncNode(expr)
if (s"${nativeNode.toProtobuf}".contains("fnNode")) {
throw new UnsupportedOperationException(s"join key with expression is not supported.")
//throw new UnsupportedOperationException(s"join key with expression is not supported.")
rkeyProjectOrdinalList += i
Field.nullable(s"${expr}", returnType)
} else {
Expand Down Expand Up @@ -289,15 +332,19 @@ object ColumnarSortMergeJoin extends Logging {
case _ =>
BuildLeft
}
val (
var (
build_key_field_list,
stream_key_field_list,
build_input_field_list,
stream_input_field_list) = buildSide match {
case BuildLeft =>
buildKeyProjectOrdinalList = lkeyProjectOrdinalList.toList
streamKeyProjectOrdinalList = rkeyProjectOrdinalList.toList
(lkeyFieldList, rkeyFieldList, l_input_field_list, r_input_field_list)

case BuildRight =>
buildKeyProjectOrdinalList = rkeyProjectOrdinalList.toList
streamKeyProjectOrdinalList = lkeyProjectOrdinalList.toList
(rkeyFieldList, lkeyFieldList, r_input_field_list, l_input_field_list)

}
Expand Down Expand Up @@ -362,6 +409,18 @@ object ColumnarSortMergeJoin extends Logging {
(build_input_field_list, stream_output_field_list ::: build_output_field_list)
}
}
// we need to add projectedKeyOutput into input_field_list here
if (buildKeyProjectOrdinalList.nonEmpty) {
build_input_field_list =
build_input_field_list ::: buildKeyProjectOrdinalList.map(i => build_key_field_list(i))
}

if (streamKeyProjectOrdinalList.nonEmpty) {
stream_input_field_list =
stream_input_field_list ::: streamKeyProjectOrdinalList.map(i => stream_key_field_list(i))
}
build_input_arrow_schema = new Schema(build_input_field_list.asJava)
stream_input_arrow_schema = new Schema(stream_input_field_list.asJava)

val conditionArrowSchema = new Schema(conditionInputFieldList.asJava)
output_arrow_schema = new Schema(conditionOutputFieldList.asJava)
Expand Down Expand Up @@ -438,6 +497,7 @@ object ColumnarSortMergeJoin extends Logging {
totaltime_sortmergejoin: SQLMetric,
numOutputRows: SQLMetric,
sparkConf: SparkConf): Unit = synchronized {
logInfo("precheck")
init(
leftKeys,
rightKeys,
Expand Down Expand Up @@ -467,7 +527,7 @@ object ColumnarSortMergeJoin extends Logging {
totaltime_sortmergejoin: SQLMetric,
numOutputRows: SQLMetric,
sparkConf: SparkConf): String = synchronized {

logInfo("prebuild")
init(
leftKeys,
rightKeys,
Expand Down Expand Up @@ -522,14 +582,55 @@ object ColumnarSortMergeJoin extends Logging {
numOutputRows,
sparkConf)

val buildSide: BuildSide = joinType match {
case LeftSemi =>
BuildRight
case LeftOuter =>
BuildRight
case LeftAnti =>
BuildRight
case j: ExistenceJoin =>
BuildRight
case LeftExistence(_) =>
BuildRight
case _ =>
BuildLeft
}

val (buildProjector, streamProjector) =
// create gandiva project to pre-process
buildSide match {
case BuildLeft =>
(
(if (buildKeyProjectOrdinalList.nonEmpty)
ColumnarProjection
.create(left.output, buildKeyProjectOrdinalList.map(i => leftKeys(i)))
else null),
(if (streamKeyProjectOrdinalList.nonEmpty)
ColumnarProjection
.create(right.output, streamKeyProjectOrdinalList.map(i => rightKeys(i)))
else null))

case BuildRight =>
(
(if (buildKeyProjectOrdinalList.nonEmpty)
ColumnarProjection
.create(right.output, buildKeyProjectOrdinalList.map(i => rightKeys(i)))
else null),
(if (streamKeyProjectOrdinalList.nonEmpty)
ColumnarProjection
.create(left.output, streamKeyProjectOrdinalList.map(i => leftKeys(i)))
else null))
}

prober = new ExpressionEvaluator(listJars.toList.asJava)
prober.build(
build_input_arrow_schema,
Lists.newArrayList(condition_probe_expr),
output_arrow_schema,
true)

columnarSortMergeJoin = new ColumnarSortMergeJoin(
new ColumnarSortMergeJoin(
prober,
stream_input_arrow_schema,
output_arrow_schema,
Expand All @@ -545,8 +646,12 @@ object ColumnarSortMergeJoin extends Logging {
prepareTime,
totaltime_sortmergejoin,
numOutputRows,
sparkConf)
columnarSortMergeJoin
sparkConf,
buildProjector,
buildKeyProjectOrdinalList,
streamProjector,
streamKeyProjectOrdinalList)

}

def close(): Unit = {
Expand Down Expand Up @@ -677,5 +782,4 @@ object ColumnarSortMergeJoin extends Logging {
condition_expression_node_list,
new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ )
}

}