Skip to content

Commit

Permalink
Add InputIteratorTransformer to decouple ReadRel and iterator index
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you committed Nov 28, 2023
1 parent e6dd56e commit 5348287
Show file tree
Hide file tree
Showing 31 changed files with 376 additions and 590 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
MetricsUtil.updateNativeMetrics(child, relMap, joinParamsMap, aggParamsMap)
}

override def genInputIteratorTransformerMetrics(
sparkContext: SparkContext): Map[String, SQLMetric] = {
Map(
"iterReadTime" -> SQLMetrics.createTimingMetric(
sparkContext,
"time of reading from iterator"),
"outputVectors" -> SQLMetrics.createMetric(sparkContext, "number of output vectors")
)
}

override def genInputIteratorTransformerMetricsUpdater(
metrics: Map[String, SQLMetric]): MetricsUpdater = {
InputIteratorMetricsUpdater(metrics)
}

override def genBatchScanTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] =
Map(
"inputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"),
Expand Down Expand Up @@ -163,8 +178,6 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
SQLMetrics.createTimingMetric(sparkContext, "time of aggregating"),
"postProjectTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of postProjection"),
"iterReadTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of reading from iterator"),
"totalTime" -> SQLMetrics.createTimingMetric(sparkContext, "total time")
)

Expand Down Expand Up @@ -312,12 +325,8 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
"extraTime" -> SQLMetrics.createTimingMetric(sparkContext, "extra operators time"),
"inputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of waiting for data"),
"outputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of waiting for output"),
"streamIterReadTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of stream side read"),
"streamPreProjectionTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of stream side preProjection"),
"buildIterReadTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of build side read"),
"buildPreProjectionTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of build side preProjection"),
"postProjectTime" ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,10 @@ package io.glutenproject.execution

import io.glutenproject.extension.ValidationResult
import io.glutenproject.substrait.SubstraitContext
import io.glutenproject.substrait.rel.RelBuilder

import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.{And, Expression}
import org.apache.spark.sql.execution.SparkPlan

import java.util

import scala.collection.JavaConverters._

case class CHFilterExecTransformer(condition: Expression, child: SparkPlan)
extends FilterExecTransformerBase(condition, child) {

Expand All @@ -46,13 +41,8 @@ case class CHFilterExecTransformer(condition: Expression, child: SparkPlan)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
val leftCondition = getLeftCondition
val childCtx = child match {
case c: TransformSupport =>
c.doTransform(context)
case _ =>
throw new IllegalStateException(s"child ${child.nodeName} doesn't support transform.");
}

val operatorId = context.nextOperatorId(this.nodeName)
if (leftCondition == null) {
Expand All @@ -63,34 +53,15 @@ case class CHFilterExecTransformer(condition: Expression, child: SparkPlan)
TransformContext(childCtx.inputAttributes, output, childCtx.root)
}

val currRel = if (childCtx != null) {
getRelNode(
context,
leftCondition,
child.output,
operatorId,
childCtx.root,
validation = false)
} else {
// This means the input is just an iterator, so an ReadRel will be created as child.
// Prepare the input schema.
val attrList = new util.ArrayList[Attribute](child.output.asJava)
getRelNode(
context,
leftCondition,
child.output,
operatorId,
RelBuilder.makeReadRel(attrList, context, operatorId),
validation = false)
}
val currRel = getRelNode(
context,
leftCondition,
child.output,
operatorId,
childCtx.root,
validation = false)
assert(currRel != null, "Filter rel should be valid.")
val inputAttributes = if (childCtx != null) {
// Use the outputAttributes of child context as inputAttributes.
childCtx.outputAttributes
} else {
child.output
}
TransformContext(inputAttributes, output, currRel)
TransformContext(childCtx.outputAttributes, output, currRel)
}

private def getLeftCondition: Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,88 +80,77 @@ case class CHHashAggregateExecTransformer(
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child match {
case c: TransformSupport =>
c.doTransform(context)
case _ =>
null
}
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)

val aggParams = new AggregationParams
val operatorId = context.nextOperatorId(this.nodeName)

