diff --git a/core/src/main/scala/com/intel/oap/execution/ColumnarSortMergeJoinExec.scala b/core/src/main/scala/com/intel/oap/execution/ColumnarSortMergeJoinExec.scala index 3b20b4fe7..f1f34d33a 100644 --- a/core/src/main/scala/com/intel/oap/execution/ColumnarSortMergeJoinExec.scala +++ b/core/src/main/scala/com/intel/oap/execution/ColumnarSortMergeJoinExec.scala @@ -315,7 +315,7 @@ case class ColumnarSortMergeJoinExec( //do not call prebuild so we could skip the c++ codegen //val triggerBuildSignature = getCodeGenSignature - try { + /*try { ColumnarSortMergeJoin.precheck( leftKeys, rightKeys, @@ -332,7 +332,7 @@ case class ColumnarSortMergeJoinExec( } catch { case e: Throwable => throw e - } + }*/ /***********************************************************/ def getCodeGenSignature: String = diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarSortMergeJoin.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarSortMergeJoin.scala index b01dd688b..8f4bdb1eb 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarSortMergeJoin.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarSortMergeJoin.scala @@ -73,7 +73,11 @@ class ColumnarSortMergeJoin( prepareTime: SQLMetric, totaltime_sortmergejoin: SQLMetric, totalOutputNumRows: SQLMetric, - sparkConf: SparkConf) + sparkConf: SparkConf, + buildProjector: ColumnarProjection, + buildKeyProjectOrdinalList: List[Int], + streamProjector: ColumnarProjection, + streamKeyProjectOrdinalList: List[Int]) extends Logging { ColumnarPluginConfig.getConf(sparkConf) var probe_iterator: BatchIterator = _ @@ -105,13 +109,28 @@ class ColumnarSortMergeJoin( } build_cb = realbuildIter.next() val beforeBuild = System.nanoTime() - val build_rb = ConverterUtils.createArrowRecordBatch(build_cb) + // handle projection + val projectedBuildKeyCols: List[ArrowWritableColumnVector] = if (buildProjector != null) { + val builderOrdinalList = buildProjector.getOrdinalList + val builderAttributes = buildProjector.output + val builderProjectCols = builderOrdinalList.map(i => { + build_cb.column(i).asInstanceOf[ArrowWritableColumnVector] + }) + buildProjector.evaluate(build_cb.numRows, builderProjectCols.map(_.getValueVector())) + } else { + List[ArrowWritableColumnVector]() + } + val buildCols = (0 until build_cb.numCols).toList.map(i => + build_cb.column(i).asInstanceOf[ArrowWritableColumnVector]) ::: projectedBuildKeyCols + val build_rb = + ConverterUtils.createArrowRecordBatch(build_cb.numRows, buildCols.map(_.getValueVector)) (0 until build_cb.numCols).toList.foreach(i => build_cb.column(i).asInstanceOf[ArrowWritableColumnVector].retain()) inputBatchHolder += build_cb prober.evaluate(build_rb) prepareTime += NANOSECONDS.toMillis(System.nanoTime() - beforeBuild) ConverterUtils.releaseArrowRecordBatch(build_rb) + projectedBuildKeyCols.foreach(v => v.close) } if (build_cb != null) { build_cb = null @@ -149,9 +168,31 @@ class ColumnarSortMergeJoin( last_cb = cb val beforeJoin = System.nanoTime() val stream_rb: ArrowRecordBatch = ConverterUtils.createArrowRecordBatch(cb) - val output_rb = probe_iterator.process(stream_input_arrow_schema, stream_rb) - - ConverterUtils.releaseArrowRecordBatch(stream_rb) + val output_rb = if (cb.numRows > 0) { + val projectedStreamKeyCols: List[ArrowWritableColumnVector] = + if (streamProjector != null) { + val streamOrdinalList = streamProjector.getOrdinalList + val streamAttributes = streamProjector.output + val streamProjectCols = streamOrdinalList.map(i => { + cb.column(i).asInstanceOf[ArrowWritableColumnVector] + }) + streamProjector.evaluate(cb.numRows, streamProjectCols.map(_.getValueVector())) + } else { + List[ArrowWritableColumnVector]() + } + val streamCols = (0 until cb.numCols).toList.map(i => + cb.column(i).asInstanceOf[ArrowWritableColumnVector]) ::: projectedStreamKeyCols + val stream_rb: ArrowRecordBatch = + ConverterUtils.createArrowRecordBatch(cb.numRows, streamCols.map(_.getValueVector)) + + val res = probe_iterator.process(stream_input_arrow_schema, stream_rb) + + ConverterUtils.releaseArrowRecordBatch(stream_rb) + projectedStreamKeyCols.foreach(v => v.close) + res + } else { + null + } joinTime += NANOSECONDS.toMillis(System.nanoTime() - beforeJoin) if (output_rb == null) { val resultColumnVectors = @@ -188,6 +229,8 @@ object ColumnarSortMergeJoin extends Logging { var output_arrow_schema: Schema = _ var condition_probe_expr: ExpressionTree = _ var prober: ExpressionEvaluator = _ + var buildKeyProjectOrdinalList: List[Int] = _ + var streamKeyProjectOrdinalList: List[Int] = _ def init( leftKeys: Seq[Expression], @@ -234,16 +277,16 @@ object ColumnarSortMergeJoin extends Logging { val lkeyFieldList: List[Field] = leftKeys.toList.zipWithIndex.map { case (expr, i) => { //TODO(): fix this workaround - if (expr.isInstanceOf[AttributeReference] && expr.asInstanceOf[AttributeReference].name == "none") { - return - } + //if (expr.isInstanceOf[AttributeReference] && expr.asInstanceOf[AttributeReference].name == "none") { + // return + //} val (nativeNode, returnType) = ConverterUtils.getColumnarFuncNode(expr) if (s"${nativeNode.toProtobuf}".contains("none#")) { throw new UnsupportedOperationException( s"Unsupport to generate native expression from replaceable expression.") } if (s"${nativeNode.toProtobuf}".contains("fnNode")) { - throw new UnsupportedOperationException(s"join key with expression is not supported.") + //throw new UnsupportedOperationException(s"join key with expression is not supported.") lkeyProjectOrdinalList += i Field.nullable(s"${expr}", returnType) } else { @@ -261,7 +304,7 @@ object ColumnarSortMergeJoin extends Logging { case (expr, i) => { val (nativeNode, returnType) = ConverterUtils.getColumnarFuncNode(expr) if (s"${nativeNode.toProtobuf}".contains("fnNode")) { - throw new UnsupportedOperationException(s"join key with expression is not supported.") + //throw new UnsupportedOperationException(s"join key with expression is not supported.") rkeyProjectOrdinalList += i Field.nullable(s"${expr}", returnType) } else { @@ -289,15 +332,19 @@ object ColumnarSortMergeJoin extends Logging { case _ => BuildLeft } - val ( + var ( build_key_field_list, stream_key_field_list, build_input_field_list, stream_input_field_list) = buildSide match { case BuildLeft => + buildKeyProjectOrdinalList = lkeyProjectOrdinalList.toList + streamKeyProjectOrdinalList = rkeyProjectOrdinalList.toList (lkeyFieldList, rkeyFieldList, l_input_field_list, r_input_field_list) case BuildRight => + buildKeyProjectOrdinalList = rkeyProjectOrdinalList.toList + streamKeyProjectOrdinalList = lkeyProjectOrdinalList.toList (rkeyFieldList, lkeyFieldList, r_input_field_list, l_input_field_list) } @@ -362,6 +409,18 @@ object ColumnarSortMergeJoin extends Logging { (build_input_field_list, stream_output_field_list ::: build_output_field_list) } } + // we need to add projectedKeyOutput into input_field_list here + if (buildKeyProjectOrdinalList.nonEmpty) { + build_input_field_list = + build_input_field_list ::: buildKeyProjectOrdinalList.map(i => build_key_field_list(i)) + } + + if (streamKeyProjectOrdinalList.nonEmpty) { + stream_input_field_list = + stream_input_field_list ::: streamKeyProjectOrdinalList.map(i => stream_key_field_list(i)) + } + build_input_arrow_schema = new Schema(build_input_field_list.asJava) + stream_input_arrow_schema = new Schema(stream_input_field_list.asJava) val conditionArrowSchema = new Schema(conditionInputFieldList.asJava) output_arrow_schema = new Schema(conditionOutputFieldList.asJava) @@ -438,6 +497,7 @@ object ColumnarSortMergeJoin extends Logging { totaltime_sortmergejoin: SQLMetric, numOutputRows: SQLMetric, sparkConf: SparkConf): Unit = synchronized { + logInfo("precheck") init( leftKeys, rightKeys, @@ -467,7 +527,7 @@ object ColumnarSortMergeJoin extends Logging { totaltime_sortmergejoin: SQLMetric, numOutputRows: SQLMetric, sparkConf: SparkConf): String = synchronized { - + logInfo("prebuild") init( leftKeys, rightKeys, @@ -522,6 +582,47 @@ object ColumnarSortMergeJoin extends Logging { numOutputRows, sparkConf) + val buildSide: BuildSide = joinType match { + case LeftSemi => + BuildRight + case LeftOuter => + BuildRight + case LeftAnti => + BuildRight + case j: ExistenceJoin => + BuildRight + case LeftExistence(_) => + BuildRight + case _ => + BuildLeft + } + + val (buildProjector, streamProjector) = + // create gandiva project to pre-process + buildSide match { + case BuildLeft => + ( + (if (buildKeyProjectOrdinalList.nonEmpty) + ColumnarProjection + .create(left.output, buildKeyProjectOrdinalList.map(i => leftKeys(i))) + else null), + (if (streamKeyProjectOrdinalList.nonEmpty) + ColumnarProjection + .create(right.output, streamKeyProjectOrdinalList.map(i => rightKeys(i))) + else null)) + + case BuildRight => + ( + (if (buildKeyProjectOrdinalList.nonEmpty) + ColumnarProjection + .create(right.output, buildKeyProjectOrdinalList.map(i => rightKeys(i))) + else null), + (if (streamKeyProjectOrdinalList.nonEmpty) + ColumnarProjection + .create(left.output, streamKeyProjectOrdinalList.map(i => leftKeys(i))) + else null)) + } + prober = new ExpressionEvaluator(listJars.toList.asJava) prober.build( build_input_arrow_schema, @@ -529,7 +630,7 @@ object ColumnarSortMergeJoin extends Logging { output_arrow_schema, true) - columnarSortMergeJoin = new ColumnarSortMergeJoin( + new ColumnarSortMergeJoin( prober, stream_input_arrow_schema, output_arrow_schema, @@ -545,8 +646,12 @@ object ColumnarSortMergeJoin extends Logging { prepareTime, totaltime_sortmergejoin, numOutputRows, - sparkConf) - columnarSortMergeJoin + sparkConf, + buildProjector, + buildKeyProjectOrdinalList, + streamProjector, + streamKeyProjectOrdinalList) + } def close(): Unit = { @@ -677,5 +782,4 @@ object ColumnarSortMergeJoin extends Logging { condition_expression_node_list, new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ ) } - }