diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index ab575e90c927a..cf0e300150ea6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => Parq import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{AtomicType, DataType} object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { @@ -179,7 +179,7 @@ private[sql] case class DataSourceScan( case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] && SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) && SQLContext.getActive().get.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED) => - true + output.forall(_.dataType.isInstanceOf[AtomicType]) case _ => false } @@ -232,55 +232,29 @@ private[sql] case class DataSourceScan( // Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen // never requires UnsafeRow as input. override protected def doProduce(ctx: CodegenContext): String = { - val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" - val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector" val input = ctx.freshName("input") - val idx = ctx.freshName("batchIdx") - val rowidx = ctx.freshName("rowIdx") - val batch = ctx.freshName("batch") - // PhysicalRDD always just has one input - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") - ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") - ctx.addMutableState("int", idx, s"$idx = 0;") - val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) - val columnAssigns = colVars.zipWithIndex.map { case (name, i) => - ctx.addMutableState(columnVectorClz, name, s"$name = null;") - s"$name = ${batch}.column($i);" } - - val row = ctx.freshName("row") val numOutputRows = metricTerm(ctx, "numOutputRows") - // The input RDD can either return (all) ColumnarBatches or InternalRows. We determine this - // by looking at the first value of the RDD and then calling the function which will process - // the remaining. It is faster to return batches. - // TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know - // here which path to use. Fix this. + // PhysicalRDD always just has one input + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") - val exprRows = - output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, x._1.nullable)) - ctx.INPUT_ROW = row - ctx.currentVars = null - val columnsRowInput = exprRows.map(_.gen(ctx)) - val inputRow = if (outputUnsafeRows) row else null - val scanRows = ctx.freshName("processRows") - ctx.addNewFunction(scanRows, - s""" - | private void $scanRows(InternalRow $row) throws java.io.IOException { - | boolean firstRow = true; - | while (!shouldStop() && (firstRow || $input.hasNext())) { - | if (firstRow) { - | firstRow = false; - | } else { - | $row = (InternalRow) $input.next(); - | } - | $numOutputRows.add(1); - | ${consume(ctx, columnsRowInput, inputRow).trim} - | } - | }""".stripMargin) - // Timers for how long we spent inside the scan. We can only maintain this when using batches, - // otherwise the overhead is too high. + // The input RDD can either return (all) ColumnarBatches or InternalRows. We can tell based + // on the file format and schema. if (canProcessBatches()) { + val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" + val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector" + val idx = ctx.freshName("batchIdx") + val batch = ctx.freshName("batch") + val rowidx = ctx.freshName("rowIdx") + ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") + ctx.addMutableState("int", idx, s"$idx = 0;") + val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) + val columnAssigns = colVars.zipWithIndex.map { case (name, i) => + ctx.addMutableState(columnVectorClz, name, s"$name = null;") + s"$name = ${batch}.column($i);" } + // Timers for how long we spent inside the scan. We can only maintain this when using batches, + // otherwise the overhead is too high. val scanTimeMetric = metricTerm(ctx, "scanTime") val getBatchStart = ctx.freshName("scanStart") val scanTimeTotalNs = ctx.freshName("scanTime") @@ -318,25 +292,32 @@ private[sql] case class DataSourceScan( | } | }""".stripMargin) - val value = ctx.freshName("value") s""" | if ($batch != null) { | $scanBatches(); - | } else if ($input.hasNext()) { - | Object $value = $input.next(); - | if ($value instanceof $columnarBatchClz) { - | $batch = ($columnarBatchClz)$value; + | } else { + | long $getBatchStart = System.nanoTime(); + | if ($input.hasNext()) { + | $batch = ($columnarBatchClz)$input.next(); + | $scanTimeTotalNs += System.nanoTime() - $getBatchStart; | $scanBatches(); - | } else { - | $scanRows((InternalRow) $value); | } | } """.stripMargin } else { + val row = ctx.freshName("row") + val exprRows = + output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, x._1.nullable)) + ctx.INPUT_ROW = row + ctx.currentVars = null + val columnsRowInput = exprRows.map(_.gen(ctx)) + val inputRow = if (outputUnsafeRows) row else null s""" - |if ($input.hasNext()) { - | $scanRows((InternalRow) $input.next()); - |} + | while (!shouldStop() && $input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | $numOutputRows.add(1); + | ${consume(ctx, columnsRowInput, inputRow).trim} + | } """.stripMargin } }