Skip to content

Commit

Permalink
Optimized UTF8String conversion for row tables and CollectLimit optim…
Browse files Browse the repository at this point in the history
…ization in query routing

- now using direct conversion methods from store side to get UTF8String instead of double copying
- also use direct primitive long calls for Date/Timestamp to get millis and micros respectively
- don't use execute() for plans that override and optimize SparkPlan.executeCollect()
  in query routing (e.g. CollectLimitExec); such plans are expected to return small enough
  results that will not cause trouble on the lead node
  • Loading branch information
Sumedh Wale committed Oct 12, 2016
1 parent 1452f81 commit cf6976f
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, SnappyContext}
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.SnappyUtils
import org.apache.spark.{Logging, SparkContext, SparkEnv}

Expand Down Expand Up @@ -77,8 +76,6 @@ class SparkSQLExecuteImpl(val sql: String,

private[this] val querySchema = df.schema

private[this] val resultsRdd = df.queryExecution.toRdd

private[this] lazy val colTypes = getColumnTypes

// check for query hint to serialize complex types as CLOBs
Expand All @@ -94,18 +91,48 @@ class SparkSQLExecuteImpl(val sql: String,
case None => (false, Array.empty[String])
}

private def handleLocalExecution(srh: SnappyResultHolder): Unit = {
// prepare SnappyResultHolder with all data and create new one
if (hdos.size > 0) {
val rawData = hdos.toByteArrayCopy
srh.fromSerializedData(rawData, rawData.length, null)
}
}

override def packRows(msg: LeadNodeExecutorMsg,
snappyResultHolder: SnappyResultHolder): Unit = {

var srh = snappyResultHolder
val isLocalExecution = msg.isLocallyExecuted
val bm = SparkEnv.get.blockManager
val partitionBlockIds = new Array[RDDBlockId](resultsRdd.partitions.length)
val serializeComplexType = !complexTypeAsClob && querySchema.exists(
_.dataType match {
case _: ArrayType | _: MapType | _: StructType => true
case _ => false
})
// for plans that override SparkPlan.executeCollect(), use the normal
// execution because those have much more efficient paths (e.g.
// limit will apply limit on individual partitions etc)
val executedPlan = df.queryExecution.executedPlan
if (Utils.useExecuteCollect(executedPlan)) {
val result = Utils.withNewExecutionId(df, {
val handler = new InternalRowHandler(sql, querySchema,
serializeComplexType, colTypes)
val rows = executedPlan.executeCollect()
handler.serializeRows(rows.iterator)
})
hdos.clearForReuse()
writeMetaData()
hdos.write(result)
if (isLocalExecution) {
handleLocalExecution(srh)
}
msg.lastResult(srh)
return
}

val resultsRdd = executedPlan.execute()
val bm = SparkEnv.get.blockManager
val partitionBlockIds = new Array[RDDBlockId](resultsRdd.partitions.length)
val handler = new ExecutionHandler(sql, querySchema, resultsRdd.id,
partitionBlockIds, serializeComplexType, colTypes)
var blockReadSuccess = false
Expand Down Expand Up @@ -159,11 +186,7 @@ class SparkSQLExecuteImpl(val sql: String,
writeMetaData()
}
if (isLocalExecution) {
// prepare SnappyResultHolder with all data and create new one
if (hdos.size > 0) {
val rawData = hdos.toByteArrayCopy
srh.fromSerializedData(rawData, rawData.length, null)
}
handleLocalExecution(srh)
}
msg.lastResult(srh)

Expand Down Expand Up @@ -495,19 +518,11 @@ object SparkSQLExecuteImpl {
}
}

class ExecutionHandler(sql: String, schema: StructType, rddId: Int,
partitionBlockIds: Array[RDDBlockId],
serializeComplexType: Boolean, rowStoreColTypes: Array[(Int, Int, Int)] = null) extends Serializable {

def apply(resultsRdd: RDD[InternalRow], df: DataFrame): Unit = {
Utils.withNewExecutionId(df.sparkSession, df.queryExecution) {
val sc = SnappyContext.globalSparkContext
sc.runJob(resultsRdd, rowIter _, resultHandler _)
}
}

private[snappydata] def rowIter(itr: Iterator[InternalRow]): Array[Byte] = {
class InternalRowHandler(sql: String, schema: StructType,
serializeComplexType: Boolean,
rowStoreColTypes: Array[(Int, Int, Int)] = null) extends Serializable {

final def serializeRows(itr: Iterator[InternalRow]): Array[Byte] = {
var numCols = -1
var numEightColGroups = -1
var numPartCols = -1
Expand Down Expand Up @@ -553,6 +568,19 @@ class ExecutionHandler(sql: String, schema: StructType, rddId: Int,
}
dos.toByteArray
}
}

final class ExecutionHandler(sql: String, schema: StructType, rddId: Int,
partitionBlockIds: Array[RDDBlockId], serializeComplexType: Boolean,
rowStoreColTypes: Array[(Int, Int, Int)] = null)
extends InternalRowHandler(sql, schema, serializeComplexType, rowStoreColTypes) {

def apply(resultsRdd: RDD[InternalRow], df: DataFrame): Unit = {
Utils.withNewExecutionId(df, {
val sc = SnappyContext.globalSparkContext
sc.runJob(resultsRdd, serializeRows _, resultHandler _)
})
}

private[snappydata] def resultHandler(partitionId: Int,
block: Array[Byte]): Unit = {
Expand Down
31 changes: 27 additions & 4 deletions core/src/main/scala/org/apache/spark/sql/collection/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.collection
import java.io.ObjectOutputStream
import java.nio.ByteBuffer
import java.sql.DriverManager
import java.util.TimeZone

import scala.annotation.tailrec
import scala.collection.{mutable, Map => SMap}
Expand All @@ -38,8 +39,9 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.execution.command.ExecutedCommandExec
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, SparkPlan, TakeOrderedAndProjectExec}
import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, DriverWrapper}
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
import org.apache.spark.sql.hive.SnappyStoreHiveCatalog
import org.apache.spark.sql.sources.CastLongTime
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -553,9 +555,22 @@ object Utils {
driver
}

def withNewExecutionId[T](session: SparkSession,
queryExecution: QueryExecution)(body: => T): T = {
SQLExecution.withNewExecutionId(session, queryExecution)(body)
/**
* Wrap a DataFrame action to track all Spark jobs in the body so that
* we can connect them with an execution.
*/
def withNewExecutionId[T](df: DataFrame, body: => T): T = {
df.withNewExecutionId(body)
}

/**
* Return true if the plan overrides executeCollect to provide a more
* efficient version which should be preferred over execute().
*/
def useExecuteCollect(plan: SparkPlan): Boolean = plan match {
case _: CollectLimitExec | _: ExecutedCommandExec |
_: LocalTableScanExec | _: TakeOrderedAndProjectExec => true
case _ => false
}

def immutableMap[A, B](m: mutable.Map[A, B]): Map[A, B] = new Map[A, B] {
Expand Down Expand Up @@ -597,6 +612,14 @@ object Utils {
def createCatalystConverter(dataType: DataType): Any => Any =
CatalystTypeConverters.createToCatalystConverter(dataType)

// we should use the exact day as Int, for example, (year, month, day) -> day
def millisToDays(millisUtc: Long, tz: TimeZone): Int = {
// SPARK-6785: use Math.floor so negative number of days (dates before 1970)
// will correctly work as input for function toJavaDate(Int)
val millisLocal = millisUtc + tz.getOffset(millisUtc)
Math.floor(millisLocal.toDouble / DateTimeUtils.MILLIS_PER_DAY).toInt
}

def getGenericRowValues(row: GenericRow): Array[Any] = row.values

def newChunkedByteBuffer(chunks: Array[ByteBuffer]): ChunkedByteBuffer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import com.pivotal.gemfirexd.internal.engine.store.{AbstractCompactExecRow, Resu

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.collection.Utils
import org.apache.spark.sql.store.StoreUtils
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType, StructType}
import org.apache.spark.unsafe.Platform
Expand All @@ -42,6 +42,12 @@ abstract class CompactExecRowToMutableRow extends ResultNullHolder {
protected final val fieldTypes = StoreUtils.mapCatalystTypes(
schema, dataTypes)

final lazy val defaultCal = new GregorianCalendar(
ClientSharedData.DEFAULT_TIMEZONE, ClientSharedData.DEFAULT_LOCALE)

final lazy val defaultTZ =
ClientSharedData.DEFAULT_TIMEZONE.clone().asInstanceOf[java.util.TimeZone]

protected final def createInternalRow(execRow: AbstractCompactExecRow,
mutableRow: SpecificMutableRow): InternalRow = {
var i = 0
Expand Down Expand Up @@ -151,21 +157,19 @@ abstract class CompactExecRowToMutableRow extends ResultNullHolder {
case StoreUtils.DATE_TYPE =>
val cal = this.defaultCal
cal.clear()
// TODO: can avoid creating Date object, rather get long millis
val v = execRow.getAsDate(pos, cal, this)
if (v != null) {
mutableRow.setInt(i, DateTimeUtils.fromJavaDate(v))
val millis = execRow.getAsDateMillis(pos, cal, this)
if (!wasNull) {
mutableRow.setInt(i, Utils.millisToDays(millis, defaultTZ))
} else {
mutableRow.setNullAt(i)
wasNull = false
}
case StoreUtils.TIMESTAMP_TYPE =>
val cal = this.defaultCal
cal.clear()
// TODO: can avoid creating Timestamp object, rather get long nanos
val v = execRow.getAsTimestamp(pos, cal, this)
if (v != null) {
mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(v))
val micros = execRow.getAsTimestampMicros(pos, cal, this)
if (!wasNull) {
mutableRow.setLong(i, micros)
} else {
mutableRow.setNullAt(i)
wasNull = false
Expand Down Expand Up @@ -222,9 +226,6 @@ class ResultNullHolder extends ResultWasNull {

final var wasNull: Boolean = _

final lazy val defaultCal = new GregorianCalendar(
ClientSharedData.DEFAULT_TIMEZONE, ClientSharedData.DEFAULT_LOCALE)

override final def setWasNull(): Unit = {
wasNull = true
}
Expand All @@ -235,3 +236,12 @@ class ResultNullHolder extends ResultWasNull {
result
}
}

final class ResultSetNullHolder extends ResultNullHolder {

val defaultCal = new GregorianCalendar(
ClientSharedData.DEFAULT_TIMEZONE, ClientSharedData.DEFAULT_LOCALE)

val defaultTZ =
ClientSharedData.DEFAULT_TIMEZONE.clone().asInstanceOf[java.util.TimeZone]
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private[sql] final case class RowTableScan(
val numRows = ctx.freshName("numRows")
val row = ctx.freshName("row")
val holder = ctx.freshName("nullHolder")
val holderClass = classOf[ResultNullHolder].getName
val holderClass = classOf[ResultSetNullHolder].getName
val compactRowClass = classOf[AbstractCompactExecRow].getName
val baseSchema = baseRelation.schema
val columnsRowInput = output.map(a => genCodeCompactRowColumn(ctx,
Expand Down Expand Up @@ -139,15 +139,13 @@ private[sql] final case class RowTableScan(
val javaType = ctx.javaType(dataType)
val col = ctx.freshName("col")
val pos = ordinal + 1
var useHolder = true
val code = dataType match {
case IntegerType =>
s"final $javaType $col = $rowVar.getAsInt($pos, $holder);"
case StringType =>
// TODO: SW: optimize to store same full UTF8 format in GemXD
s"""
final $javaType $col = UTF8String.fromString($rowVar.getAsString(
$pos, $holder));
"""
useHolder = false
s"final $javaType $col = $rowVar.getAsUTF8String($ordinal);"
case LongType =>
s"final $javaType $col = $rowVar.getAsLong($pos, $holder);"
case BooleanType =>
Expand All @@ -161,44 +159,42 @@ private[sql] final case class RowTableScan(
case DoubleType =>
s"final $javaType $col = $rowVar.getAsDouble($pos, $holder);"
case d: DecimalType =>
useHolder = false
val decVar = ctx.freshName("dec")
s"""
final java.math.BigDecimal $decVar = $rowVar.getAsBigDecimal(
$pos, $holder);
$pos, null);
final $javaType $col = $decVar != null ? Decimal.apply($decVar,
${d.precision}, ${d.scale}) : null;
"""
case DateType =>
// TODO: optimize to avoid Date object and instead get millis
val cal = ctx.freshName("cal")
val date = ctx.freshName("date")
val dateMs = ctx.freshName("dateMillis")
val calClass = classOf[GregorianCalendar].getName
s"""
final $calClass $cal = $holder.defaultCal();
$cal.clear();
final java.sql.Date $date = $rowVar.getAsDate($pos, $cal, $holder);
final $javaType $col = $date != null ? org.apache.spark.sql
.catalyst.util.DateTimeUtils.fromJavaDate($date) : 0;
final long $dateMs = $rowVar.getAsDateMillis($ordinal, $cal, $holder);
final $javaType $col = org.apache.spark.sql.collection
.Utils.millisToDays($dateMs, $holder.defaultTZ());
"""
case TimestampType =>
// TODO: optimize to avoid object and instead get nanoseconds
val cal = ctx.freshName("cal")
val tsVar = ctx.freshName("ts")
val calClass = classOf[GregorianCalendar].getName
s"""
final $calClass $cal = $holder.defaultCal();
$cal.clear();
final java.sql.Timestamp $tsVar = $rowVar.getAsTimestamp($pos,
$cal, $holder);
final $javaType $col = $tsVar != null ? org.apache.spark.sql
.catalyst.util.DateTimeUtils.fromJavaTimestamp($tsVar) : 0L;
final $javaType $col = $rowVar.getAsTimestampMicros(
$ordinal, $cal, $holder);
"""
case BinaryType =>
s"final $javaType $col = $rowVar.getAsBytes($pos, $holder);"
useHolder = false
s"final $javaType $col = $rowVar.getAsBytes($pos, null);"
case _: ArrayType =>
useHolder = false
val bytes = ctx.freshName("bytes")
s"""
final byte[] $bytes = $rowVar.getAsBytes($pos, $holder);
final byte[] $bytes = $rowVar.getAsBytes($pos, null);
final $javaType $col;
if ($bytes != null) {
$col = new UnsafeArrayData();
Expand All @@ -208,9 +204,10 @@ private[sql] final case class RowTableScan(
}
"""
case _: MapType =>
useHolder = false
val bytes = ctx.freshName("bytes")
s"""
final byte[] $bytes = $rowVar.getAsBytes($pos, $holder);
final byte[] $bytes = $rowVar.getAsBytes($pos, null);
final $javaType $col;
if ($bytes != null) {
$col = new UnsafeMapData();
Expand All @@ -220,9 +217,10 @@ private[sql] final case class RowTableScan(
}
"""
case s: StructType =>
useHolder = false
val bytes = ctx.freshName("bytes")
s"""
final byte[] $bytes = $rowVar.getAsBytes($pos, $holder);
final byte[] $bytes = $rowVar.getAsBytes($pos, null);
final $javaType $col;
if ($bytes != null) {
$col = new UnsafeRow(${s.length});
Expand All @@ -232,14 +230,18 @@ private[sql] final case class RowTableScan(
}
"""
case _ =>
s"""
$javaType $col = ($javaType)$rowVar.getAsObject($pos, $holder);
"""
useHolder = false
s"$javaType $col = ($javaType)$rowVar.getAsObject($pos, null);"
}
if (nullable) {
val isNullVar = ctx.freshName("isNull")
ExprCode(s"$code\nfinal boolean $isNullVar = $holder.wasNullAndClear();",
isNullVar, col)
if (useHolder) {
ExprCode(s"$code\nfinal boolean $isNullVar = $holder.wasNullAndClear();",
isNullVar, col)
} else {
ExprCode(s"$code\nfinal boolean $isNullVar = $col == null;",
isNullVar, col)
}
} else {
ExprCode(code, "false", col)
}
Expand Down
Loading

0 comments on commit cf6976f

Please sign in to comment.