val (relNode, inputAttributes, outputAttributes) = if (childCtx != null) {
// The final HashAggregateExecTransformer and partial HashAggregateExecTransformer
// are in the one WholeStageTransformer.
if (modes.isEmpty || !modes.contains(Partial)) {
(
getAggRel(context, operatorId, aggParams, childCtx.root),
childCtx.outputAttributes,
output)
} else {
(
getAggRel(context, operatorId, aggParams, childCtx.root),
childCtx.outputAttributes,
aggregateResultAttributes)
}
} else {
// This means the input is just an iterator, so an ReadRel will be created as child.
// Prepare the input schema.
// Notes: Currently, ClickHouse backend uses the output attributes of
// aggregateResultAttributes as Shuffle output,
// which is different from Velox backend.
aggParams.isReadRel = true
val typeList = new util.ArrayList[TypeNode]()
val nameList = new util.ArrayList[String]()
val (inputAttrs, outputAttrs) = {
if (modes.isEmpty) {
// When there is no aggregate function, it does not need
// to handle outputs according to the AggregateMode
for (attr <- child.output) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}
(child.output, output)
} else if (!modes.contains(Partial)) {
// non-partial mode
var resultAttrIndex = 0
for (attr <- aggregateResultAttributes) {
val colName = getIntermediateAggregateResultColumnName(
resultAttrIndex,
aggregateResultAttributes,
groupingExpressions,
aggregateExpressions)
nameList.add(colName)
val (dataType, nullable) =
getIntermediateAggregateResultType(attr, aggregateExpressions)
nameList.addAll(ConverterUtils.collectStructFieldNames(dataType))
typeList.add(ConverterUtils.getTypeNode(dataType, nullable))
resultAttrIndex += 1
}
(aggregateResultAttributes, output)
val (relNode, inputAttributes, outputAttributes) =
if (!child.isInstanceOf[InputIteratorTransformer]) {
// The final HashAggregateExecTransformer and partial HashAggregateExecTransformer
// are in the one WholeStageTransformer.
if (modes.isEmpty || !modes.contains(Partial)) {
(
getAggRel(context, operatorId, aggParams, childCtx.root),
childCtx.outputAttributes,
output)
} else {
// partial mode
for (attr <- child.output) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}

(child.output, aggregateResultAttributes)
(
getAggRel(context, operatorId, aggParams, childCtx.root),
childCtx.outputAttributes,
aggregateResultAttributes)
}
}
} else {
// This means the input is just an iterator, so an ReadRel will be created as child.
// Prepare the input schema.
// Notes: Currently, ClickHouse backend uses the output attributes of
// aggregateResultAttributes as Shuffle output,
// which is different from Velox backend.
val typeList = new util.ArrayList[TypeNode]()
val nameList = new util.ArrayList[String]()
val (inputAttrs, outputAttrs) = {
if (modes.isEmpty) {
// When there is no aggregate function, it does not need
// to handle outputs according to the AggregateMode
for (attr <- child.output) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}
(child.output, output)
} else if (!modes.contains(Partial)) {
// non-partial mode
var resultAttrIndex = 0
for (attr <- aggregateResultAttributes) {
val colName = getIntermediateAggregateResultColumnName(
resultAttrIndex,
aggregateResultAttributes,
groupingExpressions,
aggregateExpressions)
nameList.add(colName)
val (dataType, nullable) =
getIntermediateAggregateResultType(attr, aggregateExpressions)
nameList.addAll(ConverterUtils.collectStructFieldNames(dataType))
typeList.add(ConverterUtils.getTypeNode(dataType, nullable))
resultAttrIndex += 1
}
(aggregateResultAttributes, output)
} else {
// partial mode
for (attr <- child.output) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}

// The iterator index will be added in the path of LocalFiles.
val iteratorIndex: Long = context.nextIteratorIndex
val inputIter = LocalFilesBuilder.makeLocalFiles(
ConverterUtils.ITERATOR_PREFIX.concat(iteratorIndex.toString))
context.setIteratorNode(iteratorIndex, inputIter)
val readRel =
RelBuilder.makeReadRel(typeList, nameList, null, iteratorIndex, context, operatorId)
(child.output, aggregateResultAttributes)
}
}

(getAggRel(context, operatorId, aggParams, readRel), inputAttrs, outputAttrs)
}
val readRel =
RelBuilder.makeReadRelForInputIterator(typeList, nameList, context, operatorId)
(getAggRel(context, operatorId, aggParams, readRel), inputAttrs, outputAttrs)
}
TransformContext(inputAttributes, outputAttributes, relNode)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,6 @@ class HashAggregateMetricsUpdater(val metrics: Map[String, SQLMetric])
var currentIdx = operatorMetrics.metricsList.size() - 1
var totalTime = 0L

