Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SNAP-1136] Pooled version of Kryo serializer which works for closures #426

Merged
merged 11 commits into from
Nov 28, 2016
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ allprojects {
scalatestVersion = '2.2.6'
jettyVersion = '9.2.16.v20160414'
guavaVersion = '14.0.1'
kryoVersion = '4.0.0'
derbyVersion = '10.12.1.1'
pegdownVersion = '1.6.0'
snappyStoreVersion = '1.5.1'
Expand Down Expand Up @@ -279,6 +280,7 @@ subprojects {

include '**/*.class'
exclude '**/*DUnitTest.class'
exclude '**/*DUnitSingleTest.class'
exclude '**/*TestBase.class'

workingDir = "${testResultsBase}/junit"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package io.snappydata.cluster

import java.io.File
import java.sql.{DriverManager, Connection}
import java.sql.{Connection, DriverManager}
import java.util.Properties

import scala.collection.JavaConverters._
Expand All @@ -31,6 +31,7 @@ import io.snappydata.{Locator, Server, ServiceManager}
import org.slf4j.LoggerFactory

import org.apache.spark.sql.SnappyContext
import org.apache.spark.sql.collection.Utils
import org.apache.spark.{SparkConf, SparkContext}

/**
Expand Down Expand Up @@ -173,7 +174,7 @@ class ClusterManagerTestBase(s: String) extends DistributedTestBase(s) {
def getANetConnection(netPort: Int,
useGemXDURL: Boolean = false): Connection = {
val driver = "com.pivotal.gemfirexd.jdbc.ClientDriver"
Class.forName(driver).newInstance
Utils.classForName(driver).newInstance
var url: String = null
if (useGemXDURL) {
url = "jdbc:gemfirexd://localhost:" + netPort + "/"
Expand Down Expand Up @@ -222,6 +223,7 @@ object ClusterManagerTestBase {
conf.set("spark.sql.inMemoryColumnarStorage.batchSize", "3")
// conf.set("spark.executor.memory", "2g")
// conf.set("spark.shuffle.manager", "SORT")
Utils.setDefaultSerializerAndCodec(conf)

props.asScala.foreach({ case (k, v) =>
if (k.indexOf(".") < 0) {
Expand Down Expand Up @@ -273,6 +275,8 @@ object ClusterManagerTestBase {
cleanupTestData(null, null)
val sparkContext = SnappyContext.globalSparkContext
if (sparkContext != null) sparkContext.stop()
// clear system properties set explicitly
Utils.clearDefaultSerializerAndCodec()
}

def stopNetworkServers(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import io.snappydata.gemxd.ClusterCallbacksImpl
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.SnappyCoarseGrainedExecutorBackend
import org.apache.spark.sql.SnappyContext
import org.apache.spark.sql.collection.Utils
import org.apache.spark.{Logging, SparkCallbacks, SparkConf, SparkEnv}

/**
Expand Down Expand Up @@ -147,12 +148,14 @@ object ExecutorInitiator extends Logging {

// Fetch the driver's Spark properties.
val executorConf = new SparkConf
Utils.setDefaultSerializerAndCodec(executorConf)

val port = executorConf.getInt("spark.executor.port", 0)
val props = SparkCallbacks.fetchDriverProperty(executorHost,
executorConf, port, url)

val driverConf = new SparkConf()
val driverConf = new SparkConf
Utils.setDefaultSerializerAndCodec(driverConf)
// Specify a default directory for executor, if the local directory for executor
// is set via the executor conf,
// it will override this property later in the code
Expand All @@ -173,8 +176,8 @@ object ExecutorInitiator extends Logging {
// TODO: conf to this conf that was received from driver.

// If memory manager is not set, use Snappy unified memory manager
driverConf.set("spark.memory.manager",
driverConf.get("spark.memory.manager", SNAPPY_MEMORY_MANAGER))
driverConf.setIfMissing("spark.memory.manager",
SNAPPY_MEMORY_MANAGER)

val cores = driverConf.getInt("spark.executor.cores",
Runtime.getRuntime.availableProcessors() * 2)
Expand Down Expand Up @@ -242,6 +245,7 @@ object ExecutorInitiator extends Logging {
executorRunnable.stopTask = true
}
executorRunnable.setDriverDetails(None, null)
Utils.clearDefaultSerializerAndCodec()
}

def restartExecutor(): Unit = {
Expand All @@ -256,7 +260,7 @@ object ExecutorInitiator extends Logging {
// Avoid creation of executor inside the Gem accessor
// that is a Spark driver but has joined the gem system
// in the non embedded mode
if (SparkCallbacks.isDriver()) {
if (SparkCallbacks.isDriver) {
logInfo("Executor cannot be instantiated in this " +
"VM as a Spark driver is already running. ")
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import java.io.DataOutput
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets

import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.gemstone.gemfire.DataSerializer
import com.gemstone.gemfire.internal.shared.Version
import com.gemstone.gemfire.internal.{ByteArrayDataInput, InternalDataSerializer}
Expand All @@ -46,7 +48,7 @@ import org.apache.spark.sql.{DataFrame, SnappyContext}
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.SnappyUtils
import org.apache.spark.{Logging, SparkContext, SparkEnv}
import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskContext}

/**
* Encapsulates a Spark execution for use in query routing from JDBC.
Expand Down Expand Up @@ -91,9 +93,10 @@ class SparkSQLExecuteImpl(val sql: String,
case None => (false, Array.empty[String])
}

private def handleLocalExecution(srh: SnappyResultHolder): Unit = {
private def handleLocalExecution(srh: SnappyResultHolder,
size: Int): Unit = {
// prepare SnappyResultHolder with all data and create new one
if (hdos.size > 0) {
if (size > 0) {
val rawData = hdos.toByteArrayCopy
srh.fromSerializedData(rawData, rawData.length, null)
}
Expand All @@ -118,13 +121,13 @@ class SparkSQLExecuteImpl(val sql: String,
val handler = new InternalRowHandler(sql, querySchema,
serializeComplexType, colTypes)
val rows = executedPlan.executeCollect()
handler.serializeRows(rows.iterator)
handler(null, rows.iterator)
})
hdos.clearForReuse()
writeMetaData()
hdos.write(result)
if (isLocalExecution) {
handleLocalExecution(srh)
handleLocalExecution(srh, hdos.size)
}
msg.lastResult(srh)
return
Expand All @@ -133,6 +136,7 @@ class SparkSQLExecuteImpl(val sql: String,
val resultsRdd = executedPlan.execute()
val bm = SparkEnv.get.blockManager
val partitionBlockIds = new Array[RDDBlockId](resultsRdd.partitions.length)

val handler = new ExecutionHandler(sql, querySchema, resultsRdd.id,
partitionBlockIds, serializeComplexType, colTypes)
var blockReadSuccess = false
Expand Down Expand Up @@ -162,10 +166,7 @@ class SparkSQLExecuteImpl(val sql: String,
if (dosSize > GemFireXDUtils.DML_MAX_CHUNK_SIZE) {
if (isLocalExecution) {
// prepare SnappyResultHolder with all data and create new one
if (dosSize > 0) {
val rawData = hdos.toByteArrayCopy
srh.fromSerializedData(rawData, rawData.length, null)
}
handleLocalExecution(srh, dosSize)
msg.sendResult(srh)
srh = new SnappyResultHolder(this)
} else {
Expand All @@ -186,7 +187,7 @@ class SparkSQLExecuteImpl(val sql: String,
writeMetaData()
}
if (isLocalExecution) {
handleLocalExecution(srh)
handleLocalExecution(srh, hdos.size)
}
msg.lastResult(srh)

Expand Down Expand Up @@ -518,11 +519,15 @@ object SparkSQLExecuteImpl {
}
}

class InternalRowHandler(sql: String, schema: StructType,
serializeComplexType: Boolean,
rowStoreColTypes: Array[(Int, Int, Int)] = null) extends Serializable {
class InternalRowHandler(private var sql: String,
private var schema: StructType,
private var serializeComplexType: Boolean,
private var rowStoreColTypes: Array[(Int, Int, Int)] = null)
extends ((TaskContext, Iterator[InternalRow]) => Array[Byte])
with Serializable with KryoSerializable {

final def serializeRows(itr: Iterator[InternalRow]): Array[Byte] = {
override def apply(context: TaskContext,
itr: Iterator[InternalRow]): Array[Byte] = {
var numCols = -1
var numEightColGroups = -1
var numPartCols = -1
Expand Down Expand Up @@ -568,17 +573,62 @@ class InternalRowHandler(sql: String, schema: StructType,
}
dos.toByteArray
}

override def write(kryo: Kryo, output: Output): Unit = {
output.writeString(sql)
kryo.writeObject(output, schema)
output.writeBoolean(serializeComplexType)
val colTypes = rowStoreColTypes
if (colTypes != null) {
val len = colTypes.length
output.writeVarInt(len, true)
var i = 0
while (i < len) {
val colType = colTypes(i)
output.writeVarInt(colType._1, false)
output.writeVarInt(colType._2, false)
output.writeVarInt(colType._3, false)
i += 1
}
} else {
output.writeVarInt(0, true)
}
}

override def read(kryo: Kryo, input: Input): Unit = {
sql = input.readString()
schema = kryo.readObject[StructType](input, classOf[StructType])
serializeComplexType = input.readBoolean()
val len = input.readVarInt(true)
if (len > 0) {
val colTypes = new Array[(Int, Int, Int)](len)
var i = 0
while (i < len) {
val colType1 = input.readVarInt(false)
val colType2 = input.readVarInt(false)
val colType3 = input.readVarInt(false)
colTypes(i) = (colType1, colType2, colType3)
i += 1
}
rowStoreColTypes = colTypes
} else {
rowStoreColTypes = null
}
}
}

final class ExecutionHandler(sql: String, schema: StructType, rddId: Int,
partitionBlockIds: Array[RDDBlockId], serializeComplexType: Boolean,
rowStoreColTypes: Array[(Int, Int, Int)] = null)
extends InternalRowHandler(sql, schema, serializeComplexType, rowStoreColTypes) {
final class ExecutionHandler(_sql: String, _schema: StructType, rddId: Int,
partitionBlockIds: Array[RDDBlockId], _serializeComplexType: Boolean,
_rowStoreColTypes: Array[(Int, Int, Int)])
extends InternalRowHandler(_sql, _schema, _serializeComplexType,
_rowStoreColTypes) with Serializable with KryoSerializable {

def this() = this(null, null, 0, null, false, null)

def apply(resultsRdd: RDD[InternalRow], df: DataFrame): Unit = {
Utils.withNewExecutionId(df, {
val sc = SnappyContext.globalSparkContext
sc.runJob(resultsRdd, serializeRows _, resultHandler _)
sc.runJob(resultsRdd, this, resultHandler _)
})
}

Expand All @@ -592,6 +642,9 @@ final class ExecutionHandler(sql: String, schema: StructType, rddId: Int,
partitionBlockIds(partitionId) = blockId
}
}

override def toString(): String =
s"ExecutionHandler: Iterator[InternalRow] => Array[Byte]"
}

object SnappyContextPerConnection {
Expand Down
4 changes: 3 additions & 1 deletion cluster/src/main/scala/io/snappydata/impl/LeadImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class LeadImpl extends ServerImpl with Lead with Logging {
setAppName("leaderLauncher").
set(Property.JobserverEnabled(), "true").
set("spark.scheduler.mode", "FAIR")
Utils.setDefaultSerializerAndCodec(conf)

// inspect user input and add appropriate prefixes
// if property doesn't contain '.'
Expand Down Expand Up @@ -240,12 +241,13 @@ class LeadImpl extends ServerImpl with Lead with Logging {
SnappyContext.flushSampleTables()
}

assert(sparkContext != null, "Mix and match of LeadService api " +
assert(sparkContext != null, "Mix and match of LeadService api " +
"and SparkContext is unsupported.")
if (!sparkContext.isStopped) {
sparkContext.stop()
sparkContext = null
}
Utils.clearDefaultSerializerAndCodec()

if (null != remoteInterpreterServerObj) {
val method: Method = remoteInterpreterServerClass.getMethod("isAlive")
Expand Down
7 changes: 4 additions & 3 deletions cluster/src/main/scala/org/apache/spark/SparkCallbacks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RetrieveSparkProps

/**
* Calls that are needed to be sent to snappy-cluster classes because the variables are private[spark]
*/
* Calls that are needed to be sent to snappy-cluster classes because
* the variables are private[spark]
*/
object SparkCallbacks {

def createExecutorEnv(
Expand Down Expand Up @@ -70,7 +71,7 @@ object SparkCallbacks {
SparkConf.isExecutorStartupConf(key)
}

def isDriver() : Boolean = {
def isDriver: Boolean = {
SparkEnv.get != null &&
SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ import com.gemstone.gemfire.distributed.internal.membership.InternalDistributedM
import com.pivotal.gemfirexd.internal.engine.distributed.utils.GemFireXDUtils

import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcEndpointAddress, RpcEnv}
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerBlockManagerAdded, SparkListenerBlockManagerRemoved, SparkListenerExecutorAdded, SparkListenerExecutorRemoved, TaskSchedulerImpl}
import org.apache.spark.sql.{BlockAndExecutorId, SnappyContext}

class SnappyCoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, override val rpcEnv: RpcEnv)
extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) with Logging {
class SnappyCoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl,
override val rpcEnv: RpcEnv)
extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) {

private val snappyAppId = "snappy-app-" + System.currentTimeMillis

Expand Down Expand Up @@ -92,7 +92,7 @@ class SnappyCoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, override
}

class BlockManagerIdListener(sc: SparkContext)
extends SparkListener with Logging {
extends SparkListener {

override def onExecutorAdded(
msg: SparkListenerExecutorAdded): Unit = synchronized {
Expand Down
Loading