// read rel
if (aggregationParams.isReadRel) {
metrics("iterReadTime") +=
(operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong
metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors
totalTime += operatorMetrics.metricsList.get(currentIdx).time
currentIdx -= 1
}

// pre projection
if (aggregationParams.preProjectionNeeded) {
metrics("preProjectTime") +=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,6 @@ class HashJoinMetricsUpdater(val metrics: Map[String, SQLMetric])
currentIdx -= 1
}

// build side read rel
if (joinParams.isBuildReadRel) {
val buildSideRealRel = operatorMetrics.metricsList.get(currentIdx)
metrics("buildIterReadTime") += (buildSideRealRel.time / 1000L).toLong
metrics("outputVectors") += buildSideRealRel.outputVectors
totalTime += buildSideRealRel.time
currentIdx -= 1
}

// stream side pre projection
if (joinParams.streamPreProjectionNeeded) {
metrics("streamPreProjectionTime") +=
Expand All @@ -58,25 +49,15 @@ class HashJoinMetricsUpdater(val metrics: Map[String, SQLMetric])
currentIdx -= 1
}

// stream side read rel
if (joinParams.isStreamedReadRel) {
metrics("streamIterReadTime") +=
(operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong
metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors
totalTime += operatorMetrics.metricsList.get(currentIdx).time

// update fillingRightJoinSideTime
MetricsUtil
.getAllProcessorList(operatorMetrics.metricsList.get(currentIdx))
.foreach(
processor => {
if (processor.name.equalsIgnoreCase("FillingRightJoinSide")) {
metrics("fillingRightJoinSideTime") += (processor.time / 1000L).toLong
}
})

currentIdx -= 1
}
// update fillingRightJoinSideTime
MetricsUtil
.getAllProcessorList(operatorMetrics.metricsList.get(currentIdx))
.foreach(
processor => {
if (processor.name.equalsIgnoreCase("FillingRightJoinSide")) {
metrics("fillingRightJoinSideTime") += (processor.time / 1000L).toLong
}
})

// joining
val joinMetricsData = operatorMetrics.metricsList.get(currentIdx)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.glutenproject.metrics

import org.apache.spark.sql.execution.metric.SQLMetric

case class InputIteratorMetricsUpdater(metrics: Map[String, SQLMetric]) extends MetricsUpdater {
override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = {
if (opMetrics != null) {
val operatorMetrics = opMetrics.asInstanceOf[OperatorMetrics]
if (!operatorMetrics.metricsList.isEmpty) {
val metricsData = operatorMetrics.metricsList.get(0)
metrics("iterReadTime") += (metricsData.time / 1000L).toLong
metrics("outputVectors") += metricsData.outputVectors
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import io.glutenproject.expression.{ConverterUtils, ExpressionConverter}
import io.glutenproject.substrait.SubstraitContext
import io.glutenproject.substrait.expression.ExpressionNode
import io.glutenproject.substrait.plan.{PlanBuilder, PlanNode}
import io.glutenproject.substrait.rel.{LocalFilesBuilder, RelBuilder}
import io.glutenproject.substrait.rel.RelBuilder

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, Expression}

Expand All @@ -31,17 +31,10 @@ object PlanNodesUtil {
def genProjectionsPlanNode(key: Expression, output: Seq[Attribute]): PlanNode = {
val context = new SubstraitContext

// input
val iteratorIndex: Long = context.nextIteratorIndex
var operatorId = context.nextOperatorId("ClickHouseBuildSideRelationReadIter")
val inputIter = LocalFilesBuilder.makeLocalFiles(
ConverterUtils.ITERATOR_PREFIX.concat(iteratorIndex.toString))
context.setIteratorNode(iteratorIndex, inputIter)

val typeList = ConverterUtils.collectAttributeTypeNodes(output)
val nameList = ConverterUtils.collectAttributeNamesWithExprId(output)
val readRel =
RelBuilder.makeReadRel(typeList, nameList, null, iteratorIndex, context, operatorId)
val readRel = RelBuilder.makeReadRelForInputIterator(typeList, nameList, context, operatorId)

// replace attribute to BoundRefernce according to the output
val newBoundRefKey = key.transformDown {
Expand Down
Loading

0 comments on commit 5348287

Please sign in to comment.